
This post shows how to explore datasets and prepare a baseline method with PyTorch Lightning and TorchVision model.
In a previous post, we described best practices for configuring a Kaggle development environment and ingesting a dataset.
How to Prepare your Development Environment to Rank on Kaggle
Now that we have an environment configured, let’s walk through our Kaggle kernel and explain each cell step by step. The main stages will be: (1) loading data, (2) checking label distribution, and (3) watching some sample images from each class. Later we start (4) wrapping data to PyTorch classes, aka Dataset and Dataloader. In the end, we put it all together and (5) train a simple baseline CNN model.
I have chosen PyTorch Lightning for this task as it helps to decouple my Data Science code from deep learning engineering, empowering us to focus on:
a) Loading and processing data with [LightningDataModules](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html)
.
b) Selecting which model/architecture to use in our [LightningModules](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html)
.
c) Evaluating performance with [TorchMetrics](https://torchmetrics.readthedocs.io/en/stable/)
.
GitHub – Borda/kaggle_plant-pathology: Identify the type of disease present on Appletree leafs
Data Exploration
Although the Plant Pathology 2021 – FGVC8 challenge organisers provide task descriptions, I always recommend doing some own data exploration to verify that you know how to manipulate the given data.
To start, we check the images folder and labels, load them and observe what annotation format is used. In this case, it is a two-column table with image names in the first columns and multiple string labels in the second column.

For classification tasks, it is helpful to check the label distribution to get some sense of what to expect. Ideally, there should be an almost equal number of samples in each class; if not, you should deploy some technique to balance your Dataset if it is heavily unbalanced (balancing techniques are out of the scope of this post).

Wrapping data to PyTorch classes
[Pytorch](https://pytorch.org/) Lightning is just organised Pytorch under the hood to use standard PyTorch objects and methods for data handling such as theDataset
Dataloader
objects. The Datasets
defines how to load the data from storage (e.g. hard drive) and couple input data (images) with annotations (labels). The primary method which needs to be implemented is __getitem__
which in simplicity may look as following:

As with almost all classification machine learning tasks, the model does not work with text labels directly, and we need to transform our text into a representation that we can model.
We represent the labels using a binary encoding – 1
for the particular label is present in the image and 0
otherwise. The resulting binary vector position corresponds to the label index in the fixed set of all possible labels overall sub-dataset divisions – training/validation/testing.
For example, an image of a healthy sample is represented by [0,0,1,0,0,0]
and an image labeled with "rust scab" is [0,0,0,0,1,1]
. Below, we show a few samples of image-encoding pairs from our implemented Dataset
:

Wrapping data to Lightning class
The last data-related step is defining aLightningDataModule
. This PyTorch Lightning component encapsulates all the data handling steps, including:
- Creating a
Dataset
- Splitting the
Dataset
into training/validation(/testing) sub-datasets - Wrapping sub-datasets into particular PyTorch native
DataLoader
These DataLoaders
are then directly accessed LightningDataModule
during model training. Below is an example to give you a brief idea about the LightningDataModule
structure:

Basic data augmentation
Data augmentation is an essential machine learning technique that aims to extend variability and hopefully prevent overfitting artificially by a synthetic extension of the training dataset. The idea of augmentation is to generate new samples in the eventual appearance direction, which may be missing in the training dataset. However, it is still probable to appear later in the validation or production. In the image domain, the typical augmentation is geometrical transformation and color/intensity changes.
We use a random combination of vertical/horizontal flipping, rotation, minor crop, and small perspective transformation to simulate different observation positions (see code-snapshot bellow with a sequence of torchvision
transformations).

Baseline Model Fitting
Before training a complex model, we should sanity check our data set with a baseline model to verify that our training pipeline does not leak. This best practice eliminates downstream confusion since if we were to start with a very complex model that did not converge, we would not know if the problem lies with the data processing, the training process, or our model itself.
Adjusting TorchVision model
We recommend a simple ResNet50 (generally considered and a good trade-off between model complexity and learning capacity) from the TorchVision package with pre-trained weights to speed up any convergence and wrap it into a Module.
Wrapping the pre-trained model is required as the base ResNet50 has 1000 outputs (as it is trained on the ImageNet dataset), and our classification needs only six outputs/classes – we replace the last linear layer with a new one.

Preparing Lightning model
The LightningModule
is a core object of PyTorch Lightning, which aggregates all model-related processes – model (architecture & weights), how to perform training/validation step, logging metrics, etc.
The first step is to define a new LightningModule
– wrapping our model, metrics, loss function, and other parameters necessary for training. For the Image Classification metric, we use the classical Accuracy, Precision and F1 score metrics. As a loss function, we use Binary Cross Entropy With Logits since the data is multi-label data.

The next step in writing a LightningModule
is to define the training and validation step where we specify how data shall be fed to the model to obtain prediction and compute the loss. The loss is passed to optimization for automatics weights update. You can see that they are very similar as we want to track similar metrics for training and validation; the difference is in training_step
requiring return loss.

The last step to defining which optimizer we want to use and an optional scheduler that manages learning rate changes over training depending on the training progress (driven by training step index). We use a weighted Adam optimizer with default parameters.

Traning the model with Lightning
PyTorch Lightning has a rich ecosystem of callbacks. They provide out-of-the-box support for best practices from checkpointing (saving model weights during the training) and early-stopping (stop training if the model does not improve anymore) to advanced techniques such as Pruning or Stochastic Weight Averaging.
We configure checkpointing for saving model with the best validation F1-score during training:
ckpt = pl.callbacks.ModelCheckpoint(
monitor='valid_f1',
filename='checkpoint/{epoch:02d}-{valid_f1:.4f}',
mode='max',
)
Now we have all the code we need to start training.
In PyTorch Lightning, the training is abstracted to the Lightning Trainer
. We set training parameters such as learning rates, the number of epochs, how many GPUs we want to use (for Kaggle use all of them), training precision (we lower the float precision from 32bits to 16bits which does not harm training performance but allows us to double the amount of data in can fit in each batch) and more.
trainer = pl.Trainer(
gpus=1,
callbacks=[ckpt],
max_epochs=35,
precision=16,
accumulate_grad_batches=24,
val_check_interval=0.5,
)
To train call trainer.fit(model=model, datamodule=dm)
which kicks off the training of our model… sit back in your comfortable chair and watch how your model learns…

In this post, we shared how to screen the given dataset and what interesting aspects you should look at. We showed how to wrap a file-like dataset to the PyTorch class dataset, which is a core of the data handling. Moreover, we wrote a basic image multi-label classification model within PyTorchLightning based on a TorchVision model and trained it seamlessly on GPU with mixed precision without extra code.
In the future, I will show how to convert a notebook to a sharable Python package with some minimal sustainability tips and how to write a simple CLI with exposed training parameters for easier hyper-parameter searches.
Converting Scientific Kaggle Notebooks to Friendly Python Package
Stay tuned and follow me to learn more!
Best Practices to Rank on Kaggle Competition with PyTorch Lightning and Grid.ai Spot Instances
About the Author
Jirka Borovec has **** been working in Machine learning and Data science for several years in a few different IT companies. In particular, he enjoys exploring interesting world problems and solving them with state-of-the-art techniques. In addition, he developed several open-source python packages and actively participating in other well-known projects. He works in _Grid.a_i as Research Engineer and serves as a lead contributor of _PyTorchLightning.a_i.