The world’s leading publication for data science, AI, and ML professionals.

PyTorch Lightning: Making your Training Phase Cleaner and Easier

An introduction to PyTorch Lightning with step-by-step examples to get you started now

Photo by NOAA on Unsplash
Photo by NOAA on Unsplash

PyTorch has become one of the preferred frameworks by industry and academia due to the great flexibility to prototype neural network architectures, as well as a great structure to control each component of the training phase of a deep learning model. However, sometimes prototyping the training phase of a deep learning model becomes a complex task for various reasons which can cause several mistakes to be made, this is where PyTorch Lightning appears, the PyTorch extension that facilitates and organizes the PyTorch model training phase.

That is why in this blog we are going to know what PyTorch Lightning is, some of its components and we will see a couple of examples so you can start using PyTorch Lightning now. So this blog will be divided into the following sections:

  • What is Pytorch Lightning?
  • Example 1: A basic training
  • Example 2: A more advanced training

Let’s get started!

What is PyTorch Lightning?

PyTorch Lightning is a PyTorch extension for the prototyping of the training, evaluation and testing phase of PyTorch models. Also, Pytorch Lightning provides a simple, friendly and intuitive structure to organize each component of the training phase of a PyTorch model. On the other hand, PyTorch Lightning provides a great variety of functionalities and flags for a detailed customization of the training of our model. In short, PyTorch Lightning came to organize, simplify and compact the components that involve a training phase of a deep learning model such as: training, evaluation, testing, metrics tracking, experiment organization and logs.

Figure 1. From PyTorch to PyTorch Lightning | Image by author
Figure 1. From PyTorch to PyTorch Lightning | Image by author

The PyTorch Lightning project was started in 2016 by William Falcon when he was completing his PhD at NYU [1]. Subsequently PyTorch Lightning was launched in March 2019 and made public in July of the same year, it is also in 2019 that PyTorch Lightning was adopted by the NeurIPS Reproducibility Challenge as the standard to send code to such conference [2]. Lightning is an open-source project that currently has more than 180 contributors [3].

Great, we already know what PyTorch Lightning is, who started the project and why it arises, now let’s move on to the technical details. In the following sections we are going to see a couple of examples using PyTorch Lightning, the first example will explain a basic Training with which you can start using Lightning. In the second example we will see how to make use of some of the many amenities that Lightning provides, specifically we will see some important flags, the use of callbacks and how to automatically find the optimal suggested "learning rate_" as well as the optimal suggested "_batchsize", so let’s get started!

If you want to take a look at the complete implementation, I leave you the link to my GitHub repository: https://github.com/FernandoLpz/PyTorch-Lightning

Example 1: A basic training

The objective of this example is to show the basic structure of a training phase in PyTorch Lighting as well as the main and required components to be able to use it. In this example, we are going to create a dummy classification model using the breast cancer dataset (a toy dataset for classification).

So let’s get started!

In the above code snnipet we are building a simple neural network with 2 linear layers. ** It is important to note that we are not extending from _"nn.Modul_e" as we would commonly do in a pure PyTorch model, in this case we are extending from _"pl.LightningModul_e" (line 18). Likewise, in the class constructor we are initializing a couple of variables that will make use of the API provided by PyTorch Lightnin**g for the calculation of metrics, in this case the accuracy (lines 26 and 27) – it is important to mention that the calculation of metrics is can be done in several ways, in this case we will do it this way since it seems to me the most practical way.

Perfect, so far we have only defined our neural network (all normal for now). From this moment we begin to introduce the elements that are required to compact our training phase with PyTorch Lightning. Therefore, first I would like us to observe the following structure:

class NeuralNet(pl.LightningModule):
  def __init__(self, learning_rate=None):
    # Constructor
    pass

  def forward(self, x):
    # Forward
    pass
  def configure_optimizers(self):
    # The optimizer is initialized
    pass
  def training_step(self, batch, batch_idx):
    # This function is called for every batch
    pass
  def training_epoch_end(self, outputs):
    # This function is called once "n" epochs are done
    pass
  def validation_step(self, batch, batch_idx):
    # This function is called for every batch
    pass
  def validation_epoch_end(self, outputs):
    # This function is called once "n" epochs are done
    pass

As you can see, the previous structure is the class shown in code snippet 1, only we have added new functions. These functions are what PyTorch Lightning will implement to perform each procedure in a super organized way.

Important: It is very important to mention that these functions are not the only ones that can be implemented, PyTorch Lightning provides a large number of functions that we can use as best suits us, in this case, for practical reasons we will implement the aforementioned since we will allow to approach in a didactic way the dynamics of training with PyTorch Lightning.

As we can see, the structure of the class is self-explanatory, that is, in addition to defining the neural network and the forward function, with PyTorch Lightning we can define what we want to be done in each batch execution as well as in each epoch for both the training and validation data, on the other hand, we also observe that the optimizer is isolated, which allows us to have a better organization of each element of the training phase.

So the real full class will look like:

As you can see, the content of each function is super simple, this is the practicality, flexibility and cleanliness that we get with PyTorch Lightning. Some important points to note:

  • For the metrics calculation, you can just make use of the built-in metric functions that Lightning provides.
  • The logs are saved "automatically", that is, you do not have to use a specific library, it only makes use of the "self.log" variable that comes from the "pl.LightningModule" extension, your logs will be saved automatically.
  • Obviously you can customize your own metrics as well as the logs, for this I recommend that you take a look at the documentation

