In this section (and also three sections in the future), we investigate how to use torch.autograd.Function
to implement the hand-written operators. The tentative outline is:
- This section (6), we talk about the basics of
torch.autograd.Function
. - In the next section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
- In the section 8, we talk about writing C++ CUDA extension for the “linear layer” operator.
- In the section 9, we talk details about building the extension to a module, as well as testing. Then we’ll conclude the things we’ve done so far.
Backgrounds
This article mainly takes reference of the Official tutorial and summarizes, explains the important points.
By defining an operator with torch.autograd.Function
and implement its forward / backward function, we can use this operator with other PyTorch built-in operators together. The operators defined by torch.autograd.Function
can be automatically back-propagated.
As mentioned in the tutorial, we should use the torch.autograd.Function
in the following scenes:
- The computation is from other libraries, so they don’t support differential natively. We should explicitly define its backward functions.
- The PyTorch’s implementation of an operator cannot take benefits from the parallelization. We utilize the PyTorch C++/CUDA extension for the better performance.
Basic Structure
The following is the basic structure of the Function:
1 | import torch |
The
forward
andbackward
functions arestaticmethod
. The forward function iso0, ..., oM = forward(i0, ..., iN)
, calculate the output0 ~ outputM by the input0 ~ inputN. Then the backward function isg_i0, ..., g_iN = backward(g_o0, ..., g_M)
, calculate the gradient of input0 ~ gradient of inputM by the gradient of output0 ~ outputN.Since forward and backward are merely functions. We need store the input tensors to the
ctx
in the forward pass, so that we can get them in the backward functions. See here to use the alternative way to defineFunction
.ctx.needs_input_grad
is a tuple of Booleans. It records whether one input needs to calculate the gradient or not. Therefore, we can save computation resources if one tensor doesn’t need gradients. In that case, the return value of backward function for that tensor isNone
.
Use it
Pure functions
After defining the class, we can use the .apply
method to use it. Simply
1 | # Option 1: alias |
or,
1 | # Option 2: wrap in a function, to support default args and keyword args. |
Then call as
1 | output = linear(input, weight, bias) # input, weight, bias are all tensors! |
nn.Module
In most cases, the weight and bias are parameters that are trainable during the process. We can further wrap this linear function to a Linear module:
1 | class Linear(nn.Module): |
As mentioned in section 3, 4 of this series, the weight and bias should be nn.Parameter
so that they can be recognized correctly. Then we initialize the weights with random variables.
In the forward
functions, we use the defined LinearFunction.apply
functions. The backward process will be automatically done just as other PyTorch modules.