0%

Finally, we have finished all contents we want to talk about. In this section, we’ll do a quick summary about what we have talked about and plan for the future of this series.

Summary

In our ten sections of tutorial, we are learning from low-level (tensors) to high-level (modules). In detail, the structure looks like this:

  • Tensor operations (Sec 1, 2)
  • Tensor-wise operations (Sec 3)
  • Module basics (Sec 4)
  • Implement by pure-python (Sec 5 ResNet)
  • Implement by CUDA (Sec 6, 7, 8, 9)

Conclusion

From our tutorial, we know that the model consists of nn.Modules. We implement the forward() function with many tensor-wise operations to do the forward pass.

The PyTorch is highly optimized. The Python side is enough for most cases. So, it is unnecessary to implement the algorithm in C++/CUDA. (Ref to sec 9. Our CUDA matrix multiplication operation is slower than the PyTorch’s). In addition, when we are writing in native Python, we don’t need to worry about the correctness of the gradient calculation.

But just in some rare cases, the forward() implementation is complicated, and they may contain for loop. The performance is low. Under such circumstances, you may consider to write the operator by yourself. But keep in mind that:

  • You need to check if the forward & backward propagations are correct;
  • You need to do benchmarks - does my operator really get faster?

Therefore, manually write a optimized CUDA operator is time consuming and complicated. In addition, one should be equipped with proficient CUDA knowledge. But once you write the good CUDA operators, your program will boost for many times. They are all about trade-off.

Announce in Advance

Finally, let’s talk about some things I will do in the future:

  • This series will not end. For this series article 11 and later: we’ll talk about some famous model implementations.
  • As I said above, writing CUDA operator needs proficient CUDA knowledge. So I’ll setup a new series to tell you how to write good CUDA programs: CUDA Medium Tutorials

In the section 6 to 9, we’ll investigate how to use torch.autograd.Function to implement the hand-written operators. The tentative outline is:

  • In the section 6, we talk about the basics of torch.autograd.Function. The operators defined by torch.autograd.Function can be automatically back-propagated.
  • In the last section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
  • In the section (8), we talk about writing C++ CUDA extension for the “linear layer” operator.
  • In this section (9), we talk details about building the extension to a python module, as well as testing the module. Then we’ll conclude the things we’ve done so far.

Note:

  • This blog is written with following reference:
    • PyTorch official tutorial about CUDA extension: website.
    • YouTube video about writing CUDA extension: video, code.
  • For how to write CUDA code, you can follow official documentation, blogs (In Chinese). You can search by yourself for English tutorials and video tutorials.
  • This blog only talk some important points in the matrix multiplication example. Code are picked by pieces for illustration. Whole code is at: code.

Python-side Wrapper

Purely using C++ extension functions is not enough in our case. As mentioned in the Section 6, we need to build our operators with torch.autograd.Function. It is not convenient to let the user define the operator wrapper every time, so it’s better if we can write the wrapper in a python module. Then, users can easily import our python module, and using the wrapper class and functions in it.

cudaops-struct-improved.drawio

The python module is at mylinearops/. Follow the section 6, we define some autograd.Function operators and nn.Module modules in the mylinearops/mylinearops.py. Then, we export the operators and modules by the code in the mylinearops/__init__.py:

1
2
3
from .mylinearops import matmul
from .mylinearops import linearop
from .mylinearops import LinearLayer

As a result, when user imports the mylinearops, only the matmul (Y = XW) function, linearop (Y = XW+b) function and LinearLayer module are public to the users.

Writing setup.py and Building

setup.py script

The setup.py script is general same for all packages. Next time, you can just copy-paste the code above and modify some key components.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import glob
import os.path as osp
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension


ROOT_DIR = osp.dirname(osp.abspath(__file__))
include_dirs = [osp.join(ROOT_DIR, "include")]

SRC_DIR = osp.join(ROOT_DIR, "src")
sources = glob.glob(osp.join(SRC_DIR, '*.cpp'))+glob.glob(osp.join(SRC_DIR, '*.cu'))


setup(
name='mylinearops',
version='1.0',
author=...,
author_email=...,
description='Hand-written Linear ops for PyTorch',
long_description='Simple demo for writing Linear ops in CUDA extensions with PyTorch',
ext_modules=[
CUDAExtension(
name='mylinearops_cuda',
sources=sources,
include_dirs=include_dirs,
extra_compile_args={'cxx': ['-O2'],
'nvcc': ['-O2']}
)
],
py_modules=['mylinearops.mylinearops'],
cmdclass={
'build_ext': BuildExtension
}
)

At the beginning, we first get the path information. We get the include_dirs (Where we store our .h headers), sources (Where we store our C++/CUDA source code) directory.

