In this section (and also three sections in the future), we investigate how to use torch.autograd.Function to implement the hand-written operators. The tentative outline is:
- This section (6), we talk about the basics of
torch.autograd.Function. - In the next section (7), we’ll talk about mathematic derivation for the “linear layer” operator.
- In the section 8, we talk about writing C++ CUDA extension for the “linear layer” operator.
- In the section 9, we talk details about building the extension to a module, as well as testing. Then we’ll conclude the things we’ve done so far.
Backgrounds
This article mainly takes reference of the Official tutorial and summarizes, explains the important points.
By defining an operator with torch.autograd.Function and implement its forward / backward function, we can use this operator with other PyTorch built-in operators together. The operators defined by torch.autograd.Function can be automatically back-propagated.
As mentioned in the tutorial, we should use the torch.autograd.Function in the following scenes:
- The computation is from other libraries, so they don’t support differential natively. We should explicitly define its backward functions.
- The PyTorch’s implementation of an operator cannot take benefits from the parallelization. We utilize the PyTorch C++/CUDA extension for the better performance.
Basic Structure
The following is the basic structure of the Function:
1 | import torch |
The
forwardandbackwardfunctions arestaticmethod. The forward function iso0, ..., oM = forward(i0, ..., iN), calculate the output0 ~ outputM by the input0 ~ inputN. Then the backward function isg_i0, ..., g_iN = backward(g_o0, ..., g_M), calculate the gradient of input0 ~ gradient of inputM by the gradient of output0 ~ outputN.Since forward and backward are merely functions. We need store the input tensors to the
ctxin the forward pass, so that we can get them in the backward functions. See here to use the alternative way to defineFunction.ctx.needs_input_gradis a tuple of Booleans. It records whether one input needs to calculate the gradient or not. Therefore, we can save computation resources if one tensor doesn’t need gradients. In that case, the return value of backward function for that tensor isNone.
Use it
Pure functions
After defining the class, we can use the .apply method to use it. Simply
1 | # Option 1: alias |
or,
1 | # Option 2: wrap in a function, to support default args and keyword args. |
Then call as
1 | output = linear(input, weight, bias) # input, weight, bias are all tensors! |
nn.Module
In most cases, the weight and bias are parameters that are trainable during the process. We can further wrap this linear function to a Linear module:
1 | class Linear(nn.Module): |
As mentioned in section 3, 4 of this series, the weight and bias should be nn.Parameter so that they can be recognized correctly. Then we initialize the weights with random variables.
In the forward functions, we use the defined LinearFunction.apply functions. The backward process will be automatically done just as other PyTorch modules.