0%

PyTorch Practical Hand-Written Modules Basics 3--Tensor-wise operations

In this section we will talk about some PyTorch functions that operates the tensors.

torch.Tensor.expand

Signature: Tensor.expand(*sizes) -> Tensor

The expand function returns a new view of the self tensor, with singleton dimensions expanded to a larger size. The passing parameter indicates the destination size. (“singleton dimensions” means the dimension with shape 1)

Basic Usage

Passing -1 as the size for a dimension means not changing the size of that dimension.

1
2
3
4
5
6
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])
print(x)
print(x.expand(3, 4)) # torch.Size([3, 4])
print(x.expand(-1, 4)) # torch.Size([3, 4])
print(x.expand(3, -1)) # torch.Size([3, 1])
print(x.expand(-1, -1)) # torch.Size([3, 1])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# OUTPUT
tensor([[1],
[2],
[3]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
tensor([[1],
[2],
[3]])
tensor([[1],
[2],
[3]])

Wrong Usage

Only the dimension with shape 1 can be expanded:

1
2
3
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])

print(x.expand(2, 2)) # ERROR! can't expand axis 0 shape from 3 (not 1)

Why use it?

The return is only a view, not a new tensor. Therefore, if you only want to only read (not write) to an expanded tensor, use expand() will save much GPU memory. Note that modifying on the expanded tensor would make modification on the original as well.

1
2
3
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])
x.expand(3, 4)[0, 1] = 100
print(x)
1
2
3
4
# OUTPUT
tensor([[100],
[ 2],
[ 3]])

torch.Tensor.repeat

Signature: Tensor.repeat(*sizes) -> Tensor)

Repeats this tensor along the specified dimensions. It is somewhat similar to torch.Tensor.expand(), but the passing in parameter indicates the repeat times. Also, this is a deep copy.

1
2
x = torch.tensor([1, 2, 3]) # torch.Size([3])
print(x.repeat(4, 2)) # torch.Size([4, 6])
1
2
3
4
5
# OUTPUT
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])

More than the given ndimension

If the size has more dimension than the self tensor, like the example below, the x only have shape 3x1, while we have more than two input parameters, then additional dimensions will be added at the front.

1
2
3
4
5
6
7
8
9
10
11
12
13
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])

print(x.repeat(4, 2, 1).shape)
# torch.Size([4, 6, 1]) first 1: same. last 2 dim: [3,1]*[2,1]=[6,1]

print(x.repeat(4, 2, 1, 1).shape)
# torch.Size([4, 2, 3, 1]) first 2: same. last 2 dim: [3,1]*[1,1]=[3,1]

print(x.repeat(1, 4, 2, 1).shape)
# torch.Size([1, 4, 6, 1]) first 2: same. last 2 dim: [3,1]*[2,1]=[6,1]

print(x.repeat(1, 1, 4, 2).shape)
# torch.Size([1, 1, 12, 2]) first 2: same. last 2 dim: [3,1]*[4,2]=[12,2]

torch.Tensor.transpose

Signature: torch.transpose(input, dim0, dim1) -> Tensor

Signature: torch.Tensor.transpose(dim0, dim1) -> Tensor

Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

Therefore, like the examples below, x.transpose(0, 1) and x.transpose(1, 0) are same.

1
2
3
4
x = torch.randn(2, 3)
print(x) # shape: torch.Size([2, 3])
print(x.transpose(0, 1)) # shape: torch.Size([3, 2])
print(x.transpose(1, 0)) # shape: torch.Size([3, 2])
1
2
3
4
5
6
7
8
y = torch.randn(2, 3, 4)
print(y) # shape: torch.Size([2, 3, 4])

print(y.transpose(0, 1)) # shape: torch.Size([3, 2, 4])
print(y.transpose(1, 0)) # shape: torch.Size([3, 2, 4])

print(y.transpose(0, 2)) # shape: torch.Size([4, 3, 2])
print(y.transpose(2, 0)) # shape: torch.Size([4, 3, 2])

torch.Tensor.permute

Signature: torch.Tensor.permute(dims) -> Tensor

Signature: torch.permute(input, dims) -> Tensor

This function reorder the dimensions. See the example below.

1
2
3
4
5
y = torch.randn(2, 3, 4) # Shape: torch.Size([2, 3, 4])

print(y.permute(0, 1, 2)) # Shape: torch.Size([2, 3, 4])

print(y.permute(0, 2, 1)) # Shape: torch.Size([2, 4, 3])

Let’s have a close look to the third line as an example.

  • The first argument 0 means that the new tensor’s first dimension is the original dimension at 0, so the shape is 2.

  • The second argument 2 means that the new tensor’s second dimension is the original dimension at 2, so the shape is 4.

  • The third argument 1 means that the new tensor’s third dimension is the original dimension at 1, so the shape is 3.