Then, we call the setup function. The parameter explanation are as following:

  • name: The package name, how do users call this program
  • version: The version number, decided by the creator
  • author: The creator’s name
  • author_email: The creator’s email
  • description: The package’s description, short version
  • long_description: The package’s description, long version
  • ext_modules: Key in our building process. When we are building the PyTorch CUDA extension, we should use CUDAExtension, so that the build helper can know how to compile correctly
    • name: the CUDA extension name. We import this name in our wrapper to access the cuda functions
    • sources: the source files
    • include_dirs: the header files
    • extra_compile_args: The extra compiling flags. {'cxx': ['-O2'], nvcc': ['-O2']} is commonly used, which means using -O2 optimization level when compiling
  • py_modules: The Python modules needed for the package, which is our wrapper, mylinearops. In most cases, the wrapper module has the same name as the overall package name. ('mylinearops.mylinearops' stands for 'mylinearops/mylinearops.py')
  • cmdclass: When building the PyTorch CUDA extension, we always pass in this: {'build_ext': BuildExtension}

Building

Then, we can build the package. We first activate the conda environment where we want to install in:

1
conda activate <target_env>

Then run:

1
2
cd <proj_root>
python setup.py install

Note: Don’t run pip install ., otherwise your python module will not be successfully installed, at least in my case.

It may take some time to compile it. If the building process ends up with some error message, go and fix them. If it finally displays something as “successfully installed mylinearops”, then you are ready to go.

To check if the installation is successful, we can try to import it:

1
2
3
4
5
6
7
8
$ python
Python 3.9.15 (main, Nov 24 2022, 14:31:59)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import mylinearops
>>> dir(mylinearops)
['LinearLayer', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'linearop', 'matmul', 'mylinearops']
>>>

Further testing will be mentioned in the next subsection.

Module Testing

We will test the forward and backward of matmul and LinearLayer calculations respectively. To verify the answer, we’ll compare our answer with the PyTorch’s implementation or with torch.autograd.gradcheck. To increase the accuracy, we recommend to use double (torch.float64) type instead of float (torch.float32).

For tensors: create with argument dtype=torch.float64.

For modules: a good way is to use model.double() to convert all the parameters and buffers to double.

forward

A typical method is to use torch.allclose to verify if two tensors are close to each other. We can create the reference answer by PyTorch’s implementation.

  • matmul:
1
2
3
4
5
6
7
8
9
10
import torch
import mylinearops

A = torch.randn(20, 30, dtype=torch.float64).cuda().requires_grad_()
B = torch.randn(30, 40, dtype=torch.float64).cuda().requires_grad_()

res_my = mylinearops.matmul(A, B)
res_torch = torch.matmul(A, B)

print(torch.allclose(res_my, res_torch))
  • LinearLayer:
1
2
3
4
5
6
7
8
9
10
11
import torch
import mylinearops

A = torch.randn(40, 30, dtype=torch.float64).cuda().requires_grad_() * 100
linear = mylinearops.LinearLayer(30, 50).cuda().double()

res_my = linear(A)
res_torch = torch.matmul(A, linear.weight) + linear.bias

print(torch.allclose(res_my, res_torch))
print(torch.max(torch.abs(res_my - res_torch)))

It is worthwhile that sometimes, because of the floating number error, the answer from PyTorch is not consistent with the answer from our implementations. We have three methods:

  1. Pass atol=1e-5, rtol=1e-5 into the torch.allclose to increase the tolerance level.
  2. [Not very recommended] We can observe the absolute error by torch.max(torch.abs(res_my - res_torch)) for reference. If the result is merely 0.01 ~ 0.1, That would be OK in most cases.

backward

For backward calculation, we can use torch.autograd.gradcheck to verify the result. If some tensors are only float, an warning will occur:

……/torch/autograd/gradcheck.py:647: UserWarning: Input #0 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.

So it is recommended to use the double type. Otherwise the check will likely fail.

  • matmul:

As mentioned above, for pure calculation functions, we can assign all tensor as double (torch.float64) type. We are ready to go:

1
2
3
4
5
6
7
import torch
import mylinearops

A = torch.randn(20, 30, dtype=torch.float64).cuda().requires_grad_()
B = torch.randn(30, 40, dtype=torch.float64).cuda().requires_grad_()

print(torch.autograd.gradcheck(mylinearops.matmul, (A, B))) # pass
  • LinearLayer:

As mentioned above, we can use model.double(). We are ready to go:

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import mylinearops

## CHECK for Linear Layer with bias ##
A = torch.randn(40, 30, dtype=torch.float64).cuda().requires_grad_()
linear = mylinearops.LinearLayer(30, 40).cuda().double()
print(torch.autograd.gradcheck(linear, (A,))) # pass

## CHECK for Linear Layer without bias ##
A = torch.randn(40, 30, dtype=torch.float64).cuda().requires_grad_()
linear_nobias = mylinearops.LinearLayer(30, 40, bias=False).cuda().double()
print(torch.autograd.gradcheck(linear_nobias, (A,))) # pass

Full Example

Now, we use our linear module to build a three layer classic linear model [784, 256, 10]to classify the MNIST digits. See the examples/main.py file.

Just as the nn.Linear, we create the model by:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.linear1 = mylinearops.LinearLayer(784, 256, bias=True)#.cuda()
self.linear2 = mylinearops.LinearLayer(256, 256, bias=True)#.cuda()
self.linear3 = mylinearops.LinearLayer(256, 10, bias=True)#.cuda()
self.relu = nn.ReLU()
# self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
# x = self.softmax(self.linear3(x))
x = self.linear3(x)
return x

After writing some basic things, we can run our model: python examples/tests.py.

We also build the model by PyTorch’s nn.Linear. The result logging is:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# mylinearops
...
Epoch: [10/10], Step: [100/468], Loss: 0.0417, Acc: 0.9844
Epoch: [10/10], Step: [200/468], Loss: 0.0971, Acc: 0.9609
Epoch: [10/10], Step: [300/468], Loss: 0.0759, Acc: 0.9766
Epoch: [10/10], Step: [400/468], Loss: 0.0777, Acc: 0.9766
Time: 23.4661s

# torch
...
Epoch: [10/10], Step: [100/468], Loss: 0.1048, Acc: 0.9688
Epoch: [10/10], Step: [200/468], Loss: 0.0412, Acc: 0.9844
Epoch: [10/10], Step: [300/468], Loss: 0.0566, Acc: 0.9688
Epoch: [10/10], Step: [400/468], Loss: 0.0217, Acc: 0.9922
Time: 26.5896s

It is surprising that our implementation is even faster than the torch’s one. (But relax, after trying for some repetitions, we find ours is just as fast as the torch’s one). This is because the data scale is relatively small, the computation proportion is small. When the data scale is larger, ours may be slower than torch’s.

In the section 6 to 9, we’ll investigate how to use torch.autograd.Function to implement the hand-written operators. The tentative outline is:

  • In the section 6, we talk about the basics of torch.autograd.Function. The operators defined by torch.autograd.Function can be automatically back-propagated.
  • In the last section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
  • In this section (8), we talk about writing C++ CUDA extension for the “linear layer” operator.
  • In the section 9, we talk details about building the extension to a module, as well as testing. Then we’ll conclude the things we’ve done so far.

Note:

  • This blog is written with following reference:
    • PyTorch official tutorial about CUDA extension: website.
    • YouTube video about writing CUDA extension: video, code.
  • For how to write CUDA code, you can follow official documentation, blogs (In Chinese). You can search by yourself for English tutorials and video tutorials.
  • This blog only talk some important points in the matrix multiplication example. Code are picked by pieces for illustration. Whole code is at: code.

Overall Structure

The general structure for our PyTorch C++ / CUDA extension looks like following:

cudaops-struct

We mainly have three kinds of file: Library interface, Core code on CPU, and Core code on GPU. Let’s explain them in detail:

  • Library interface (.cpp)

    • Contains Functions Interface for Python to call. These functions usually have Tensor input and Tensor return value.
    • Contains a standard pybind declaration, since our extension uses pybind to bind the C++ functions for Python. It indicates which functions are needed to be bound.
  • Core code on CPU (.cpp)

    • Contains core function to do the calculation.
    • Contains wrapper for the core function, serves to creating the result tensor, checking the input shape, etc.
  • Core code on GPU (.cu)

    • Contains CUDA kernel function __global__ to do the parallel calculation.
    • Contains wrapper for the core function, serves to creating the result tensor, checking the input shape, setting the launch configs, launching the kernel, etc.

Then, after we finishing the code, we can use Python build tools to compile the code into a static object library (.so file). Then, we can import them normally in the Python side. We can call the functions we declared in library interface by pybind11.

In our example code, we don’t provide code for CPU calculation. We only support GPU. So we only have two files (src/linearops.cpp and src/addmul_kernel.cu)

Pybind Interface

This is the src/linearops.cpp file in our repo.

1. Utils function

We usually defines some utility macro functions in our code. They are in the include/utils.h header file.

1
2
3
4
5
6
7
// PyTorch CUDA Utils
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// Kernel Config Utils
#define DIV_CEIL(a, b) (((a) + (b) - 1) / (b))

The third macro will call first two macros, which are used to make sure the tenser is on the CUDA devices and is contiguous.

The last macro performs ceil division, which are often used in setting the CUDA kernel launch configurations.

2. Interface functions

Benefited by pybind, we can simply define functions in C++ and use them in Python. A function looks like

1
2
3
4
5
torch::Tensor func(torch::Tensor a, torch::Tensor b, int c){
torch::Tensor res;
......
return res;
}

is relatively same as the Python function below.

1
2
3
4
def func(a: torch.Tensor, b: torch.Tensor, c: int) -> torch.Tensor
res = ... # torch.Tensor
......
return res

Then, we can define our matrix multiplication interface as below. Note that we need to implement both the forward and backward functions!

  • forward

Check the input, input size, and then call the CUDA function wrapper.

1
2
3
4
5
6
7
8
9
10
11
torch::Tensor matmul_forward(
const torch::Tensor &A,
const torch::Tensor &B)
{
CHECK_INPUT(A);
CHECK_INPUT(B);

TORCH_CHECK(A.size(1) == B.size(0), "matmul_fast_forward: shape mismatch");

return matmul_cuda(A, B);
}
  • backward

Also check the input, input size, and then call the CUDA function wrapper. Note that we calculate the backward of A * B = C for input matrix A, B in two different function. So that when someday we don’t need to calculate the gradient of A, we can just pass it.

The gradient function derivation is mentioned in last section here.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/* Backward for A gradient */
torch::Tensor matmul_dA_backward(
const torch::Tensor &grad_output,
const torch::Tensor &A,
const torch::Tensor &B)
{
CHECK_INPUT(grad_output);
CHECK_INPUT(B);

// dL/dB = dL/dY * B^T
auto grad_A = matmul_cuda(grad_output, transpose_cuda(B));

return grad_A;
}

/* Backward for B gradient */
torch::Tensor matmul_dB_backward(
const torch::Tensor &grad_output,
const torch::Tensor &A,
const torch::Tensor &B)
{
CHECK_INPUT(grad_output);
CHECK_INPUT(A);

// dL/dB = A^T * dL/dY
auto grad_B = matmul_cuda(transpose_cuda(A), grad_output);

return grad_B;
}

3. Binding

At the last of the src/linearops.cpp, we use the following code to bind the functions. The first string is the function name in Python side, the second is a function pointer to the function be called, and the last is the docstring for that function in Python side.

1
2
3
4
5
6
7
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
m.def("matmul_forward", &matmul_forward, "Matmul forward");
m.def("matmul_dA_backward", &matmul_dA_backward, "Matmul dA backward");
m.def("matmul_dB_backward", &matmul_dB_backward, "Matmul dB backward");
......
}

