0%

PyTorch Practical Hand-Written Modules Basics 5--Implement a ResNet

In this section, we’ll utilize knowledge we learnt from the last section (see here), to implement a ResNet Network (paper).

Note that we follow the original paper’s work. Our implementation is a simper version of the official torchvision implementation. (That is, we only implement the key structure, and the random weight init. We don’t consider dilation or other things).

Preliminaries: Calculate the feature map size

  • Basic formula

Given a convolution kernel with size K, and the padding P, the stride S, feature map size I, we can calculate the output size as O = ( I - K + 2P ) / S + 1.

  • Corollary

Based on the formula above, we know that when S=1:

  1. K=3, P=1 makes the input size and output size same.
  2. K=1, P=0 makes the input size and output size same.

Overall Structure

The Table 1 in the original paper illustrates the overall structure of the ResNet:

resnet_table1

We know that from conv2, each layer consists of many blocks. And the blocks in 18, 34 layers is different from blocks in 50, 101, 152 layers.

We have several deductions:

  1. When the feature map enters the next layer, the first block need to do a down sampling operation. This is done by setting the one of the convolution kernel’s stride=2.
  2. At other convolution kernels, the feature map’s size is same. So the convolution settings is same as the one referred in Preliminaries.

Basic Block Implementation

The basic block’s structure looks like this:

basic

Please see the code below. Here, apart from channels defining the channels in the block, we have three additional parameters, in_channels, stride, and downsample to make this block versatile in the FIRST block in each layer.

According to the ResNet structure, for example, the first block in layer3 has the input 64*56*56. The first block in layer3 has two tasks:

  1. Make the feature map size to 28*28. Thus we need to set its stride to 2.
  2. Make the number of channels from 64 to 128. Thus the in_channel should be 64.
  3. In addition, since the input is 64*56*56, while the output is 128*28*28, we need a down sample convolution to match the shortcut input to the output size.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
class ResidualBasicBlock(nn.Module):
expansion: int = 1
def __init__(self, in_channels: int, channels: int, stride: int = 1, downsample: nn.Module = None):
super().__init__()
self.downsample = downsample
self.conv1 = nn.Conv2d(in_channels, channels, 3, stride, 1)
self.batchnorm1 = nn.BatchNorm2d(channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
self.batchnorm2 = nn.BatchNorm2d(channels)
self.relu2 = nn.ReLU()

def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batchnorm1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.batchnorm2(x)
if self.downsample:
residual = self.downsample(residual)
x += residual
x = self.relu2(x)
return x

Bottleneck Block Implementation

The bottleneck block’s structure looks like this:

bottleneck

To reduce the computation cost, the Bottleneck block use 1x1 kernel to map the high number of channels (e.g., 256) to a low one (e.g., 64), and do the 3x3 convolution. Then, it maps the 64 channels to 256 again.

Please see the code below. Same as the basic block, We have three additional parameters, in_channels, stride, and downsample to make this block versatile in the FIRST block in each layer. The reasons are same as above.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ResidualBottleNeck(nn.Module):
expansion: int = 4
def __init__(self, in_channels: int, channels: int, stride: int = 1, downsample: nn.Module = None):
super().__init__()
self.downsample = downsample
self.conv1 = nn.Conv2d(in_channels, channels, 1, 1)
self.batchnorm1 = nn.BatchNorm2d(channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, 3, stride, 1)
self.batchnorm2 = nn.BatchNorm2d(channels)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(channels, channels*4, 1, 1)
self.batchnorm3 = nn.BatchNorm2d(channels*4)
self.relu3 = nn.ReLU()

def forward(self, x):
residual = x
x = self.conv1(x)
x = self.batchnorm1(x)
x = self.relu1(x)

x = self.conv2(x)
x = self.batchnorm2(x)
x = self.relu2(x)

x = self.conv3(x)
x = self.batchnorm3(x)

if self.downsample:
residual = self.downsample(residual)

x += residual
x = self.relu3(x)
return x

ResNet Base Implementation

Then we can put thing together to form the ResNet model! The whole structure is straight-forward. We define the submodules one by one, and implement the forward() function.

There is only two tricky point:

  1. To support the ResNetBase for two different base blocks, the base block can be passed to this initializer. Since two base blocks have slightly differences in setting the channels, ResidualBasicBlock and ResidualBottleNeck have an attribute called expansion, which convenient the procedure in setting the correct number of channels and outputs.
  2. See the _make_layer function below. It need to determine whether we need to do the down sample. And the condition and explanation is described below.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class ResNetBase(nn.Module):
def __init__(self, block, layer_blocks: list, input_channels=3):
super().__init__()
self.block = block
# conv1: 7x7
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)
# max pool
self.maxpool = nn.MaxPool2d(3, 2, 1)
# conv2 ~ conv5_x
self.in_channels = 64
self.conv2 = self._make_layer(64, layer_blocks[0])
self.conv3 = self._make_layer(128, layer_blocks[1], 2)
self.conv4 = self._make_layer(256, layer_blocks[2], 2)
self.conv5 = self._make_layer(512, layer_blocks[3], 2)

self.downsample = nn.AvgPool2d(7)
output_numel = 512 * self.block.expansion
self.fc = nn.Linear(output_numel, 1000)

# init the weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, channel, replicates, stride=1):
modules = []

downsample = None
if stride != 1 or self.in_channels != channel*self.block.expansion:
# Use downsample to match the dimension in two cases:
# 1. stride != 1, meaning we should downsample H, W in this layer.
# Then we need to match the residual's H, W and the output's H, W of this layer.
# 2. self.in_channels != channel*block.expansion, meaning we should increase C in this layer.
# Then we need to match the residual's C and the output's C of this layer.

downsample = nn.Sequential(
nn.Conv2d(self.in_channels, channel*self.block.expansion, 1, stride),
nn.BatchNorm2d(channel*self.block.expansion)
)

modules.append(self.block(self.in_channels, channel, stride, downsample))

self.in_channels = channel * self.block.expansion
for r in range(1, replicates):
modules.append(self.block(self.in_channels, channel))
return nn.Sequential(*modules)

def forward(self, x):
# x: shape Bx3x224x224
x = self.conv1(x)
x = self.maxpool(x)

x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)

x = self.downsample(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)

return x

Encapsulate the Constructors

Finally, we can encapsulate the constructors by functions:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def my_resnet18(in_channels=3):
return ResNetBase(ResidualBasicBlock, [2, 2, 2, 2], in_channels)

def my_resnet34(in_channels=3):
return ResNetBase(ResidualBasicBlock, [3, 4, 6, 3], in_channels)

def my_resnet50(in_channels=3):
return ResNetBase(ResidualBottleNeck, [3, 4, 6, 3], in_channels)

def my_resnet101(in_channels=3):
return ResNetBase(ResidualBottleNeck, [3, 4, 23, 3], in_channels)

def my_resnet152(in_channels=3):
return ResNetBase(ResidualBottleNeck, [3, 8, 36, 3], in_channels)

Then, we can use it as normal models:

1
2
3
img = torch.randn(1, 3, 224, 224)
model_my = my_resnet50()
res_my = model_my(img)