0%

PyTorch Practical Hand-Written Modules Basics 2--Tensor basics (Tensor functions, "axis" and indexing)

In this section, we will briefly talk about the arithmetic functions in the PyTorch. Then, we will introduce the axis parameter in most of these functions in detail.

Finally, we talk about indexing the tensor, which is very tricky in manipulating the tensors as well.

Tensor functions

PyTorch supports many arithmetic functions for tensor. They are vectorized and acts very similar to numpy. (So if you are not familiar with numpy, learn it first). In the following, I’ll introduce some functions with the official docs.

Key: What is the “dim” parameter?

For the reduction functions such as argmax, we need to pass a parameter called dim. What does it mean?

  • The default value or dim is None, indicates that do the argmax for all the entries.

  • On the other hand, if we specifies the dim parameter, that means, we apply the function argmax on each vector along a specific “axis”. For all of the example below, we use a 4x3x4 3D tensor.

1
2
# create a 4x3x4 tensor
a = torch.randn(4, 3, 4)
  1. Then, in the first case, we do:
1
2
3
4
a1 = torch.argmax(a, dim=0)
a1.shape
# OUTPUT
torch.Size([3, 4])

See the gif below. If we set dim=0, that means, we apply the argmax function on each yellow vector (they are in the direction of dim0). The original tensor’s shape is 4x3x4, we reduce on the dim0, so now it’s 3x4, containing all results from argmax on the yellow vectors.

dim0
  1. Then, in the second case, we do:
1
2
3
4
a2 = torch.argmax(a, dim=1)
a2.shape
# OUTPUT
torch.Size([4, 4])

See the gif below. If we set dim=1, that means, we apply the argmax function on each yellow vector (they are in the direction of dim1). The original tensor’s shape is 4x3x4, we reduce on the dim1, so now we will have a result with 4x4 shape.

dim1
  1. Then, in the third case, we do:
1
2
3
4
a3 = torch.argmax(a, dim=2)
a3.shape
# OUTPUT
torch.Size([4, 3])

See the gif below. If we set dim=2, that means, we apply the argmax function on each yellow vector (they are in the direction of dim2). The original tensor’s shape is 4x3x4, we reduce on the dim2, so now we will have a result with 4x3 shape.

dim2

As member function

Many functions mentioned above has member function style. For example, the following pairs are equivalent.

1
2
3
4
5
6
7
a = torch.randn(3, 4)
# pair1
_ = torch.sum(a)
_ = a.sum()
# pair2
_ = torch.argmax(a, dim=0)
_ = a.argmax(dim=0)

As in-place function

The functions mentioned above returns a new result tensor, keeping the original one same. In some cases, we can do in-place operation on the tensor. The in-place functions are terminated with a _.

For example, the following pairs are equivalent.

1
2
3
4
5
6
7
8
9
10
a = torch.randn(3, 4)
# pair 1
a = torch.cos(a)
a = a.cos()
a.cos_()
# pair 2
a = torch.clamp(a, 1, 2)
a = a.clamp(1, 2)
a.clamp_(1, 2)
a.clamp(1, 2) # Wrong: this line has no effect. The a remains same; the return value was assigned to nothing.

Tensor indexing

Indexing is very powerful in torch. They are very similar to the one in numpy. Learn numpy first if you are not familiar with it.

1
2
3
4
5
6
a = torch.randn(4, 3)
# a is
tensor([[ 1.1351, 0.7592, -3.5945],
[ 0.0192, 0.1052, 0.9603],
[-0.5672, -0.5706, 1.5980],
[ 0.1115, -0.0392, 1.4112]])

The indexing supports many types, you can pass:

  • An integer. a[1, 2] returns just one value 0-D tensor tensor(0.9603), one element at (row 1, col 2).

  • A Slice. a[1::2, 2] returns 1-D tensor tensor([0.9603, 1.4112]), two elements at (row 1, col 2) and (row 3, col 2).

  • A colon. colon means everything on this dim.a[:, 2] returns 1-D tensor tensor([-3.5945, 0.9603, 1.5980, 1.4112]), a column of 4 elements at col 2.

  • A None. None is used to create a new dim on the given axis. E.g., a[:, None, :] has the shape of torch.Size([4, 1, 3]). A further example:

a[:, 2] returns 1-D vector tensor([-3.5945, 0.9603, 1.5980, 1.4112]) of col 2.

a[:, 2, None] returns 2-D vector tensor([[-3.5945], [0.9603], [1.5980], [1.4112]]) of col 2, which the original shape is kept.

  • A ... (Ellipsis). Ellipsis can be used as multiple :. E.g.,

    1
    2
    3
    4
    a = torch.arange(16).reshape(2,2,2,2)
    # The following returns the same value
    a[..., 1]
    a[:, :, :, 1]