CUDA wrapper

This is the src/addmul_kernel.cu file in our repo.

The wrapper for matrix multiplication looks like below:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B) {
// 1. Get metadata
const int m = A.size(0);
const int n = A.size(1);
const int p = B.size(1);

// 2. Create output tensor
auto result = torch::empty({m, p}, A.options());

// 3. Set launch configuration
const dim3 blockSize = dim3(BLOCK_SIZE, BLOCK_SIZE);
const dim3 gridSize = dim3(DIV_CEIL(m, BLOCK_SIZE), DIV_CEIL(p, BLOCK_SIZE));

// 4. Call the cuda kernel launcher
AT_DISPATCH_FLOATING_TYPES(A.type(), "matmul_cuda",
([&] {
matmul_fw_kernel<scalar_t><<<gridSize, blockSize>>>(
A.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
B.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
result.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
m, p
);
}));

// 5. Return the value
return result;
}

And here, we’ll talk in details:

1. Get metadata

Just as the tensor in PyTorch, we can use Tensor.size(0) to axis the shape size of dimension 0.

Note that we have checked the dimension match at the interface side, we don’t need to check it here.

2. Create output tensor

We can do operation in-place or create a new tensor for output. Use the following code to create a tensor shape m x p, with same dtype / device as A.

1
auto result = torch::empty({m, p}, A.options());

In other situations, when we want special dtype / device, we can follow the declaration as below:

1
torch::empty({m, p}, torch::dtype(torch::kInt32).device(feats.device()))

torch.empty only allocate the memory, but not initialize the entries to 0. Because sometimes, we’ll fill into the result tensors in the kernel functions, so it is not necessary to initialize as 0.

3. Set launch configuration

You should know some basic CUDA knowledges before understand this part. Basically here, we are setting the launch configuration based on the input matrix size. We are using the macro functions defined before.

1
2
const dim3 blockSize = dim3(BLOCK_SIZE, BLOCK_SIZE);
const dim3 gridSize = dim3(DIV_CEIL(m, BLOCK_SIZE), DIV_CEIL(p, BLOCK_SIZE));

We set each thread block size to 16 x 16. Then, we set the number of blocks according to the input size.

4. Call the cuda kernel launcher

Unlike normal cuda programs, we use ATen‘s function to start the kernel. This is a standard operation, and you can copy-paste it to anywhere.

