Fastai Bag of Tricks —Experiments with a Kaggle Dataset — Part 1

Moein Shariatnia
Towards Data Science
15 min readOct 1, 2020

--

Image by Anna Shvets from Pexels

In this article, I’m going to explain my experiments with the Kaggle dataset “Chest X-ray Images (Pneumonia)” and how I tackled different problems in this journey which led to getting the perfect accuracy on the validation set and test sets. My goal is to show you the different tricks you can easily use with the help of fastai (v2) and to share my experiments with this dataset as a kind of ablation study. I assume that you know the basics of deep learning and are familiar with DL frameworks for python, especially PyTorch and fastai. So let’s get started.

Photo from Kaggle

Downloading The Dataset

You can easily get your hands on the dataset by the following commands. Before that, you need to make a Kaggle account in case you don’t have one yet and use your own API Token info to run the code.

Getting Started

I start by taking a look at how the dataset is structured and by loading the image file names. In this dataset, we have train, validation, and test sets as is common in machine learning datasets.

Here I’m using get_image_files to read the image file names of the three partitions of our data and then I’m using L which is a good replacement for List class in Python from fastai to map the len function on the three lists of file names. As you see, we have 5216, 16, and 624 images for training, validation, and test set, respectively. Our validation set is really small which I’ll talk about it later in the article

Let’s take a look at a tiny EDA. I want to know if the dataset is balanced, i.e. if we have roughly the same number of items for the Normal vs Pneumonia classes.

Number of items per each class, Image by author

As it is common with the real world and specifically medical datasets, we have a large gap between the number of items in the two categories (there is a ratio of 3:1 for pneumonia to normal cases). So our first problem to have in mind is to tackle this imbalance with some kind of trick. Actually, I would not recommend to spend a lot of time before everything else in order to tackle this problem and I’d rather suggest to first build your initial model with the imbalanced data and see how it goes!

DataBlock and DataLoaders

Here I’m going to make my DataBlock and Dataloaders to feed the data to the model. In fastai it is really easy to do, however, for this article I’m not going into the details of what every piece of code does and I highly recommend to read their book or the book’s Jupyter notebooks freely on the course website. But for those of you who are familiar with PyTorch and fastai, I’m building my training and validation dataloaders here along with data augmentation.

A batch of images after augmentations and resizing, Image by author

The class imbalance is obvious even in this small set of images!

I’m using data augmentation here to mitigate the problem of having a small dataset. Although it does not change the ratio of our positive to negative cases, it could help the model tackle the imbalance problem by being introduced to a larger number of positive cases which would help it learn the underlying pattern more robustly. In data augmentation for medical datasets, make sure that you add the ones which make sense and does not change your data in a case that makes it too different from the real world data. Here, I’m resizing the images to 512 at first with the method “squish” in order not to lose the probably important parts of the data (I’m not cropping here, but int the next part of article we will see how cropping will work too). After that, RandomResizedCropGPU crops patches of size 224 randomly from that 512 image. This technique is called “Pre-sizing” by the authors of the fastai book. Then we will augment further with Rotate and Zoom. I’m not flipping or warping or changing the brightness and contrast of my images (which are all legitimate types of data augmentation for most computer vision tasks) because I think that will not represent an actual X-ray image anymore. Also keep in mind that even flipping horizontally could be problematic in X-ray images in some particular tasks (maybe not much here) especially when the location of body organs are important. For example dextrocardia is a rare condition in which the heart is located on the right side of the chest and flipping the images will paralyze the model to detect those cases.

Also, I want to make something clear here. The images in this dataset are single channel grayscale images. But the models pretrained on the ImageNet (and the default mode of other ready to use model architectures in fastai) which we are going to use later in the article, expect a 3 channel RGB input to the model. We don’t have to worry about that because the data block and the PILImage class overcome this by repeating our first channel 2 more times to make it a three channel image and we are ready to train it normally on any model.

Baseline Model

I’m calling it baseline because we are not using fancy tricks here but it is not going to be a real baseline model like the ones we build by stacking Conv and Linear layers. I’ll start with XResNet18 architecture. XResNets are actually a lot like vanilla ResNets but with some tweaks which you can search for them online or read Bag of Tricks for Image Classification with Convolutional Neural Networks paper by Tong He et. al. to learn more about the differences. Because it’s gonna be a baseline model, I’m going to train it from scratch and I won’t use a pretrained model in this very case.

lr_finder plot, Image by author

As the plot suggests and as I expected, a learning rate of 3e-3 seems good here. So let’s train it for 10 epochs with one cycle policy and see how it goes.

baseline model performance, Image by author