Perfect, until now we have the class that defines our model ready as well as the elements for training, we only need the data to be able to start the training, so let’s go for it!

As we can see, in lines 2 and 3 we are downloading and splitting the data, in lines 6 to 11 we are transforming the arrays into PyTorch tensors. In lines 14 and 15 as well as 18 and 19, we are using the PyTorch "Datasets" and "DataLoaders" utility. So far everything is normal, the previous steps we would also have done if we were to train a pure PyTorch model. The big difference comes next, in line 22 we initialize our model, as we can see, we are only passing "_learningrate" as a parameter. In line 25 we are starting the "Trainer" of PyTorch Lightning, and this (the Trainer) is the key piece for the training. We can pass a large number of arguments to the "Trainer" in order to have a detailed customization of our training phase (for more options, I leave you the link to the documentation), in this case we are only defining the maximum number of epochs. Finally on line 28 we execute the training phase, for which we pass the model and datasets as arguments, that’s all!

If you want to take a look at the complete implementation, I leave you the link to my GitHub repository: https://github.com/FernandoLpz/PyTorch-Lightning

Perfect, with what is explained in this section you can start implementing your training phase with PyTorch Lightning. In the next section we are going to see an example where we do a more detailed customization of the same model, let’s go for it!

Example 2: A more advanced training

In this example we are going to make some changes with respect to the previous one, specifically we are going to:

  • Add within the model class, the dataset handlers (for training, validation and testing). This allows us to have a better organization of each dataset.
  • We are going to implement a couple of functionalities that PyTorch Lightning provides us which are: automatically find the optimal "_learningrate" as well as the optimal "_batchsize".

So let’s do it!

The first important thing that we have to notice is the "setup" function (line 51), this function will be called when the model is in "training" or "testing" mode by passing the respective datasets for each respective stage. On the other hand, it is important to note the use of the "_preparedata" function (line 46), where we basically have to do whatever we need to get our dataset ready. Finally, the handlers of each dataset, basically return the required dataset as an object of a "DataLoader".

Well, so far we have already added the functions to manage the datasets within the model class, now let’s see how to use the functionalities that Lightning provides to find the optimal "_learningrate" as well as the optimal "_batchsize".

In line 2 we are initializing the model. In line 5 we begin the initialization of the Trainer, as we can see we are making use of some additional parameters which are:

  • _max_epochs_: Determines the maximum number of epochs
  • _check_val_every_n_epoch_: Determine how often the model will be validated
  • precision: Determines the floating point precision used for each tensor, it is recommended to use a floating point of size 16 [4]
  • _auto_scale_batch_size_: Specifies the algorithm that will be used to search for the optimal batch

On the other hand, we define the extension that we will use to search for the optimal learning rate (line 12), for this we pass the following arguments:

  • model: Our neural net model
  • _min_lr_: The minimum value that the "_learningrate" could take
  • _max_lr_: The maximum value that the "learning_rate" could take
  • mode: How the searching will work, linearly or exponentially [5]

Great! We have everything ready to proceed with the first step, find the optimal learning rate, so if we execute code snippet 5 and visualize the result, we would obtain something like that shown in Figure 2.

Figure 2. Optimal learning rate | Image by author
Figure 2. Optimal learning rate | Image by author

Perfect, we have already obtained our optimal learning rate. Now we will proceed to assign it to our model, as follows:

Fantastic, what proceeds is to find the optimal batch. This PyTorch Lightning feature refers to the largest batch that can fit the memory of the processor in question. Although it is not necessary for this example, it is important to know that this feature is available and we can make use of it when working with larger datasets. Then, given that in the trainer we define which would be the search process for the optimal batch size, it is only enough to execute the "tune" method of the "Trainer", as follows:

trainer.tune(model)

which will output something like:

Batch size 2 succeeded, trying batch size 4
Batch size 4 succeeded, trying batch size 8
Batch size 8 succeeded, trying batch size 16
Batch size 16 succeeded, trying batch size 32
Batch size 32 succeeded, trying batch size 64
Batch size 64 succeeded, trying batch size 128
Batch size 128 succeeded, trying batch size 256
Batch size 256 succeeded, trying batch size 512
Batch size 278 succeeded, trying batch size 556
Finished batch size finder, will continue with full run using batch size 278

This would mean that the optimal batch (or the one that best fits the processor memory) is the batch of size 278 (however, since we are working with a toy dataset, this suggestion this suggestion could be limited by the number of data in our dataset).

Great! With this we would have already covered our objectives, find the optimal learning rate as well as the optimal batch size, the only thing missing is to train and test as follows:

# Once everything is done, let's train the model
# trainer.fit(model)
# Testing the model
# trainer.test()

If you want to take a look at the complete implementation, I leave you the link to my GitHub repository: https://github.com/FernandoLpz/PyTorch-Lightning

Congratulations, we have reached the end of this blog!

Conclusion

In this blog we have introduced PyTorch Lightning as well as its main components. In the same way, we have mentioned the objective that PyTorch Lightning has as an extension of the well-known PyTorch framework.

On the other hand, we have shown a couple of examples in order to try to show the flexibility, order and clarity that PyTorch Lightning provides us when training a PyTorch model.

References

[1] https://www.pytorchlightning.ai/team

[2] https://reproducibility-challenge.github.io/neurips2019/resources/

[3] https://github.com/PyTorchLightning/pytorch-lightning

[4] https://pytorch-lightning.readthedocs.io/en/latest/amp.html

[5] https://arxiv.org/pdf/1506.01186.pdf


Related Articles