1
2
3
4
5
6
7
8
9
AT_DISPATCH_FLOATING_TYPES(A.type(), "matmul_cuda", 
([&] {
matmul_fw_kernel<scalar_t><<<gridSize, blockSize>>>(
A.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
B.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
result.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
m, p
);
}));
  • This function is named AT_DISPATCH_FLOATING_TYPES, meaning the inside kernel will support floating types, i.e., float (32bit) and double (64bit). For float16, you can use AT_DISPATCH_ALL_TYPES_AND_HALF. For int (int (32bit) and long long (64 bit) and more, use AT_DISPATCH_INTEGRAL_TYPES.

  • The first argument A.type(), indicates the actual chosen type in the runtime.

  • The second argument matmul_cuda can be used for error reporting.

  • The last argument, which is a lambda function, is the actual function to be called. Basically in this function, we start the kernel by the following statement:

    1
    2
    3
    4
    5
    6
    matmul_fw_kernel<scalar_t><<<gridSize, blockSize>>>(
    A.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
    B.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
    result.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
    m, p
    );
    • matmul_fw_kernel is the kernel function name.
    • <scalar_t> is the template parameter, will be replaced to all possible types in the outside AT_DISPATCH_FLOATING_TYPES.
    • <<<gridSize, blockSize>>> passed in the launch configuration
    • In the parameter list, if that is a Tensor, we should pass in the packed accessor, which convenient indexing operation in the kernel.
      • <scalar_t> is the template parameter.
      • 2 means the Tensor.ndimension=2.
      • torch::RestrictPtrTraits means the pointer (tensor memory) would not not overlap. It enables some optimization. Usually not change.
      • size_t indicates the index type. Usually not change.
    • if the parameter is integer m, p, just pass it in as normal.

5. Return the value

If we have more then one return value, we can set the return type to std::vector<torch::Tensor>. Then we return with {xxx, yyy}.

CUDA kernel

This is the src/addmul_kernel.cu file in our repo.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
template <typename scalar_t>
__global__ void matmul_fw_kernel(
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> A,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> B,
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> result,
const int m, const int p
)
{
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int col = blockIdx.y * blockDim.y + threadIdx.y;

if (row >= m || col >= p) return;

scalar_t sum = 0;
for (int i = 0; i < A.size(1); i++) {
sum += A[row][i] * B[i][col];
}
result[row][col] = sum;
}
  • We define it as a template function template <typename scalar_t>, so that our kernel function can support different type of input tensor.
  • Usually we’ll set the input PackedTensorAccessor with const, to avoid some unexpected modification on them.
  • The main code is just a simple CUDA matrix multiplication example. This is very common, you can search online for explanation.

Ending

That’s too much things in this section. In the next section, we’ll talk about how to write the setup.py to compile the code, letting it be a module for python.

In the section 6 to 9, we’ll investigate how to use torch.autograd.Function to implement the hand-written operators. The tentative outline is:

  • In the last section (6), we talk about the basics of torch.autograd.Function. The operators defined by torch.autograd.Function can be automatically back-propagated.
  • In this section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
  • In the section 8, we talk about writing C++ CUDA extension for the “linear layer” operator.
  • In the section 9, we talk details about building the extension to a module, as well as testing. Then we’ll conclude the things we’ve done so far.

The linear layer is defined by Y = XW + b. There is a matrix multiplication operation, and a bias addition. We’ll talk about their forward/backward derivation separately.

(I feel sorry that currently there is some problem with displaying mathematics formula here. I’ll use screenshot first.)

Matrix multiplication: forward

The matrix multiplication operation is a common operator. Each entry in the result matrix is a vector dot product of two input matrixes. The (i, j) entry of the result is from multiplying first matrix’s row i vector and the second matrix’s column j vector. From this property, we know that number of columns in the first matrix should equal to number of rows in the second matrix. The shape should be: [m, n] x [n, r] -> [m, r]. For more details, see the figure illustration below.

matmul-forward

Matrix multiplication: backward

First, we should know what’s the goal of the backward propagation. In the upstream side, we would get the gradient of the answer matrix, C. (The gradient matrix has the same size as its corresponding matrix. i.e., if C is in shape [m, r], then gradient of C is shape [m, r] as well.) In this step, we should get the gradient of matrix A and B. Gradient of matrix A and B are functions in terms of matrix A and B and gradient of C. Specially, by chain rule, we can formulate it as

matmul-backward-math1

To figure out the gradient of A, we should first investigate how an entry A[i, j] contribute to the entries in the result matrix C. See the figure below:

matmul-backward

As shown above, entry A[i, j] multiplies with entries in row j of matrix B, contributing to the entries in row i of matrix C. We can write the gradient down in mathematics formula below:

matmul-backward-math2

The result above is the gradient for one entry A[i, j], and it’s a vector dot product between a matrix’s row i and another matrix’s column j. Observing this formula, we can naturally extend it to the gradient of the whole matrix A, and that will be a matrix product.

matmul-backward-math3

Recall “Gradient of matrix A and B are functions in terms of matrix A and B and gradient of C” we said before. Our derivation indeed show that, uh?

Add bias: forward

First, we should note that when doing the addition, we’re actually adding the XW matrix (shape [n, r]) with the bias vector (shape [r]). Indeed we have a broadcasting here. We add bias to each row of the XW matrix.

addbias-forward.drawio

Add bias: backward

With the similar principle, we can get the gradient for the bias as well.

addbias-backward

For each entry in vector b, the gradient is:

addbias-backward-math1

That is, the gradient of entry b_i is the summation of the i-th column. In total, the gradient will be the summation along each column (i.e., axis=0). In programming, we write:

1
grad_b = torch.sum(grad_C, axis=0)

PyTorch Verification

Finally, we can write a PyTorch program to verify if our derivation is correct: we will compare our calculated gradients with the gradients calculated by the PyTorch. If they are same, our derivation would be correct.

1
2
3
4
5
6
7
8
9
10
import torch
A = torch.randn(10, 20).requires_grad_()
B = torch.randn(20, 30).requires_grad_()

res = torch.mm(A, B)
res.retain_grad()
res.sum().backward()

print(torch.allclose(A.grad, torch.mm(res.grad, B.t()))) # grad_A = grad_res * B^T
print(torch.allclose(B.grad, torch.mm(A.t(), res.grad))) # grad_B = A^T * grad_res

Finally, the output is:

1
2
True
True

Which means that our derivation is correct.

In this section (and also three sections in the future), we investigate how to use torch.autograd.Function to implement the hand-written operators. The tentative outline is:

  • This section (6), we talk about the basics of torch.autograd.Function.
  • In the next section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
  • In the section 8, we talk about writing C++ CUDA extension for the “linear layer” operator.
  • In the section 9, we talk details about building the extension to a module, as well as testing. Then we’ll conclude the things we’ve done so far.

Backgrounds

This article mainly takes reference of the Official tutorial and summarizes, explains the important points.

By defining an operator with torch.autograd.Function and implement its forward / backward function, we can use this operator with other PyTorch built-in operators together. The operators defined by torch.autograd.Function can be automatically back-propagated.

As mentioned in the tutorial, we should use the torch.autograd.Function in the following scenes:

  • The computation is from other libraries, so they don’t support differential natively. We should explicitly define its backward functions.
  • The PyTorch’s implementation of an operator cannot take benefits from the parallelization. We utilize the PyTorch C++/CUDA extension for the better performance.

Basic Structure

The following is the basic structure of the Function:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch.autograd import Function

class LinearFunction(Function):

@staticmethod
def forward(ctx, input0, input1, ... , inputN):
# Save the input for the backward use.
ctx.save_for_backward(input1, input1, ... , inputN)
# Calculate the output0, ... outputM given the inputs.
......
return output0, ... , outputM

@staticmethod
def backward(ctx, grad_output0, ... , grad_outputM):
# Get and unpack the input for the backward use.
input0, input1, ... , inputN = ctx.saved_tensors

grad_input0 = grad_input1 = grad_inputN = None
# These needs_input_grad records whether each input need to calculate the gradient. This can improve the efficiency.
if ctx.needs_input_grad[0]:
grad_input0 = ... # backward calculation
if ctx.needs_input_grad[1]:
grad_input1 = ... # backward calculation
......

return grad_input0, grad_input1, grad_inputN
  1. The forward and backward functions are staticmethod. The forward function is o0, ..., oM = forward(i0, ..., iN), calculate the output0 ~ outputM by the input0 ~ inputN. Then the backward function is g_i0, ..., g_iN = backward(g_o0, ..., g_M), calculate the gradient of input0 ~ gradient of inputM by the gradient of output0 ~ outputN.

  2. Since forward and backward are merely functions. We need store the input tensors to the ctx in the forward pass, so that we can get them in the backward functions. See here to use the alternative way to define Function.

  3. ctx.needs_input_grad is a tuple of Booleans. It records whether one input needs to calculate the gradient or not. Therefore, we can save computation resources if one tensor doesn’t need gradients. In that case, the return value of backward function for that tensor is None.

Use it

Pure functions

After defining the class, we can use the .apply method to use it. Simply

1
2
# Option 1: alias
linear = LinearFunction.apply

or,

1
2
3
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)

Then call as

1
output = linear(input, weight, bias) # input, weight, bias are all tensors!

nn.Module

In most cases, the weight and bias are parameters that are trainable during the process. We can further wrap this linear function to a Linear module:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features

# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)

# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)

def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)

def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)

As mentioned in section 3, 4 of this series, the weight and bias should be nn.Parameter so that they can be recognized correctly. Then we initialize the weights with random variables.

