In this blog, we’ll learn pytorch-lightning framework with some simple example codes. I believe the code will significantly help you understand it. The official doc is at here.
1 2
# Before everything starts: import pytorch_lightning as pl
To train a deep learning model, we need to code the following components: data, model, training&testing procedures, and hyper-parameter configs. In pytorch-lightning, we have different classes handling them.
Data: we can define a subclass of pl.LightningDataModule to implement procedures that initialize the Dataset and DataLoader.
Model: implement the model just like what you did without pytorch-lightning – a subclass of nn.Module.
Training & Testing procedures: we can define a subclass of pl.LightningModule to implement the procedures.
Configs (pytorch-lightning have utilities for CLI arguments. But here we use another package called configargparse, which support both config file and CLI arguments)
1. Data
We can define a pl.LightningDataModule which implements procedures to initialize the Dataset and DataLoader.
Main components of pl.LightningDataModule:
__init__(): The constructor. You can save some configurations (hyper-parameters) to the class here.
setup(): Will be called before fit(), validate(), test(), and predict(). You can initialize the datasets as class attributes here. Note that it has an argument stage, to indicate what stage is it (Because we may create different datasets for different stages).
prepare_data(): Will be called only on Master process. In distribute training, setup() would be called on all processes. If you only want to only do one thing (e.g., downloading data), please implement it in the prepare_data().
train_dataloader(), val_dataloader(), test_dataloader(), predict_dataloader(). Will be called during the corresponding phase to get the DataLoader. It should return a DataLoader for that stage.
Here’s an Example. You can check detailed descriptions of each function at the function docstring:
# DataModule classMyDataModule(pl.LightningDataModule): def__init__(self, data_dir='./', batch_size=32): """ Initializes the data module with the some given config arguments. This constructor sets up the initial configurations like data directory, batch size, and transformations to be applied to the data. Called when an instance of MyDataModule is created. """ super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) defsetup(self, stage=None): """ Prepares the data for use. It is automatically called at the beginning of fit, validate, test, and predict, or when manually calling `trainer.setup`. This method is responsible for setting up internal datasets (e.g., splitting the dataset into training, validation, and testing sets, etc.) `stage` can be used to differentiate between fit, validate, and test stages. """ # Assign train/val datasets for use in dataloaders if stage == "fit": mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split( mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42) ) # Assign test dataset for use in dataloader(s) if stage == "test": self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform) if stage == "predict": self.mnist_predict = datasets.MNIST(self.data_dir, train=False, transform=self.transform) deftrain_dataloader(self): """ Creates and returns the DataLoader for the training dataset. This method is used to load the training data and is automatically called during the training phase. Returns a DataLoader instance for the training data. """ return DataLoader(self.mnist_train, batch_size=self.batch_size) defval_dataloader(self): """ Creates and returns the DataLoader for the validation dataset. This method is used to load the validation data and is automatically called during the validation phase. Returns a DataLoader instance for the validation data. """ return DataLoader(self.mnist_val, batch_size=self.batch_size, shuffle=False) deftest_dataloader(self): """ Creates and returns the DataLoader for the testing dataset. This method is used to load the test data and is automatically called during the testing phase. Returns a DataLoader instance for the test data. """ return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)
2. Model
You may directly implement the model in pl.LightningModule, but for clarity, I recommend you to implement your model as a separate nn.Module subclass, and then create the instances in the pl.LightningModule.
1 2 3 4 5 6 7 8 9 10 11 12 13
# You can do like this... classMyLightningModule(pl.LightningModule): def__init__(self, hidden_dim=64): super().__init__() self.hidden_dim = hidden_dim self.layer1 = torch.nn.Linear(28 * 28, self.hidden_dim) self.layer2 = torch.nn.Linear(self.hidden_dim, 10)
defforward(self, x): x = x.view(x.size(0), -1) x = torch.relu(self.layer1(x)) x = self.layer2(x) return x
1 2 3 4 5 6 7 8 9 10 11
# But I recommend you to do like this: classModel(nn.Module): ......
We can define a pl.LightningModule which implements procedures for training, validating, testing and predicting.
Main components of pl.LightningModule:
__init__(): The constructor. You can save some configurations (hyper-parameters) to the class here.
forward(): Make it behave like a nn.Module. You can implement the forward procedures here (See Sec. 2 for the example).
configure_optimizers(): Defines the training optimizers. See here for some advanced usage.
train_step(), validation_step(), test_step(), predict_step(): Will be called during the corresponding phase, to perform operations for a batch. It has an argument batch_idx, which indicates the batch index of current epoch. The train_step() should return the loss (a scalar or a dict), and the module will automatically perform backward propagation.
on_train_epoch_start(), on_train_epoch_end(), on_XXX_epoch_start(), on_XXX_epoch_end(): Will be called at the corresponding start/end point. You can perform some epoch-level operations here. (See below example, where we collect all training batch losses to calculate the average loss on an epoch)
Use self.log(tag, value) to log the batch’s “tag” indicator with value “value” to the tensorboard / log file.
Here’s an Example. You can check detailed descriptions of each function at the function docstring:
classMyLightningModule(pl.LightningModule): def__init__(self, hidden_dim=64): """ Initializes the LightningModule with two linear layers. This constructor sets up the neural network architecture for the module. Called when an instance of MyLightningModule is created. """ super().__init__() self.hidden_dim = hidden_dim self.layer1 = torch.nn.Linear(28 * 28, self.hidden_dim) self.layer2 = torch.nn.Linear(self.hidden_dim, 10) self.training_results = [] # To store training loss in each epoch
defforward(self, x): """ Defines the forward pass of the model. It is automatically called when the model is used to make predictions. 'x' is the input data. The method reshapes the input, applies a ReLU activation after the first linear layer, and then passes it through the second linear layer. """ x = x.view(x.size(0), -1) x = torch.relu(self.layer1(x)) x = self.layer2(x) return x
deftraining_step(self, batch, batch_idx): """ Defines a single step in the training loop. It is automatically called for each batch of data during training. 'batch' contains the input data and labels, and 'batch_idx' is the index of the current batch. The method computes the model's loss using cross-entropy. """ x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('train_loss', loss) self.training_results.append(loss) return loss
defvalidation_step(self, batch, batch_idx): ...
deftest_step(self, batch, batch_idx): ...
defconfigure_optimizers(self): """ Configures the optimizers used for training. This method is automatically called to configure the optimization algorithm. Returns an optimizer, in this case, Adam, with a set learning rate. See advanced usage in official docs. """ return torch.optim.Adam(self.parameters(), lr=0.001)
# Epoch-wise procedures defon_train_epoch_start(self): """ Called at the start of every training epoch. Can be used to perform actions specific to the beginning of each training epoch. """ print("Training Epoch Start")
defon_train_epoch_end(self): """ Called at the end of every training epoch. Can be used to perform actions specific to the end of each training epoch. """ print("Training Epoch End") print("Training Loss: ", sum(self.training_results) / len(self.training_results)) self.training_results.clear()