After three articles talking about tensors, in this article, we will talk about something to the PyTorch Hand Written Modules Basics. You can see the outline on the left sidebar.
Basic structure
The model must inherit the nn.Module
class. Basically, according to the official tutorial, nn.Module
“creates a callable which behaves like a function, but can also contain state(such as neural net layer weights).”
The following is an example from the docs:
1 | import torch.nn as nn |
Some details
- First, our model has Name
Model
, and inherits thenn.Module
class. super().__init__()
must be called at the first line of the__init__
function.- The
Model
contains two submodules as attributes,conv1
andconv2
. They’renn.Conv2d
(The PyTorch implementation for 2-D convolution) - The
forward()
function do the forward-propagation of the model. It receives a tensorx
and do two convolution-with-relu operation. And then return the result. - As for the backward-propagation, that step is calculated automatically by the powerful PyTorch’s auto-gradient technique. You don’t need to care about that.
load / store the model.state_dict()
Only model’s attributes that are subclass of nn.Module
can be regarded as a valid registered parameters. These parameters are in the model.state_dict()
, and can be load and store from/to the disk.
model.state_dict()
:
The state_dict()
is an OrderedDict
. It contains the key value pair like “Parameter Name: Tensor”
1 | model.state_dict() |
- Use the following code to store the parameters of the model
Model
above to the disk:
1 | torch.save(model.state_dict(), 'model.pth') |
- Use the following code to load the parameters from the disk:
1 | model.load_state_dict(torch.load('model.pth')) |
Common Submodules
This subsection introduces some common submodules used. As mentioned above, to make them as valid registered parameters, they are subclass of nn.Module
or are type nn.Parameter
.
clone the module
The module should be copied (cloned) by the copy.deepcopy
method.
- Shallow copy (wrong!)
The model is only shallow copied. We can see that the two models’ conv1
Tensor are the same one!!!
1 | import copy |
- Deep copy (right!)
1 | import copy |
- Example:
This is the code from DETR. This copies module
for N times, resulting in an nn.ModuleList
.
1 | def _get_clones(module, N): |
nn.ModuleList
nn.ModuleList
is a list, but inherited the nn.Module
. It can be recognized by the model correctly.
- Wrong example: from the output, we can see the submodule is not registered correctly.
1 | class Model2(nn.Module): |
1 | print(Model2().state_dict().keys()) |
- Correct example: from the output, we can see the submodule is registered correctly.
1 | class Model3(nn.Module): |
1 | print(Model3().state_dict().keys()) |
nn.ModuleDict
is similar to nn.ModuleList
, but a dictionary.
nn.Parameter
A plain tensor attributes can not be registered to the model. We need to wrap it with nn.Parameter
, to make the model save the tensor’s state correctly.
The following is modified from the official tutorial. In this example, self.weights
is merely a torch.Tensor
, which cannot be regarded as a model’s state_dict
. The self.bias
would works normally, because it’s a nn.Parameter
.
1 | from torch import nn |
Check if submodules is correctly regiestered:
1 | print(Mnist_Logistic().state_dict().keys()) |
nn.Sequential
This is a sequential container. Data will flow by the submodules contained one by one. An example is shown below.
1 | from torch import nn |
model.apply() & weight init
Applies fn
recursively to every submodule (as returned by model.children()
) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).
A typical example can be:
1 | class Model(nn.Module): |
model.eval() / model.train() / .training
The modules such as BatchNorm
and DropOut
performs differently on the training and evaluating stage.
We can use model.train()
to set the model to the training stage. Use model.eval()
to set the model to the training stage.
But, what if our own written modules need to perform differently in two stages? The answer is that, nn.Module
has an attribute called training
. It’s True
when training, False
otherwise.
1 | class Model(nn.Module): |
As we can see, when we called model.train()
, actually, all submodules from model
would set the training
attribute to True
, and False
otherwise.