0%

PyTorch Practical Hand-Written Modules Basics 4--Hand-written modules basics

After three articles talking about tensors, in this article, we will talk about something to the PyTorch Hand Written Modules Basics. You can see the outline on the left sidebar.

Basic structure

The model must inherit the nn.Module class. Basically, according to the official tutorial, nn.Module “creates a callable which behaves like a function, but can also contain state(such as neural net layer weights).”

The following is an example from the docs:

1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

Some details

  • First, our model has Name Model, and inherits the nn.Module class.
  • super().__init__() must be called at the first line of the __init__ function.
  • The Model contains two submodules as attributes, conv1 and conv2. They’re nn.Conv2d (The PyTorch implementation for 2-D convolution)
  • The forward() function do the forward-propagation of the model. It receives a tensor x and do two convolution-with-relu operation. And then return the result.
  • As for the backward-propagation, that step is calculated automatically by the powerful PyTorch’s auto-gradient technique. You don’t need to care about that.

load / store the model.state_dict()

Only model’s attributes that are subclass of nn.Module can be regarded as a valid registered parameters. These parameters are in the model.state_dict(), and can be load and store from/to the disk.

  • model.state_dict():

The state_dict() is an OrderedDict. It contains the key value pair like “Parameter Name: Tensor”

1
2
3
4
5
6
7
8
9
model.state_dict()

model.state_dict().keys()
# OUTPUT:
# odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])

model.state_dict().values()
# OUTPUT:
# odict_values([tensor([[[[ 1.0481e-01, -2.3481e-02, 9.1083e-02, 1.9955e-01, 1.0437e-01], ... ...
  • Use the following code to store the parameters of the model Model above to the disk:
1
torch.save(model.state_dict(), 'model.pth')
  • Use the following code to load the parameters from the disk:
1
model.load_state_dict(torch.load('model.pth'))

Common Submodules

This subsection introduces some common submodules used. As mentioned above, to make them as valid registered parameters, they are subclass of nn.Module or are type nn.Parameter.

clone the module

The module should be copied (cloned) by the copy.deepcopy method.

  • Shallow copy (wrong!)

The model is only shallow copied. We can see that the two models’ conv1 Tensor are the same one!!!

1
2
3
4
5
6
import copy
model = Model()
model2 = copy.copy(model) # shallow copy
print(id(model.conv1), id(model2.conv1))
# OUTPUT
2755774917472 2755774917472
  • Deep copy (right!)
1
2
3
4
5
6
import copy
model = Model()
model2 = copy.deepcopy(model) # deep copy
print(id(model.conv1), id(model2.conv1))
# OUTPUT
2755774915552 2755774916272
  • Example:

This is the code from DETR. This copies module for N times, resulting in an nn.ModuleList.

1
2
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

nn.ModuleList

nn.ModuleList is a list, but inherited the nn.Module. It can be recognized by the model correctly.

  • Wrong example: from the output, we can see the submodule is not registered correctly.
1
2
3
4
class Model2(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mlp = [nn.Linear(10, 10) for _ in range(10)]
1
2
3
print(Model2().state_dict().keys())
# OUTPUT
odict_keys([])
  • Correct example: from the output, we can see the submodule is registered correctly.
1
2
3
4
class Model3(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mlp = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])
1
2
3
print(Model3().state_dict().keys())
# OUTPUT
odict_keys(['mlp.0.weight', 'mlp.0.bias', ..., 'mlp.9.weight', 'mlp.9.bias'])

nn.ModuleDict is similar to nn.ModuleList, but a dictionary.

nn.Parameter

A plain tensor attributes can not be registered to the model. We need to wrap it with nn.Parameter, to make the model save the tensor’s state correctly.

The following is modified from the official tutorial. In this example, self.weights is merely a torch.Tensor, which cannot be regarded as a model’s state_dict. The self.bias would works normally, because it’s a nn.Parameter.

1
2
3
4
5
6
7
8
9
10
from torch import nn

class Mnist_Logistic(nn.Module):
def __init__(self):
super().__init__()
self.weights = torch.randn(784, 10) / math.sqrt(784) # WRONG
self.bias = nn.Parameter(torch.zeros(10)) # CORRECT

def forward(self, xb):
return xb @ self.weights + self.bias

Check if submodules is correctly regiestered:

1
2
3
print(Mnist_Logistic().state_dict().keys())
# OUTPUT
odict_keys(['bias']) # only `bias` regiestered! no `weights` here

nn.Sequential

This is a sequential container. Data will flow by the submodules contained one by one. An example is shown below.

1
2
3
4
5
6
7
8
from torch import nn
model =nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)

model.apply() & weight init

Applies fn recursively to every submodule (as returned by model.children()) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).

A typical example can be:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

model = Model()
# do init params with model.apply():
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
elif type(m) == nn.Conv2d:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
model.apply(init_weights)

model.eval() / model.train() / .training

The modules such as BatchNorm and DropOut performs differently on the training and evaluating stage.

We can use model.train() to set the model to the training stage. Use model.eval() to set the model to the training stage.

But, what if our own written modules need to perform differently in two stages? The answer is that, nn.Module has an attribute called training. It’s True when training, False otherwise.

1
2
3
4
5
6
7
8
class Model(nn.Module):
def __init__(self):
# skipped in this example
def forward(self, x):
if self.training:
... # write the code in training stage here
else:
... # write the code in evaluating/inferencing stage here

As we can see, when we called model.train(), actually, all submodules from model would set the training attribute to True, and False otherwise.