In the forward functions, we use the defined LinearFunction.apply functions. The backward process will be automatically done just as other PyTorch modules.

Problem

Today when I was running PyTorch scripts, I met a strange problem:

1
2
3
a = torch.rand(2, 2).to('cuda:1')
......
torch.cuda.synchronize()

but result in the following error:

1
2
3
4
5
  File "....../test.py", line 67, in <module>
torch.cuda.synchronize()
File "....../miniconda3/envs/py39/lib/python3.9/site-packages/torch/cuda/__init__.py", line 495, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: out of memory

but It’s clear that GPU1 has enough memory (we only need to allocate 16 bytes!):

1
2
3
4
5
6
7
8
9
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:1A:00.0 Off | N/A |
| 75% 73C P2 303W / 350W | 24222MiB / 24268MiB | 64% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... Off | 00000000:1B:00.0 Off | N/A |
| 90% 80C P2 328W / 350W | 15838MiB / 24268MiB | 92% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

And normally, when we fail to allocate the memory for tensors, the error is:

1
CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 6.00 GiB total capacity; 4.54 GiB already allocated; 14.94 MiB free; 4.64 GiB reserved in total by PyTorch)

But our error message is much “simpler”. So what happened?

Possible Answer

This confused me for some time. According to this website:

When you initially do a CUDA call, it’ll create a cuda context and a THC context on the primary GPU (GPU0), and for that i think it needs 200 MB or so. That’s right at the edge of how much memory you have left.

Surprisingly, in my case, GPU0 has occupied 24222MiB / 24268MiB memory. So there is no more memory for the context. In addition, this makes sense that out error message is RuntimeError: CUDA error: out of memory, not the message that tensallocation failed.

Possible Solution

Set the CUDA_VISIBLE_DEVICES environment variable. We need to change primary GPU (GPU0) to other one.

Method 1

In the starting python file:

1
2
3
# Do this before `import torch`
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # set to what you like, e.g., '1,2,3,4,5,6,7'

Method 2

In the shell:

1
2
# Do this before run python
export CUDA_VISIBLE_DEVICES=1 # set to what you like, e.g., '1,2,3,4,5,6,7'

And then, our program is ready to go.

In this section, we’ll utilize knowledge we learnt from the last section (see here), to implement a ResNet Network (paper).

Note that we follow the original paper’s work. Our implementation is a simper version of the official torchvision implementation. (That is, we only implement the key structure, and the random weight init. We don’t consider dilation or other things).

Preliminaries: Calculate the feature map size

  • Basic formula

Given a convolution kernel with size K, and the padding P, the stride S, feature map size I, we can calculate the output size as O = ( I - K + 2P ) / S + 1.

  • Corollary

Based on the formula above, we know that when S=1:

  1. K=3, P=1 makes the input size and output size same.
  2. K=1, P=0 makes the input size and output size same.

Overall Structure

The Table 1 in the original paper illustrates the overall structure of the ResNet:

resnet_table1

We know that from conv2, each layer consists of many blocks. And the blocks in 18, 34 layers is different from blocks in 50, 101, 152 layers.

We have several deductions:

  1. When the feature map enters the next layer, the first block need to do a down sampling operation. This is done by setting the one of the convolution kernel’s stride=2.
  2. At other convolution kernels, the feature map’s size is same. So the convolution settings is same as the one referred in Preliminaries.

Basic Block Implementation

The basic block’s structure looks like this:

basic

Please see the code below. Here, apart from channels defining the channels in the block, we have three additional parameters, in_channels, stride, and downsample to make this block versatile in the FIRST block in each layer.

According to the ResNet structure, for example, the first block in layer3 has the input 64*56*56. The first block in layer3 has two tasks:

  1. Make the feature map size to 28*28. Thus we need to set its stride to 2.
  2. Make the number of channels from 64 to 128. Thus the in_channel should be 64.
  3. In addition, since the input is 64*56*56, while the output is 128*28*28, we need a down sample convolution to match the shortcut input to the output size.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
class ResidualBasicBlock(nn.Module):
expansion: int = 1
def __init__(self, in_channels: int, channels: int, stride: int = 1, downsample: nn.Module = None):
super().__init__()
self.downsample = downsample
self.conv1 = nn.Conv2d(in_channels, channels, 3, stride, 1)
self.batchnorm1 = nn.BatchNorm2d(channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
self.batchnorm2 = nn.BatchNorm2d(channels)
self.relu2 = nn.ReLU()

def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batchnorm1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.batchnorm2(x)
if self.downsample:
residual = self.downsample(residual)
x += residual
x = self.relu2(x)
return x

Bottleneck Block Implementation

The bottleneck block’s structure looks like this:

bottleneck

To reduce the computation cost, the Bottleneck block use 1x1 kernel to map the high number of channels (e.g., 256) to a low one (e.g., 64), and do the 3x3 convolution. Then, it maps the 64 channels to 256 again.

Please see the code below. Same as the basic block, We have three additional parameters, in_channels, stride, and downsample to make this block versatile in the FIRST block in each layer. The reasons are same as above.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ResidualBottleNeck(nn.Module):
expansion: int = 4
def __init__(self, in_channels: int, channels: int, stride: int = 1, downsample: nn.Module = None):
super().__init__()
self.downsample = downsample
self.conv1 = nn.Conv2d(in_channels, channels, 1, 1)
self.batchnorm1 = nn.BatchNorm2d(channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, 3, stride, 1)
self.batchnorm2 = nn.BatchNorm2d(channels)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(channels, channels*4, 1, 1)
self.batchnorm3 = nn.BatchNorm2d(channels*4)
self.relu3 = nn.ReLU()

def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batchnorm1(x)
x = self.relu1(x)

x = self.conv2(x)
x = self.batchnorm2(x)
x = self.relu2(x)

x = self.conv3(x)
x = self.batchnorm3(x)

if self.downsample:
residual = self.downsample(residual)

x += residual
x = self.relu3(x)
return x

ResNet Base Implementation

Then we can put thing together to form the ResNet model! The whole structure is straight-forward. We define the submodules one by one, and implement the forward() function.

There is only two tricky point:

  1. To support the ResNetBase for two different base blocks, the base block can be passed to this initializer. Since two base blocks have slightly differences in setting the channels, ResidualBasicBlock and ResidualBottleNeck have an attribute called expansion, which convenient the procedure in setting the correct number of channels and outputs.
  2. See the _make_layer function below. It need to determine whether we need to do the down sample. And the condition and explanation is described below.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class ResNetBase(nn.Module):
def __init__(self, block, layer_blocks: list, input_channels=3):
super().__init__()
self.block = block
# conv1: 7x7
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)
# max pool
self.maxpool = nn.MaxPool2d(3, 2, 1)
# conv2 ~ conv5_x
self.in_channels = 64
self.conv2 = self._make_layer(64, layer_blocks[0])
self.conv3 = self._make_layer(128, layer_blocks[1], 2)
self.conv4 = self._make_layer(256, layer_blocks[2], 2)
self.conv5 = self._make_layer(512, layer_blocks[3], 2)

