0%

PyTorch Practical Hand-Written Modules Basics 6--torch.autograd.Function

In this section (and also three sections in the future), we investigate how to use torch.autograd.Function to implement the hand-written operators. The tentative outline is:

  • This section (6), we talk about the basics of torch.autograd.Function.
  • In the next section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
  • In the section 8, we talk about writing C++ CUDA extension for the “linear layer” operator.
  • In the section 9, we talk details about building the extension to a module, as well as testing. Then we’ll conclude the things we’ve done so far.

Backgrounds

This article mainly takes reference of the Official tutorial and summarizes, explains the important points.

By defining an operator with torch.autograd.Function and implement its forward / backward function, we can use this operator with other PyTorch built-in operators together. The operators defined by torch.autograd.Function can be automatically back-propagated.

As mentioned in the tutorial, we should use the torch.autograd.Function in the following scenes:

  • The computation is from other libraries, so they don’t support differential natively. We should explicitly define its backward functions.
  • The PyTorch’s implementation of an operator cannot take benefits from the parallelization. We utilize the PyTorch C++/CUDA extension for the better performance.

Basic Structure

The following is the basic structure of the Function:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch.autograd import Function

class LinearFunction(Function):

@staticmethod
def forward(ctx, input0, input1, ... , inputN):
# Save the input for the backward use.
ctx.save_for_backward(input1, input1, ... , inputN)
# Calculate the output0, ... outputM given the inputs.
......
return output0, ... , outputM

@staticmethod
def backward(ctx, grad_output0, ... , grad_outputM):
# Get and unpack the input for the backward use.
input0, input1, ... , inputN = ctx.saved_tensors

grad_input0 = grad_input1 = grad_inputN = None
# These needs_input_grad records whether each input need to calculate the gradient. This can improve the efficiency.
if ctx.needs_input_grad[0]:
grad_input0 = ... # backward calculation
if ctx.needs_input_grad[1]:
grad_input1 = ... # backward calculation
......

return grad_input0, grad_input1, grad_inputN
  1. The forward and backward functions are staticmethod. The forward function is o0, ..., oM = forward(i0, ..., iN), calculate the output0 ~ outputM by the input0 ~ inputN. Then the backward function is g_i0, ..., g_iN = backward(g_o0, ..., g_M), calculate the gradient of input0 ~ gradient of inputM by the gradient of output0 ~ outputN.

  2. Since forward and backward are merely functions. We need store the input tensors to the ctx in the forward pass, so that we can get them in the backward functions. See here to use the alternative way to define Function.

  3. ctx.needs_input_grad is a tuple of Booleans. It records whether one input needs to calculate the gradient or not. Therefore, we can save computation resources if one tensor doesn’t need gradients. In that case, the return value of backward function for that tensor is None.

Use it

Pure functions

After defining the class, we can use the .apply method to use it. Simply

1
2
# Option 1: alias
linear = LinearFunction.apply

or,

1
2
3
# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)

Then call as

1
output = linear(input, weight, bias) # input, weight, bias are all tensors!

nn.Module

In most cases, the weight and bias are parameters that are trainable during the process. We can further wrap this linear function to a Linear module:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features

# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)

# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)

def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)

def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)

As mentioned in section 3, 4 of this series, the weight and bias should be nn.Parameter so that they can be recognized correctly. Then we initialize the weights with random variables.

In the forward functions, we use the defined LinearFunction.apply functions. The backward process will be automatically done just as other PyTorch modules.