What is the PyTorch permute() function for?

In PyTorch, the permute() function is used to rearrange the dimensions of a tensor according to a specified order. This can be useful in various deep learning scenarios, such as when you need to change the dimension order of your input data to match the expected input format of a model.

The function takes a sequence of integers as arguments, representing the new order of the dimensions. The length of the sequence should match the number of dimensions in the input tensor.

Here’s an example:

import torch

# Create a tensor of shape (2, 3, 4)
tensor = torch.randn(2, 3, 4)

# Permute the dimensions to get a new tensor of shape (3, 4, 2)
permuted_tensor = tensor.permute(1, 2, 0)

In this example, the original tensor has shape (2, 3, 4), and we use permute() to reorder the dimensions, resulting in a new tensor with shape (3, 4, 2). The first dimension (with size 2) is moved to the last position, the second dimension (with size 3) is moved to the first position, and the third dimension (with size 4) is moved to the second position.

Why is contiguous() used in conjunction with permute()?


In PyTorch, the contiguous() function is used to create a new tensor that has the same data as the input tensor but with a contiguous memory layout. In other words, it ensures that the elements of the new tensor are stored in a continuous block of memory, allowing for efficient memory access and operations.

When you perform operations like slicing, indexing, or transposing on a tensor, the resulting tensor might have a non-contiguous memory layout. This can negatively affect the performance of certain tensor operations. By calling contiguous() on the tensor, you can obtain a tensor with a contiguous memory layout, which can improve performance in some cases.

Here’s an example of using both PyTorch transpose() and contiguous() functions:

import torch

# Create a tensor of shape (2, 3, 4)
tensor = torch.randn(2, 3, 4)

# Transpose the tensor to create a new tensor with a non-contiguous memory layout
transposed_tensor = tensor.transpose(0, 1)

# Check if the transposed tensor is contiguous
print(transposed_tensor.is_contiguous())  # Output: False

# Create a contiguous version of the transposed tensor
contiguous_tensor = transposed_tensor.contiguous()

# Check if the contiguous tensor is contiguous
print(contiguous_tensor.is_contiguous())  # Output: True

In this example, we first create a tensor, and then transpose it to obtain a new tensor with a non-contiguous memory layout. By calling contiguous() on the transposed tensor, we create a new tensor with a contiguous memory layout.

Permuting a 3D tensor to a 2D tensor

import torch

# Create a tensor of shape (2, 3, 4)
tensor = torch.randn(2, 3, 4)

# Permute the dimensions to get a new tensor of shape (2, 12)
permuted_tensor = tensor.permute(0, 2, 1).reshape(2, -1)

print(permuted_tensor.shape)  # Output: torch.Size([2, 12])

In this example, we create a tensor of shape (2, 3, 4) and use permute() to reorder the dimensions to get a new tensor of shape (2, 4, 3). We then use the reshape() function to convert the tensor to a 2D tensor of shape (2, 12).

Permuting a 4D tensor to a 3D tensor

import torch

# Create a tensor of shape (2, 3, 4, 5)
tensor = torch.randn(2, 3, 4, 5)

# Permute the dimensions to get a new tensor of shape (2, 4, 5, 3)
permuted_tensor = tensor.permute(0, 2, 3, 1)

# Select the first element along the first dimension
first_element = permuted_tensor[0]

print(first_element.shape)  # Output: torch.Size([4, 5, 3])

In this example, we create a tensor of shape (2, 3, 4, 5) and use permute() to reorder the dimensions to get a new tensor of shape (2, 4, 5, 3). We then select the first element along the first dimension to obtain a tensor of shape (4, 5, 3).

Creating a contiguous tensor from a non-contiguous tensor

import torch

# Create a tensor of shape (2, 3, 4)
tensor = torch.randn(2, 3, 4)

# Transpose the tensor to create a new tensor with a non-contiguous memory layout
transposed_tensor = tensor.transpose(0, 1)

# Check if the transposed tensor is contiguous
print(transposed_tensor.is_contiguous())  # Output: False

# Create a contiguous version of the transposed tensor
contiguous_tensor = transposed_tensor.contiguous()

# Check if the contiguous tensor is contiguous
print(contiguous_tensor.is_contiguous())  # Output: True

In this example, we create a tensor of shape (2, 3, 4), transpose it to obtain a new tensor with a non-contiguous memory layout, and then use contiguous() to create a new tensor with a contiguous memory layout.

Creating a contiguous tensor from a sliced tensor

import torch

# Create a tensor of shape (2, 3, 4)
tensor = torch.randn(2, 3, 4)

# Slice the tensor to create a new tensor with a non-contiguous memory layout
sliced_tensor = tensor[:, :, 1:3]

# Check if the sliced tensor is contiguous
print(sliced_tensor.is_contiguous())  # Output: False

# Create a contiguous version of the sliced tensor
contiguous_tensor = sliced_tensor.contiguous()

# Check if the contiguous tensor is contiguous
print(contiguous_tensor.is_contiguous())  # Output: True

In this example, we create a tensor of shape (2, 3, 4), slice it to obtain a new tensor with a non-contiguous memory layout, and then use contiguous() to create a new tensor with a contiguous memory layout.

Here are some relevant links to documentation for the PyTorch permute() and contiguous() functions:

  1. torch.permute() documentation: This page provides documentation for the permute() function in PyTorch. It includes a description of the function, its parameters, and an example usage.
  2. torchvision.ops.Permute documentation: This page provides documentation for the Permute module in the PyTorch vision package. It includes a description of the function, its parameters, and an example usage.
  3. torch.contiguous() documentation: This page provides documentation for the contiguous() function in PyTorch. It includes a description of the function, its parameters, and an example usage.

These resources should provide a good starting point for learning more about the permute() and contiguous() functions in PyTorch.


Posted

in

,

by