The world’s leading publication for data science, AI, and ML professionals.

Practical Transfer Learning with PyTorch

Capitalizing on an already functional deep neural network is a huge speedup when solving an ML problem

Photo by Omar Flores on Unsplash
Photo by Omar Flores on Unsplash

In a previous post, I explained how PyTorch and XGBoost can be combined to perform transfer learning.

Transfer learning with XGBoost and PyTorch: hack Alexnet for MNIST dataset

This was quite an unconventional way of transferring learning, mixing deep neural networks and gradient Booted trees. To illustrate the method, I explain how the knowledge acquired by AlexNet for categorizing images from the ILSVRC dataset can be transferred to another use case with XGBoost: identify hand­written digits.

A few days after writing this post, I realized that the more common use case involving transfer learning with Neural Networks only was not really covered online for PyTorch. However, an excellent article from Dipanjan (DJ) Sarkar presents Deep Learning in detail, with some code but using Keras.

A Comprehensive Hands-on Guide to Transfer Learning with Real-World Applications in Deep Learning

I highly recommend reading this post.

In this post, we are going to fill this gap regarding transfer learning and pytorch.

We will do the same transfer as in the previous post, using only Neural Network method and compare performances when mixing XGBoost or using only Neural Networks.

Motivations

Before digging in the code, it is worth reminding why transfer learning is interesting. The main reason is time: a lot of time and energy is required to collect and label a big enough database for deep learning, and a lot of time and energy is also necessary for configuring, fitting, and tuning a deep neural network.

Capitalizing on an already functional deep Neural Network is a huge speedup when solving an ML problem.

In the case of deep neural networks designed for computer vision, most of the complexity is contained in the layers building features. The layers responsible for the classification are usually only 3 layers deep.

Being able to exploit these pre-trained features computing layers is very appealing.

The next section will show how this can be done using conventional deep learning methods.

The MNIST use case

For the sake of simplicity, we want to use a database easily available in Python. MNIST looks like a good candidate. It is open source and can be loaded in a sec using TorchVision. Train and test sets are quite large: 60k images for the train set, and 10k for the test set. Moreover, they don’t have the same dimensions as the images used for training AlexNet. It’s interesting to see how we can handle it.

AlexNet has been initially trained to categorize the images of ILSVRC with 1 000 labels. During this training, it has learned to build visual features with a deep Convolutional Neural Network. The assumption that we make in this part is that these features can be reused to perform a completely different classification, i. e. identify handwritten digits with only 10 labels.

Retargeting Alex Net

We need at this point is a way to keep the structure and the weights of the features layers and retrained the classification ones.

This is achieved by the code below:

It creates a new neural networks model, that gets an original model as input, and generates a new one, where the classifier network has been replaced by a new classifier. The features layers remained untouched.

The important part lies in the lines that freeze the weights of the features layers. This forces the layers that compute the features to remain untouched during the training.

Those are the lines responsible for the transfer learning, as all the knowledge acquired during the initial training, on the ILSVRC dataset will be reused without modification in a new use case.

Learning to recognize handwritten digits

All we have to do now is to build a new model, based on the original, pretrained AlexNet, and refit a new classifier to recognize handwritten digits.

Thanks to pytorchvision and Pytorch, this is a pretty simple task. The code below explains how:

The approach is straightforward, except for the transform part, which is required for resizing the MNIST images to fit the size of the dataset used for training AlexNet. This is done by the transform method.

The number of labels has been downsized from 1k to 10.

At the end of the training, which take quite some time on my laptop, the model is serialized and stored. We are going to use this dump to evaluate the performance of this transfer, and check that this approach is relevant.

Evaluation

scikit learns provides two handy methods for evaluating the performances of a classification model: _confusion_matrix and classification_report_.

The first one counts the labels used for each class, based on the real value. The second one computes standard features, like f1-score, recall, precision, …

We are going to apply these two methods to the test dataset provided by torchvision. This dataset contains 10k handwritten digits, which is a decent size for validation.

The lines below are similar to the ones used for the training. transform is reused to resize test images.

As you can see reading the output of the program in the commented lines above, the performances are quite good.

It is interesting to observe that the performance are slightly better than the one achieved in my previous post using XGBoost. This comes at the cost of longer training. XGBoost performances can however be further increased using proper hyperparameters tuning: in the previous post, default values for classification were used.

Conclusion

Transferring learning using pretrained models is a very efficient and fast way to get an accurate model, as well for classification than for regression.

Using pytorch and pytochvision, this can be done in a very limited number of lines.

In this post, we have been using the knowledge acquired on images to apply it to other images. I’m pretty sure that we could also achieve decent results on any other kind of signal, like 1D signal.


Related Articles