self.downsample = nn.AvgPool2d(7)
output_numel = 512 * self.block.expansion
self.fc = nn.Linear(output_numel, 1000)

# init the weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, channel, replicates, stride=1):
modules = []

downsample = None
if stride != 1 or self.in_channels != channel*self.block.expansion:
# Use downsample to match the dimension in two cases:
# 1. stride != 1, meaning we should downsample H, W in this layer.
# Then we need to match the residual's H, W and the output's H, W of this layer.
# 2. self.in_channels != channel*block.expansion, meaning we should increase C in this layer.
# Then we need to match the residual's C and the output's C of this layer.

downsample = nn.Sequential(
nn.Conv2d(self.in_channels, channel*self.block.expansion, 1, stride),
nn.BatchNorm2d(channel*self.block.expansion)
)

modules.append(self.block(self.in_channels, channel, stride, downsample))

self.in_channels = channel * self.block.expansion
for r in range(1, replicates):
modules.append(self.block(self.in_channels, channel))
return nn.Sequential(*modules)

def forward(self, x):
# x: shape Bx3x224x224
x = self.conv1(x)
x = self.maxpool(x)

x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)

x = self.downsample(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)

return x

Encapsulate the Constructors

Finally, we can encapsulate the constructors by functions:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def my_resnet18(in_channels=3):
return ResNetBase(ResidualBasicBlock, [2, 2, 2, 2], in_channels)

def my_resnet34(in_channels=3):
return ResNetBase(ResidualBasicBlock, [3, 4, 6, 3], in_channels)

def my_resnet50(in_channels=3):
return ResNetBase(ResidualBottleNeck, [3, 4, 6, 3], in_channels)

def my_resnet101(in_channels=3):
return ResNetBase(ResidualBottleNeck, [3, 4, 23, 3], in_channels)

def my_resnet152(in_channels=3):
return ResNetBase(ResidualBottleNeck, [3, 8, 36, 3], in_channels)

Then, we can use it as normal models:

1
2
3
img = torch.randn(1, 3, 224, 224)
model_my = my_resnet50()
res_my = model_my(img)

After three articles talking about tensors, in this article, we will talk about something to the PyTorch Hand Written Modules Basics. You can see the outline on the left sidebar.

Basic structure

The model must inherit the nn.Module class. Basically, according to the official tutorial, nn.Module “creates a callable which behaves like a function, but can also contain state(such as neural net layer weights).”

The following is an example from the docs:

1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

Some details

  • First, our model has Name Model, and inherits the nn.Module class.
  • super().__init__() must be called at the first line of the __init__ function.
  • The Model contains two submodules as attributes, conv1 and conv2. They’re nn.Conv2d (The PyTorch implementation for 2-D convolution)
  • The forward() function do the forward-propagation of the model. It receives a tensor x and do two convolution-with-relu operation. And then return the result.
  • As for the backward-propagation, that step is calculated automatically by the powerful PyTorch’s auto-gradient technique. You don’t need to care about that.

load / store the model.state_dict()

Only model’s attributes that are subclass of nn.Module can be regarded as a valid registered parameters. These parameters are in the model.state_dict(), and can be load and store from/to the disk.

  • model.state_dict():

The state_dict() is an OrderedDict. It contains the key value pair like “Parameter Name: Tensor”

1
2
3
4
5
6
7
8
9
model.state_dict()

model.state_dict().keys()
# OUTPUT:
# odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])

model.state_dict().values()
# OUTPUT:
# odict_values([tensor([[[[ 1.0481e-01, -2.3481e-02, 9.1083e-02, 1.9955e-01, 1.0437e-01], ... ...
  • Use the following code to store the parameters of the model Model above to the disk:
1
torch.save(model.state_dict(), 'model.pth')
  • Use the following code to load the parameters from the disk:
1
model.load_state_dict(torch.load('model.pth'))

Common Submodules

This subsection introduces some common submodules used. As mentioned above, to make them as valid registered parameters, they are subclass of nn.Module or are type nn.Parameter.

clone the module

The module should be copied (cloned) by the copy.deepcopy method.

  • Shallow copy (wrong!)

The model is only shallow copied. We can see that the two models’ conv1 Tensor are the same one!!!

1
2
3
4
5
6
import copy
model = Model()
model2 = copy.copy(model) # shallow copy
print(id(model.conv1), id(model2.conv1))
# OUTPUT
2755774917472 2755774917472
  • Deep copy (right!)
1
2
3
4
5
6
import copy
model = Model()
model2 = copy.deepcopy(model) # deep copy
print(id(model.conv1), id(model2.conv1))
# OUTPUT
2755774915552 2755774916272
  • Example:

This is the code from DETR. This copies module for N times, resulting in an nn.ModuleList.

1
2
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

nn.ModuleList

nn.ModuleList is a list, but inherited the nn.Module. It can be recognized by the model correctly.

  • Wrong example: from the output, we can see the submodule is not registered correctly.
1
2
3
4
class Model2(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mlp = [nn.Linear(10, 10) for _ in range(10)]
1
2
3
print(Model2().state_dict().keys())
# OUTPUT
odict_keys([])
  • Correct example: from the output, we can see the submodule is registered correctly.
1
2
3
4
class Model3(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mlp = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])
1
2
3
print(Model3().state_dict().keys())
# OUTPUT
odict_keys(['mlp.0.weight', 'mlp.0.bias', ..., 'mlp.9.weight', 'mlp.9.bias'])

nn.ModuleDict is similar to nn.ModuleList, but a dictionary.

nn.Parameter

A plain tensor attributes can not be registered to the model. We need to wrap it with nn.Parameter, to make the model save the tensor’s state correctly.

The following is modified from the official tutorial. In this example, self.weights is merely a torch.Tensor, which cannot be regarded as a model’s state_dict. The self.bias would works normally, because it’s a nn.Parameter.

1
2
3
4
5
6
7
8
9
10
from torch import nn

class Mnist_Logistic(nn.Module):
def __init__(self):
super().__init__()
self.weights = torch.randn(784, 10) / math.sqrt(784) # WRONG
self.bias = nn.Parameter(torch.zeros(10)) # CORRECT

def forward(self, xb):
return xb @ self.weights + self.bias

Check if submodules is correctly regiestered:

1
2
3
print(Mnist_Logistic().state_dict().keys())
# OUTPUT
odict_keys(['bias']) # only `bias` regiestered! no `weights` here

nn.Sequential

This is a sequential container. Data will flow by the submodules contained one by one. An example is shown below.

1
2
3
4
5
6
7
8
from torch import nn
model =nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)

model.apply() & weight init

Applies fn recursively to every submodule (as returned by model.children()) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).

A typical example can be:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

model = Model()
# do init params with model.apply():
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
elif type(m) == nn.Conv2d:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
model.apply(init_weights)

model.eval() / model.train() / .training

The modules such as BatchNorm and DropOut performs differently on the training and evaluating stage.

We can use model.train() to set the model to the training stage. Use model.eval() to set the model to the training stage.

But, what if our own written modules need to perform differently in two stages? The answer is that, nn.Module has an attribute called training. It’s True when training, False otherwise.

