0%

Learn PyTorch-Lightning with Example Codes

In this blog, we’ll learn pytorch-lightning framework with some simple example codes. I believe the code will significantly help you understand it. The official doc is at here.

1
2
# Before everything starts:
import pytorch_lightning as pl

To train a deep learning model, we need to code the following components: data, model, training&testing procedures, and hyper-parameter configs. In pytorch-lightning, we have different classes handling them.

  • Data: we can define a subclass of pl.LightningDataModule to implement procedures that initialize the Dataset and DataLoader.
  • Model: implement the model just like what you did without pytorch-lightning – a subclass of nn.Module.
  • Training & Testing procedures: we can define a subclass of pl.LightningModule to implement the procedures.
  • Configs (pytorch-lightning have utilities for CLI arguments. But here we use another package called configargparse, which support both config file and CLI arguments)

1. Data

We can define a pl.LightningDataModule which implements procedures to initialize the Dataset and DataLoader.

Main components of pl.LightningDataModule:

  • __init__(): The constructor. You can save some configurations (hyper-parameters) to the class here.
  • setup(): Will be called before fit(), validate(), test(), and predict(). You can initialize the datasets as class attributes here. Note that it has an argument stage, to indicate what stage is it (Because we may create different datasets for different stages).
  • prepare_data(): Will be called only on Master process. In distribute training, setup() would be called on all processes. If you only want to only do one thing (e.g., downloading data), please implement it in the prepare_data().
  • train_dataloader(), val_dataloader(), test_dataloader(), predict_dataloader(). Will be called during the corresponding phase to get the DataLoader. It should return a DataLoader for that stage.

Here’s an Example. You can check detailed descriptions of each function at the function docstring:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# DataModule
class MyDataModule(pl.LightningDataModule):
def __init__(self, data_dir='./', batch_size=32):
"""
Initializes the data module with the some given config arguments.
This constructor sets up the initial configurations like data directory,
batch size, and transformations to be applied to the data.

Called when an instance of MyDataModule is created.
"""
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

def setup(self, stage=None):
"""
Prepares the data for use. It is automatically called at the beginning of fit,
validate, test, and predict, or when manually calling `trainer.setup`. This method
is responsible for setting up internal datasets (e.g., splitting the dataset into
training, validation, and testing sets, etc.)

`stage` can be used to differentiate between fit, validate, and test stages.
"""
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)

# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

if stage == "predict":
self.mnist_predict = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

def train_dataloader(self):
"""
Creates and returns the DataLoader for the training dataset. This method is used to
load the training data and is automatically called during the training phase.

Returns a DataLoader instance for the training data.
"""
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
"""
Creates and returns the DataLoader for the validation dataset. This method is used to
load the validation data and is automatically called during the validation phase.

Returns a DataLoader instance for the validation data.
"""
return DataLoader(self.mnist_val, batch_size=self.batch_size, shuffle=False)

def test_dataloader(self):
"""
Creates and returns the DataLoader for the testing dataset. This method is used to
load the test data and is automatically called during the testing phase.

Returns a DataLoader instance for the test data.
"""
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)

2. Model

You may directly implement the model in pl.LightningModule, but for clarity, I recommend you to implement your model as a separate nn.Module subclass, and then create the instances in the pl.LightningModule.

1
2
3
4
5
6
7
8
9
10
11
12
13
# You can do like this...
class MyLightningModule(pl.LightningModule):
def __init__(self, hidden_dim=64):
super().__init__()
self.hidden_dim = hidden_dim
self.layer1 = torch.nn.Linear(28 * 28, self.hidden_dim)
self.layer2 = torch.nn.Linear(self.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.layer1(x))
x = self.layer2(x)
return x
1
2
3
4
5
6
7
8
9
10
11
# But I recommend you to do like this:
class Model(nn.Module):
......

class MyLightningModule(pl.LightningModule):
def __init__(self, hidden_dim=64):
super().__init__()
self.model = Model(hidden_dim)

def forward(self, x):
return self.model(x)

3. Training & Testing procedures

We can define a pl.LightningModule which implements procedures for training, validating, testing and predicting.

Main components of pl.LightningModule:

  • __init__(): The constructor. You can save some configurations (hyper-parameters) to the class here.
  • forward(): Make it behave like a nn.Module. You can implement the forward procedures here (See Sec. 2 for the example).
  • configure_optimizers(): Defines the training optimizers. See here for some advanced usage.
  • train_step(), validation_step(), test_step(), predict_step(): Will be called during the corresponding phase, to perform operations for a batch. It has an argument batch_idx, which indicates the batch index of current epoch. The train_step() should return the loss (a scalar or a dict), and the module will automatically perform backward propagation.
  • on_train_epoch_start(), on_train_epoch_end(), on_XXX_epoch_start(), on_XXX_epoch_end(): Will be called at the corresponding start/end point. You can perform some epoch-level operations here. (See below example, where we collect all training batch losses to calculate the average loss on an epoch)
  • Use self.log(tag, value) to log the batch’s “tag” indicator with value “value” to the tensorboard / log file.

Here’s an Example. You can check detailed descriptions of each function at the function docstring:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class MyLightningModule(pl.LightningModule):
def __init__(self, hidden_dim=64):
"""
Initializes the LightningModule with two linear layers. This constructor
sets up the neural network architecture for the module.

Called when an instance of MyLightningModule is created.
"""
super().__init__()
self.hidden_dim = hidden_dim
self.layer1 = torch.nn.Linear(28 * 28, self.hidden_dim)
self.layer2 = torch.nn.Linear(self.hidden_dim, 10)
self.training_results = [] # To store training loss in each epoch

def forward(self, x):
"""
Defines the forward pass of the model. It is automatically called when
the model is used to make predictions.

'x' is the input data. The method reshapes the input, applies a ReLU
activation after the first linear layer, and then passes it through the
second linear layer.
"""
x = x.view(x.size(0), -1)
x = torch.relu(self.layer1(x))
x = self.layer2(x)
return x

def training_step(self, batch, batch_idx):
"""
Defines a single step in the training loop. It is automatically called for each
batch of data during training.

'batch' contains the input data and labels, and 'batch_idx' is the index of
the current batch. The method computes the model's loss using cross-entropy.
"""
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('train_loss', loss)
self.training_results.append(loss)
return loss

def validation_step(self, batch, batch_idx):
...

def test_step(self, batch, batch_idx):
...

def configure_optimizers(self):
"""
Configures the optimizers used for training. This method is automatically
called to configure the optimization algorithm.

Returns an optimizer, in this case, Adam, with a set learning rate.
See advanced usage in official docs.
"""
return torch.optim.Adam(self.parameters(), lr=0.001)

# Epoch-wise procedures
def on_train_epoch_start(self):
"""
Called at the start of every training epoch. Can be used to perform
actions specific to the beginning of each training epoch.
"""
print("Training Epoch Start")

def on_train_epoch_end(self):
"""
Called at the end of every training epoch. Can be used to perform
actions specific to the end of each training epoch.
"""
print("Training Epoch End")
print("Training Loss: ", sum(self.training_results) / len(self.training_results))
self.training_results.clear()

def on_validation_epoch_start(self):
...

def on_validation_epoch_end(self):
...

def on_test_epoch_start(self):
...

def on_test_epoch_end(self):
...

def on_predict_epoch_start(self):
...

def on_predict_epoch_end(self):
...