Okay! We didn’t do bad but we are far from perfect! There are some points in these results. If you look at the accuracy and f1 score results, they seem unstable which is due to some reasons. First, we are training the model from scratch (and even the data is not normalized!) which can lead to a more unstable training here. Second, our validation set is really tiny! It has only 16 items that may not be good representatives of the data the model is being trained on. Finally, I’m using one cycle training in here which increases the learning rate in the first 25 percent of iterations that could cause the higher valid_loss early in the training.

Let’s think about our model in this way. We have 3875 pneumonia vs 1341 normal cases. So if our model simply outputs “pneumonia” for every item it sees in the training set, it can easily achieve about 75 percent accuracy. Our model is not doing the exact thing here (because our f1 score suggests that it has learned to distinguish the two classes to some extent) but we should know what the most simple baseline is like!

After all, I think we are getting good results after just 4 minutes of training on this rather small dataset. But, there are a lot of cool things we are going to experiment with in this and the next article to make it better.

Overcoming the class imbalance problem

As I mentioned in the previous sections, one of our problems is the class imbalance problem. There are many ways you can find in the computer vision literature to overcome this really common problem of datasets. I highly suggest to read A systematic study of the class imbalance problem in convolutional neural networks paper which has done a perfect job on comparing the different methods to tackle this problem. In its abstract you can find the following:

Based on results from our experiments we conclude that (i) the effect of class imbalance on classification performance is detrimental; (ii) the method of addressing class imbalance that emerged as dominant in almost all analyzed scenarios was oversampling … (iv) as opposed to some classical machine
learning models, oversampling does not cause overfitting of CNNs

Really nice! We just got our answer! Let me explain. The most used approaches to overcome class imbalance is either down sampling (using less data) the class which has a higher ratio or over sampling the class which has a lower ratio. The paper is suggesting to do the latter; in our case, increasing the number of “normal” cases by simply repeating them in the training dataset. We can do it like this:

All we need to do is these two lines of code! In the first line, I’m collecting the file names of the “normal” cases in a list (or L actually!). In the second line, I’m somehow stacking three lists: one that has my previous training file names (imbalanced normal and pneumonia cases), one list that has the “normal” file names repeated two times (normal_cases*2) and one list containing the 16 validation files (we will separate training and validation set later in the data block like the code in the previous section). Now we have 3875 pneumonia and 4023 normal cases, roughly equal. Make sure that you give this over sampled list to your data block with the new indices for training and validation sets. Let’s train our model again on this new dataset (building the model is from scratch and exactly the same as before):

Getting 100 percent accuracy, Image by author

Wow! We just got 100 percent accuracy! As far as I know at the time of writing, this is by far the best result on the provided validation set and the training time is much lower than others who have suggested models on this dataset.

Our model perfectly classified the validation set. By doing this over sampling, the training process seems to be more stable than the previous model. Our validation loss is now significantly lower than the previous model and is much closer to the train loss. So with this we can confirm the results of that paper as we have gotten much better results without overfitting.

The result is absolutely great but do not be too much happy about it! :) You may remember that our validation set is really tiny with only 16 samples. This number of samples is way smaller than we can reach a solid conclusion about the performance of the model. Actually if you try what I have done here, you may get different results because with this low number of samples, random chance has a great role! I tested the model on the test set (I know it’s too soon to use that part but I did!) and I got 85 percent accuracy which confirms what I said here that the model is not able to generalize well (yet!).

So, in order to train a really good and robust model, from now on in this article, I merge the test set provided by the dataset to the validation set so that we can judge the behavior of our model more accurately when implementing other tricks and decide what to do next with that information. Below is how to merge the validation and test sets:

merging val and test set

Here is the new baseline model results on the oversampled dataset with merged validation and test sets with two new metrics, Precision and Recall:

Model performance after merging the validation and test sets, Image by author

As you see, the results are getting more stable and they don’t jump widely from one number to another. In each of the next parts, I would add one trick to our training process and see how it affects the training.

Normalization

In this section, I’m going to normalize my data by subtracting its mean and dividing it by standard deviation. First I need to obtain these stats from my data. I do the following to get the mean and standard deviation for each channel:

As you see the three numbers for mean or standard deviation are the same because the channels are identical as I said in previous section.

By adding the following to your code for making the data block, you’ll have your data normalized:

When I trained the same model again from the scratch, I got the following:

Model performance after normalization, Image by author

It did not make a great difference here mostly because we are training from scratch and the model can adapt itself to the not normalized data as well.

Transfer Learning and Pretrained Models

It is almost always a great idea to start your project (or maybe after some baseline models) by using transfer learning and pretrained models. They are way more helpful than initializing your weights randomly and help your model converge more easily and quickly. In transfer learning, it is important that your input data has similar characteristics to the data which the model has been trained on; so, we are going to normalize our data to have a mean of zero and a standard deviation of one for every single channel. Remember that we (actually fastai!) repeated the one single channel we had in order to make our images like natural images with three channels. So, we are now ready to fine tune our first model! Build your data block as before including Normalize in your batch_tfms and then you can build your pretrained model easily like this:

A couple of things to note in here. I’ve changed the architecture to resnet18 which is very similar to the previous one and I’m using cnn_learner here which by default loads the pretrained weights of the model and freezes most of the layers (except your model’s head which is new and randomly initialized) so that their weights will not get updated. You can find the learning rate and train the model like before. Keep in mind that in fastai v2 you can easily fine tune the pretrained model with fine_tune method which trains the frozen model first and then unfreezes it and trains it for more epochs. But here I’m doing this procedure manually:

pretrained resnet18 performance, Image by author

Okay! About two percent more accuracy by using the pretrained model. In the figure you see that the seemingly best results are in the middle of training (epoch 6) and we perhaps had more epochs than needed. As Jeremy Howard mentioned in the 2020 fastai course, it is better to re-train the model for the number of epochs which is needed (6 for example in this case) rather than using EarlyStopping or SaveModel callbacks with a lot of epochs. The reason is that we are using one cycle training and we decrease the learning rate significantly in the last iterations which could lead to finding better optimums in the loss surface that can help model generalize better and we lose this opportunity if we stop our training in the middle of the iterations when the learning rate is high.

Unfreezing

Now we unfreeze the model and train the all the layers by using different learning rates for different layers. We know that in transfer learning, the early layers of the model do not need that much change in their weights because the things they are responsible for learning are almost common in CV tasks. On the other hand, we use higher learning rates for the layers as we go deeper in the model to adapt them to the new task. But before that, let’s find out which learning rates we should use:

And this is the result of lr_find():

Image by author

The interpretation of the lr_find plot could get difficult sometimes. But in this case is rather straightforward. I’ll give slice(1e-6, 1e-4) as my range of learning rates for different layers of the model (the early layers will have a maximum lr of 1e-6 and the deepest layers will have 1e-4 as their max_lr). I found this range to work okay even if the lr_find plot is not that helpful.

model performance after unfreezing, Image by author

We got one percent more accuracy here. As you may have noticed, I’m training for less epochs here because I trained once before this and I found 5 epochs as enough and according to what I said before about EarlyStoping in the last section, I re-trained it rather than stopping the training early in the middle of the procedure. In the next section, we will enhance the performance of the model even further.

Progressive Resizing

Up until now, we were training our models with pictures of size 224 * 224 which is a rather small image compared to the original size images which are over 512 * 512. Here, we will fine tune our model from the last section with bigger images of size 360 * 360 to introduce more data to our model and help it learn the patterns. The only thing you need to change is this:

I just changed the size to 360 like this. We need to freeze the model again and train it a little and then unfreeze it and train more. I only show you the final results to cut down the length of the article:

final results, Image by author

That’s really awesome! we are getting about 95 percent accuracy on the merged validation and test set and the precision and recall scores are all high perfectly.

In the next section, we will interpret the results of our model more thoroughly.

Model Interpretation

Here I’ll show you how to get interpretations of your model easily with fastai. Let’s take a look at the code for getting the confusion matrix:

confusion matrix, Image by author

Fair enough! Let’s also take a look at classification report:

classification report, Image by author

The numbers are indicating that we’re doing a good job on this rather small dataset. We can now check how we are doing on the initial validation set (those 16 images). I know that it is weird because those images are already in our merged validation set, but it’s just a simple checking to make sure our model is stronger than before with the same performance on the provided validation set:

Okay perfect! 100 percent accuracy, f1_score, precision, and recall. As you remember, in the beginning of the article after doing over sampling we got 100 percent accuracy and here we repeated it. But this model is more robust because with the help of the larger validation set we made, we were able to test different tricks and use the losses and metrics printed out during training to decide what should we do next and how to tune our models. But the down side is that we no longer have any test set and we don’t actually know how this model does on real data because it is possible that the we have overfit the model on the merged validation set. But the results are really promising and we might not have problems later on.

Next part of the article

As this article is getting really long, I will explain the other tricks and experiments in a separate article as the second part of this one. We will see how to use Grad CAM and simple heatmaps to visualize what our model is paying attention to in the images and we will try hyperparameter tuning (tuning the weight decay and learning rate) which I think you will be really excited about. After that, we will try Label Smoothing, MixUp, Half Precision training and other model architectures and data augmentation methods and we will see how each of them affects the project.

Sources: I used many ideas from the book Deep Learning for Coders with Fastai and PyTorch and fastai courses. I highly recommend to give them a try.

About me: I’m a medical student and I love deep learning and the cool things we are able to build with it to improve our quality of life. I’m going to use its power to overcome the obstacles we face in medicine. I spend a lot of time studying about deep learning beside my medical courses and I aim to work in the interdisciplinary fields where AI and medicine meet each other :)

--

--