1
2
3
4
5
6
7
8
class Model(nn.Module):
def __init__(self):
# skipped in this example
def forward(self, x):
if self.training:
... # write the code in training stage here
else:
... # write the code in evaluating/inferencing stage here

As we can see, when we called model.train(), actually, all submodules from model would set the training attribute to True, and False otherwise.

In this section we will talk about some PyTorch functions that operates the tensors.

torch.Tensor.expand

Signature: Tensor.expand(*sizes) -> Tensor

The expand function returns a new view of the self tensor, with singleton dimensions expanded to a larger size. The passing parameter indicates the destination size. (“singleton dimensions” means the dimension with shape 1)

Basic Usage

Passing -1 as the size for a dimension means not changing the size of that dimension.

1
2
3
4
5
6
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])
print(x)
print(x.expand(3, 4)) # torch.Size([3, 4])
print(x.expand(-1, 4)) # torch.Size([3, 4])
print(x.expand(3, -1)) # torch.Size([3, 1])
print(x.expand(-1, -1)) # torch.Size([3, 1])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# OUTPUT
tensor([[1],
[2],
[3]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
tensor([[1],
[2],
[3]])
tensor([[1],
[2],
[3]])

Wrong Usage

Only the dimension with shape 1 can be expanded:

1
2
3
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])

print(x.expand(2, 2)) # ERROR! can't expand axis 0 shape from 3 (not 1)

Why use it?

The return is only a view, not a new tensor. Therefore, if you only want to only read (not write) to an expanded tensor, use expand() will save much GPU memory. Note that modifying on the expanded tensor would make modification on the original as well.

1
2
3
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])
x.expand(3, 4)[0, 1] = 100
print(x)
1
2
3
4
# OUTPUT
tensor([[100],
[ 2],
[ 3]])

torch.Tensor.repeat

Signature: Tensor.repeat(*sizes) -> Tensor)

Repeats this tensor along the specified dimensions. It is somewhat similar to torch.Tensor.expand(), but the passing in parameter indicates the repeat times. Also, this is a deep copy.

1
2
x = torch.tensor([1, 2, 3]) # torch.Size([3])
print(x.repeat(4, 2)) # torch.Size([4, 6])
1
2
3
4
5
# OUTPUT
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])

More than the given ndimension

If the size has more dimension than the self tensor, like the example below, the x only have shape 3x1, while we have more than two input parameters, then additional dimensions will be added at the front.

1
2
3
4
5
6
7
8
9
10
11
12
13
x = torch.tensor([[1], [2], [3]]) # torch.Size([3, 1])

print(x.repeat(4, 2, 1).shape)
# torch.Size([4, 6, 1]) first 1: same. last 2 dim: [3,1]*[2,1]=[6,1]

print(x.repeat(4, 2, 1, 1).shape)
# torch.Size([4, 2, 3, 1]) first 2: same. last 2 dim: [3,1]*[1,1]=[3,1]

print(x.repeat(1, 4, 2, 1).shape)
# torch.Size([1, 4, 6, 1]) first 2: same. last 2 dim: [3,1]*[2,1]=[6,1]

print(x.repeat(1, 1, 4, 2).shape)
# torch.Size([1, 1, 12, 2]) first 2: same. last 2 dim: [3,1]*[4,2]=[12,2]

torch.Tensor.transpose

Signature: torch.transpose(input, dim0, dim1) -> Tensor

Signature: torch.Tensor.transpose(dim0, dim1) -> Tensor

Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

Therefore, like the examples below, x.transpose(0, 1) and x.transpose(1, 0) are same.

1
2
3
4
x = torch.randn(2, 3)
print(x) # shape: torch.Size([2, 3])
print(x.transpose(0, 1)) # shape: torch.Size([3, 2])
print(x.transpose(1, 0)) # shape: torch.Size([3, 2])
1
2
3
4
5
6
7
8
y = torch.randn(2, 3, 4)
print(y) # shape: torch.Size([2, 3, 4])

print(y.transpose(0, 1)) # shape: torch.Size([3, 2, 4])
print(y.transpose(1, 0)) # shape: torch.Size([3, 2, 4])

print(y.transpose(0, 2)) # shape: torch.Size([4, 3, 2])
print(y.transpose(2, 0)) # shape: torch.Size([4, 3, 2])

torch.Tensor.permute

Signature: torch.Tensor.permute(dims) -> Tensor

Signature: torch.permute(input, dims) -> Tensor

This function reorder the dimensions. See the example below.

1
2
3
4
5
y = torch.randn(2, 3, 4) # Shape: torch.Size([2, 3, 4])

print(y.permute(0, 1, 2)) # Shape: torch.Size([2, 3, 4])

print(y.permute(0, 2, 1)) # Shape: torch.Size([2, 4, 3])

Let’s have a close look to the third line as an example.

  • The first argument 0 means that the new tensor’s first dimension is the original dimension at 0, so the shape is 2.

  • The second argument 2 means that the new tensor’s second dimension is the original dimension at 2, so the shape is 4.

  • The third argument 1 means that the new tensor’s third dimension is the original dimension at 1, so the shape is 3.

Finally, the result shape is torch.Size([2, 4, 3]).

torch.Tensor.view / torch.Tensor.reshape

Signature: Tensor.view(*shape) -> Tensor

Signature: Tensor.reshape(*shape) -> Tensor

Reshape the Tensor to shape.

The function shape() always return a new copy of the tensor.

For function view(), if the shape satisfies some conditions (see here), deep copy can be avoided to save the GPU memory.

1
2
3
4
5
x = torch.randn(4, 3)
print(x) # Shape: torch.Size([4, 3])

print(x.reshape(3, 4)) # Shape: torch.Size([3, 4])
print(x.reshape(-1, 4)) # Shape: torch.Size([3, 4])
1
2
3
4
5
x = torch.randn(4, 3)
print(x) # Shape: torch.Size([4, 3])

print(x.view(3, 4)) # Shape: torch.Size([3, 4])
print(x.view(-1, 4)) # Shape: torch.Size([3, 4])

torch.cat

Signature: torch.cat(tensors, dim=0, out=None) -> Tensor

Concatenates the given sequence of tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty. For how to determine the dim, please refer to my previous article.

1
2
3
4
5
6
7
8
9
10
11
x = torch.randn(2, 3)
print(x) # Shape: torch.Size([2, 3])

y = torch.randn(2, 3)
print(y) # Shape: torch.Size([2, 3])

z = torch.cat((x, y), dim=0)
print(z) # Shape: torch.Size([4, 3]) [2+2, 3]

z = torch.cat((x, y), dim=1)
print(z) # Shape: torch.Size([2, 6]) [2, 3+3]

torch.stack

Signature: torch.stack(tensors, dim=0, out=None) -> Tensor

Concatenates a sequence of tensors along a new dimension. See example below.

1
2
3
4
5
6
7
8
9
10
11
x = torch.randn(2, 3) # Shape: torch.Size([2, 3])
y = torch.randn(2, 3) # Shape: torch.Size([2, 3])

z = torch.stack((x, y), dim=0)
print(z) # Shape: torch.Size([*2, 2, 3]) The first 2 is the new dimension