Finally, the result shape is torch.Size([2, 4, 3]).

torch.Tensor.view / torch.Tensor.reshape

Signature: Tensor.view(*shape) -> Tensor

Signature: Tensor.reshape(*shape) -> Tensor

Reshape the Tensor to shape.

The function shape() always return a new copy of the tensor.

For function view(), if the shape satisfies some conditions (see here), deep copy can be avoided to save the GPU memory.

1
2
3
4
5
x = torch.randn(4, 3)
print(x) # Shape: torch.Size([4, 3])

print(x.reshape(3, 4)) # Shape: torch.Size([3, 4])
print(x.reshape(-1, 4)) # Shape: torch.Size([3, 4])
1
2
3
4
5
x = torch.randn(4, 3)
print(x) # Shape: torch.Size([4, 3])

print(x.view(3, 4)) # Shape: torch.Size([3, 4])
print(x.view(-1, 4)) # Shape: torch.Size([3, 4])

torch.cat

Signature: torch.cat(tensors, dim=0, out=None) -> Tensor

Concatenates the given sequence of tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty. For how to determine the dim, please refer to my previous article.

1
2
3
4
5
6
7
8
9
10
11
x = torch.randn(2, 3)
print(x) # Shape: torch.Size([2, 3])

y = torch.randn(2, 3)
print(y) # Shape: torch.Size([2, 3])

z = torch.cat((x, y), dim=0)
print(z) # Shape: torch.Size([4, 3]) [2+2, 3]

z = torch.cat((x, y), dim=1)
print(z) # Shape: torch.Size([2, 6]) [2, 3+3]

torch.stack

Signature: torch.stack(tensors, dim=0, out=None) -> Tensor

Concatenates a sequence of tensors along a new dimension. See example below.

1
2
3
4
5
6
7
8
9
10
11
x = torch.randn(2, 3) # Shape: torch.Size([2, 3])
y = torch.randn(2, 3) # Shape: torch.Size([2, 3])

z = torch.stack((x, y), dim=0)
print(z) # Shape: torch.Size([*2, 2, 3]) The first 2 is the new dimension

z = torch.stack((x, y), dim=1)
print(z) # Shape: torch.Size([2, *2, 3]) The second 2 is the new dimension

z = torch.stack((x, y), dim=2)
print(z) # Shape: torch.Size([2, 3, *2]) The last 2 is the new dimension

torch.vstack/hstack

torch.vsplit(...) is spliting the tensors vertically, which is equivalent to torch.split(..., dim=0).

torch.hsplit(...) is spliting the tensors horizontally, which is equivalent to torch.split(..., dim=1).

1
2
3
4
5
x = torch.randn(2, 3) # Shape: torch.Size([2, 3])
y = torch.randn(2, 3) # Shape: torch.Size([2, 3])

assert torch.vstack((x, y)).shape == torch.cat((x, y), dim=0).shape
assert torch.hstack((x, y)).shape == torch.cat((x, y), dim=1).shape

torch.split

Signature: torch.split(tensor, split_size_or_sections, dim=0)

  • If split_size_or_sections is an integer, then tensor will be split into equally sized chunks (if possible, ptherwise, last would be smaller).
1
2
3
4
x = torch.randn(4, 3) # Shape: torch.Size([4, 3])

print(torch.split(x, 2, dim=0)) # 2-item tuple, each Shape: (2, 3)
print(torch.split(x, 1, dim=1)) # 3-item tuple, each Shape: (4, 1)
  • If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
1
2
3
4
5
6
7
x = torch.randn(4, 3) # Shape: torch.Size([4, 3])

print(torch.split(x, (1, 3), dim=0))
# 2-item tuple, each Shape: (1, 3) and (3, 3)

print(torch.split(x, (1,1,1), dim=1))
# 3-item tuple, each Shape: (4, 1) and (4, 1) and (4, 1)

torch.vsplit/hsplit

This is actually similar to torch.vstack and torch.hstack. v means vertically, along dim=0, and h means horizontally, along dim=1.

1
2
3
4
5
6
7
# The followings are equivalent:
# pair 1
print(torch.vsplit(x, 3))
print(torch.split(x, 1, dim=0))
# pair 2
print(torch.hsplit(x, 4))
print(torch.split(x, 1, dim=1))

torch.flatten

Signature: torch.flatten(input, start_dim=0, end_dim=-1) -> Tensor

flatten the given dimension from start_dim to end_dim. This is especially useful when converting a 3D (image) tensor to a linear vector.

1
2
3
4
5
x = torch.randn(2, 4, 4)
print(x) # Shape: torch.Size([2, 4, 4])

flattened = torch.flatten(x, start_dim=1)
print(flattened) # Shape: torch.Size([2, 16])