z = torch.stack((x, y), dim=1)
print(z) # Shape: torch.Size([2, *2, 3]) The second 2 is the new dimension

z = torch.stack((x, y), dim=2)
print(z) # Shape: torch.Size([2, 3, *2]) The last 2 is the new dimension

torch.vstack/hstack

torch.vsplit(...) is spliting the tensors vertically, which is equivalent to torch.split(..., dim=0).

torch.hsplit(...) is spliting the tensors horizontally, which is equivalent to torch.split(..., dim=1).

1
2
3
4
5
x = torch.randn(2, 3) # Shape: torch.Size([2, 3])
y = torch.randn(2, 3) # Shape: torch.Size([2, 3])

assert torch.vstack((x, y)).shape == torch.cat((x, y), dim=0).shape
assert torch.hstack((x, y)).shape == torch.cat((x, y), dim=1).shape

torch.split

Signature: torch.split(tensor, split_size_or_sections, dim=0)

  • If split_size_or_sections is an integer, then tensor will be split into equally sized chunks (if possible, ptherwise, last would be smaller).
1
2
3
4
x = torch.randn(4, 3) # Shape: torch.Size([4, 3])

print(torch.split(x, 2, dim=0)) # 2-item tuple, each Shape: (2, 3)
print(torch.split(x, 1, dim=1)) # 3-item tuple, each Shape: (4, 1)
  • If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
1
2
3
4
5
6
7
x = torch.randn(4, 3) # Shape: torch.Size([4, 3])

print(torch.split(x, (1, 3), dim=0))
# 2-item tuple, each Shape: (1, 3) and (3, 3)

print(torch.split(x, (1,1,1), dim=1))
# 3-item tuple, each Shape: (4, 1) and (4, 1) and (4, 1)

torch.vsplit/hsplit

This is actually similar to torch.vstack and torch.hstack. v means vertically, along dim=0, and h means horizontally, along dim=1.

1
2
3
4
5
6
7
# The followings are equivalent:
# pair 1
print(torch.vsplit(x, 3))
print(torch.split(x, 1, dim=0))
# pair 2
print(torch.hsplit(x, 4))
print(torch.split(x, 1, dim=1))

torch.flatten

Signature: torch.flatten(input, start_dim=0, end_dim=-1) -> Tensor

flatten the given dimension from start_dim to end_dim. This is especially useful when converting a 3D (image) tensor to a linear vector.

1
2
3
4
5
x = torch.randn(2, 4, 4)
print(x) # Shape: torch.Size([2, 4, 4])

flattened = torch.flatten(x, start_dim=1)
print(flattened) # Shape: torch.Size([2, 16])

Learned and made up from video and code.

Prerequisite: basic knowledge in C/C++.

Compile a Basic Program

See code here

A project root directory must contain a file called CMakeLists.txt, describing the build procedure of the project.

A typical simple CMakeLists.txt contains the following (assuming we have two source files in the current directory, main.cpp and hello.cpp):

1
2
3
4
5
cmake_minimum_required(VERSION 3.12) # describe the minimum cmake version

project(hellocmake LANGUAGES CXX) # describe the project name, and lan

add_executable(a.out main.cpp hello.cpp)

The add_executable function’s signature is add_executable(target, [src files...]), meaning to use all src files to compile the target.

To build the program, run in the shell:

1
2
3
4
cmake -B build
cmake --build build
# run with
./build/<program_name>

To clean and rebuild from scratch, just

1
rm -rf build

See code here

1
2
3
4
5
6
7
# compile static OR dynamic library
add_library(hellolib STATIC hello.cpp)
add_library(hellolib SHARED hello.cpp)

add_executable(a.out main.cpp)

target_link_libraries(a.out PUBLIC hellolib)
  • The add_library function’s signature is add_library(target, STATIC/SHARED [src files...]), meaning to use all src files to compile the static/dynamic target library.
  • Then, target_link_libraries(a.out PUBLIC hellolib) links the hellolib‘s source to the a.out.

Compile a subdirectory

See code here

The sub-directory could contain a set of source codes to compile a library/executable.

1
2
3
4
5
6
7
8
# main CMakeLists.txt
cmake_minimum_required(VERSION 3.12)
project(hellocmake LANGUAGES CXX)

add_subdirectory(hellolib) # the name of subdirectory

add_executable(a.out main.cpp)
target_link_libraries(a.out PUBLIC hellolib)
1
2
# sub-directory CMakeLists.txt
add_library(hellolib STATIC hello.cpp)

If the main.cpp uses the headers in the subdirectory hellolib, then main.cpp should write #include "hellolib/hello.h". To simplify the #include statement, we could add the following to main’s CMakeLists.txt:

1
2
3
4
...
add_executable(a.out main.cpp)
target_include_directories(a.out PUBLIC hellolib)
...

This is still some complex. If we want to build two executable, we need write the following, with repeated code:

1
2
3
4
5
6
...
add_executable(a.out main.cpp)
target_include_directories(a.out PUBLIC hellolib)
add_executable(b.out main.cpp)
target_include_directories(b.out PUBLIC hellolib)
...

A solution is to move the target_include_directories() to the subdirectory. Then all the further library/executable relied on the hellolib will include this subdirectory.

1
2
# sub-directory
target_include_directories(hellolib PUBLIC .)

If we change the PUBLIC to PRIVATE, then the further dependent would not have the effects.

For example, use the following code to link the OpenMP library.

1
2
find_package(OpenMP REQUIRD)
target_link_libraries(main PUBLIC OpenMP::OpenMP_CXX)

Use the following code to link the OpenMP library.

1
2
find_package(OpenCV REQUIRED)
target_link_libraries(main ${OpenCV_LIBS})

Further options

  • Set release type (Default type is DEBUG):
1
2
3
set(CMAKE_BUILD_TYPE Release)
# Or set it when building
cmake --build build --config Release
  • Set C++ standard:
1
SET(CMAKE_CXX_STANDARD 17)
  • Set global / special macros:
1
2
3
4
5
6
7
8
9
# global
add_definitions(-DDEBUG) # -D is not necessary
add_definitions(DEBUG)
# special target
target_compile_definitions(a.out PUBLIC -DDEBUG)
target_compile_definitions(a.out PUBLIC DEBUG)

# They have the same effect as
g++ xx.cpp -DDEBUG # (define a `DEBUG` macro to the file)
  • Set global / special compiling options:
1
2
3
4
5
6
7
# global
add_compile_options(-O2)
# special target
target_compile_options(a.out PUBLIC -O0)

# They have the same effect as
g++ xx.cpp -O0 # (add a `-O0` option in the compilation)
1
2
# Set SIMD and fast-math
target_compile_options(a.out PUBLIC -ffast-math -march=native)
  • Set global / special include directories:
1
2
3
4
# global
include_directories(hellolib)
# special target
target_include_directories(a.out PUBLIC hellolib)

CUDA with CMake

A common template can be:

1
2
3
4
5
6
7
8
9
cmake_minimum_required(VERSION 3.10)
project(main LANGUAGES CUDA CXX)

SET(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

add_executable(main main.cu)
set_target_properties(main PROPERTIES CUDA_ARCHITECTURES "86")