Deep Learning and Medical Image Analysis for Malaria Detection with fastai

Learn to classify blood smear images using a high-level deep learning environment

Aldo von Wangenheim
Towards Data Science

--

Jimmy Chan/Pexels free images

After the Lister Hill National Center for Biomedical Communications (LHNCBC), part of National Library of Medicine (NLM), has made available an annotated dataset of healthy and infected blood smear malaria images, various postings and papers have been published showing how to use image classification with convolutional neural networks to learn and classify those images.

The images in this dataset look like the ones we collected below: Parasitized smears will show some colored dots and uninfected ones will tend to be homogeneously colored.

It should not be an excessively difficult task to classify these smears. The results described in the postings listed below show that a classification accuracy of 96% to 97% is feasible.

In this posting we will show how to use the fast.ai CNN library for the purpose of learning to classify these malaria smears. Fast.ai is a library, built on PyTorch, which makes writing machine learning applications much faster and simpler. Fast.ai also offers an online course covering the use of fast.ai and deep learning in general. Compared to lower level “high-level” libraries such as Keras, TensorFlow.Keras or pure PyTorch, fast.ai dramatically reduces the amount of boilerplate code you’ll need to produce state of the art neural network applications.

This posting is based upon the following material:

  1. PyImagesearch::Deep Learning and Medical Image Analysis with Keras — Example with malaria images by Adrian Rosebrock on December 3, 2018;
  2. Malaria Datasets from the NIH — a repository of segmented cells from the thin blood smear slide images from the Malaria Screener research activity;
  3. TowardsDataScience::Detecting Malaria with Deep Learning — AI for Social Good — A Healthcare Case Study — a review of the above by Dipanjan (DJ) Sarkar;
  4. PeerJ::Pre-trained convolutional neural networks as feature extractors toward improved malaria parasite detection in thin blood smear images by Sivaramakrishnan Rajaraman​ et.al. — the original scientific paper the authors of the postings above based their work upon.

I have adapted this material for usage with PyTorch/fast.ai in Mai 2019. The commented code is freely available as a Google Colaboratory Notebook.

Let’s go to the code…

Initializations

Do what is in this section below once each time you run this notebook…

%reload_ext autoreload
%autoreload 2
%matplotlib inline

Testing your virtual machine on Google Colab…

Only to be sure, look at which CUDA driver and which GPU Colab has made available for you. The GPU will typically be either:

  • a K80 with 11 GB RAM or (if your really lucky)
  • a Tesla T4 with 14 GB RAM

If Google’s servers are crowded, you’ll eventually have access to only part of a GPU. If your GPU is shared with another Colab notebook, you’ll see a smaller amount of memory made available for you.

Tip: Avoid peak times of US west coast. I live at GMT-3 and we are two hours ahead of the US east coast, so I always try to perform heavy processing in the morning hours.

!/opt/bin/nvidia-smi
!nvcc --version

When I started running the experiments described here, I was lucky: I had a whole T4 with 15079 MB RAM! My output looked like this:

Thu May  2 07:36:26 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79 Driver Version: 410.79 CUDA Version: 10.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 63C P8 17W / 70W | 0MiB / 15079MiB | 0% Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130

Libray imports

Here we import all the necessary packages. We are going to work with the fast.ai V1 library which sits on top of Pytorch 1.0. The fast.ai library provides many useful functions that enable us to quickly and easily build neural networks and train our models.

from fastai.vision import *
from fastai.metrics import error_rate
from fastai.callbacks import SaveModelCallback
# Imports for diverse utilities
from shutil import copyfile
import matplotlib.pyplot as plt
import operator
from PIL import Image
from sys import intern # For the symbol definitions

Utility Functions: Export and restoration functions

Export network for deployment and create a copy

def exportStageTo(learn, path):
learn.export()
# Faça backup diferenciado
copyfile(path/'export.pkl', path/'export-malaria.pkl')

#exportStage1(learn, path)

Restoration of a deployment model, for example in order to conitnue fine-tuning

def restoreStageFrom(path):
# Restore a backup
copyfile(path/'export-malaria.pkl', path/'export.pkl')
return load_learner(path)

#learn = restoreStage1From(path)

Download the Malaria Data

The authors I listed above employed the NIH Malaria Dataset. Let’s do the same:

!wget  --backups=1 -q https://ceb.nlm.nih.gov/proj/malaria/cell_images.zip
!wget --backups=1 -q https://ceb.nlm.nih.gov/proj/malaria#/malaria_cell_classification_code.zip
# List what you've downloaded:
!ls -al

The backups=1 parameter of wget will allow you to repeat the command line many times, if a download fails, without creating a lot of new versions of the files.

The last line should produce the following output:

total 345208
drwxr-xr-x 1 root root 4096 May 2 07:45 .
drwxr-xr-x 1 root root 4096 May 2 07:35 ..
-rw-r--r-- 1 root root 353452851 Apr 6 2018 cell_images.zip
drwxr-xr-x 1 root root 4096 Apr 29 16:32 .config
-rw-r--r-- 1 root root 12581 Apr 6 2018 malaria_cell_classification_code.zip
drwxr-xr-x 1 root root 4096 Apr 29 16:32 sample_data

Now unzip the NIH malaria cell images dataset:

!unzip cell_images.zip

This will produce a very large verbose output, that will look like this:

Archive:  cell_images.zip
creating: cell_images/
creating: cell_images/Parasitized/
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_162.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_163.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_164.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_165.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_166.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_167.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_168.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_169.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_170.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_171.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_138.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_139.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_140.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_141.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_142.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_143.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_144.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_157.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_158.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_159.png
....
....
....
...and so on...

Prepare your data

Change the name of the cell_images folder to train, and then mv it on top of a new data folder, so fast.ai can use it to automatically generate train, validation and test sets, without any further fuss…

!mv cell_images train
!mkdir data
!mv train data

Look at your data folder

Use apt install tree if you don’t have the tree command installed yet:

!apt install tree

Now run it:

!tree ./data --dirsfirst --filelimit 10

This will show the structure of your files tree:

./data
└── train
├── Parasitized [13780 exceeds filelimit, not opening dir]
└── Uninfected [13780 exceeds filelimit, not opening dir]
3 directories, 0 files

Do not forget to set a filelimit, otherwise you’ll have a huge amount of output…

Initialize a few variables

bs = 256        # Batch size, 256 for small images on a T4 GPU...
size = 128 # Image size, 128x128 is a bit smaller than most
# of the images...
path = Path("./data") # The path to the 'train' folder you created...

Create your training and validation data bunches

In the original material from PyImagesearch, which employs Keras, there’s a long routine to create training, validation and test folders from the data. With fast.ai it is not necessary: if you only have a ‘train’ folder, you can automatically split it while creating the DataBunch by simply passing a few parameters. We will split the data into a training set (80%) and a validation set (20%). This is done with the valid_pct = 0.2 parameter in the ImageDataBunch.from_folder() constructor method:

# Limit your augmentations: it's medical data! 
# You do not want to phantasize data...
# Warping, for example, will let your images badly distorted,
# so don't do it!
# This dataset is big, so don't rotate the images either.
# Lets stick to flipping...
tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.0)
# Create the DataBunch!
# Remember that you'll have images that are bigger than 128x128
# and images that are smaller, o squish them all in order to
# occupy exactly 128x128 pixels...
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.SQUISH, valid_pct = 0.2, bs=bs)
#
print('Transforms = ', len(tfms))
# Save the DataBunch in case the training goes south...
# so you won't have to regenerate it..
# Remember: this DataBunch is tied to the batch size you selected.
data.save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.pkl')
# Show the statistics of the Bunch...
print(data.classes)
data

The print() will output the transforms and the classes:

Transforms =  2
['Parasitized', 'Uninfected']

The last line, ‘data’ will simply output the return value of the ImageDataBunch instance:

ImageDataBunch;Train: LabelList (22047 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Uninfected,Uninfected,Uninfected,Uninfected,Uninfected
Path: data;
Valid: LabelList (5511 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Parasitized,Uninfected,Parasitized,Uninfected,Uninfected
Path: data;
Test: None

Look at your DataBunch to see if the augmentations are acceptable…

data.show_batch(rows=5, figsize=(15,15))

Training: resnet34

If you do not know what to use, it is a good choice to start with a Residual Network with 34 layers. Not too small and not too big… In the tutorials listed above the authors used:

  • a custom, but small, ResNet (PyImagesearch)
  • a VGG19 (TowardsDataScience)

We will employ off-the-shelf fast.ai residual networks (ResNets). Let’s create our first network:

learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.model

The last line will output the architecture of the network as a text stream. It will look like this:

Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
....and so on....

Even a “small” network such as a ResNet34 is still very large. Do not bother trying to understand the output. You can read more about residual networks later. There are lots of introductory postings about ResNets.

Training strategy

Here comes one of the great differentials of fast.ai: Easy-to-use HYPOs (hyperparameter optimization strategies). Hyperparameter optimization is a somewhat arcane subdiscipline of CNNs. This is because CNNs have so much parameters and trying to choose which we will change by setting some non-standard values in order to provide for a better performance of our network, is a very complex issue and a study per se. The fast.ai library provides a few very advanced yet easy-to-use HYPOs that help immensely in implementing much better CNNs in a fast way.

We will employ the fit1cycle method developed by Leslie N. Smith — see below for details:

Since this method is fast, we will employ only 10 epochs in this first Transfer Learning stage. We will also save the network each epoch, if the performance gets better: https://docs.fast.ai/callbacks.html#SaveModelCallback

learn.fit_one_cycle(10, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria-1')])
# Save it!
learn.save('malaria-stage-1')
# Deploy it!
exportStageTo(learn, path)

This will produce a table like this as an output:

The table above shows an accuracy of 96.4% for the validation set, and this with transfer learning only! The error_rate fast.ai shows you will always see the one associated with the training set. As a comparison, consider that Adrian Rosebrock achieved 97% with his custom ResNet in the PyImagesearch posting.

Results for ResNet34

Let’s see what additional results we have got. We will first look which were the instances the model confused with one another. We will try to see if what the model predicted was reasonable or not. In this case the mistakes look reasonable (none of the mistakes seems obviously naive). This is an indicator that our classifier is working correctly.

Furthermore, we will plot the confusion matrix. This is also very simple in fast.ai.

interp = ClassificationInterpretation.from_learner(learn)losses,idxs = interp.top_losses()len(data.valid_ds)==len(losses)==len(idxs)

Look at your 9 worst results (without employing a heatmap, at first):

interp.plot_top_losses(9, figsize=(20,11), heatmap=False)

Now, do the same but highlight using a heatmap what induced the wrong classification:

interp.plot_top_losses(9, figsize=(20,11), heatmap=True)

Show the confusion Matrix

fast.ai’s ClassificationInterpretation class has a high-level instance method that allows for fast and easy plotting of confusion matrices that show you in a better way how good the CNN has performed. It doesn’t make so much sense with only two classes, but we’ll do it anyway: it generates beautiful pictures… You can set the size and the resolution of the resulting plot. We’ll set 5x5 inches with 100 dpi.

interp.plot_confusion_matrix(figsize=(5,5), dpi=100)

Show your learning curve:

It is interesting to look at the learning and the validation curves. It will show us if the network has learned in a steady way or if it oscillated (which can indicate bad quality data) and if we have a result that is OK or if we are overfitting or underfitting our network.

Again fast.ai has high-level methods that’ll help us. Each fast.ai cnn_learner has an automatically created instance of a Recorder. A recorder records epoch, loss, opt and metric data during training. The plot_losses() method will create a graphic with the train and validation curves:

learn.recorder.plot_losses()

This result looks really so good that it doesn’t make sense to fine-tune the network. If we look attentively, we’ll see that the validation loss gets worse than the training loss at about 500 batches, indicating that the network is probably starting to overfit at this point. This is an indication that we have definitively trained enough, at least for this ResNet model.

The overfitting we observed could be an indication that we are employing a network model that is an overkill for the complexity of the data, meaning that we are training a network that learns individual instances and not a generalization for the dataset. One very simple and practical way to test this hypothesis is to try to learn the dataset with a simpler network and see what happens.

Let’s employ a smaller network and try it all again…

ResNet18

This network is much simpler. Let’s see if it does the job.

The ResNet18 is much smaller, so we’ll have more GPU RAM for us. We will create the DataBunch again, this time with a bigger batch size…

# Limit your augmentations: it's medical data! 
# You do not want to phantasize data...
# Warping, for example, will let your images badly distorted,
# so don't do it!
# This dataset is big, so don't rotate the images either.
# Lets stick to flipping...
tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.0)
# Create the DataBunch!
# Remember that you'll have images that are bigger than 128x128
# and images that are smaller, so squish them to occupy
# exactly 128x128 pixels...
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.SQUISH, valid_pct = 0.2, bs=512)
#
print('Transforms = ', len(tfms))
# Save the DataBunch in case the training goes south... so you won't have to regenerate it..
# Remember: this DataBunch is tied to the batch size you selected.
data.save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.pkl')
# Show the statistics of the Bunch...
print(data.classes)
data

Observe that we stuck with our valid_pct = 0.2: we will still let fast.ai randomly choose 20% of the dataset as validation set.

The code above will output something like this:

Transforms =  2
['Parasitized', 'Uninfected']

and:

ImageDataBunch;Train: LabelList (22047 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Uninfected,Uninfected,Uninfected,Uninfected,Uninfected
Path: data;
Valid: LabelList (5511 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Parasitized,Uninfected,Parasitized,Uninfected,Parasitized
Path: data;
Test: None

Now, create the learner:

learn18 = cnn_learner(data, models.resnet18, metrics=error_rate)

If you Colab environment doesn’t have the pretrained data for the ResNet18, fast.ai will automatically download it:

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.torch/models/resnet18-5c106cde.pth
46827520it [00:01, 28999302.58it/s]

Look at the model:

learn18.model

This will list the structure of your net. It is much smaller than the ResNet34, but still has a lot of layers. The output will look like this:

Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
...and so on...

Let’s train it

We will again use the fit_one_cycle HYPO training strategy. Limit the training to 10 epochs to see how this smaller network behaves:

learn18.fit_one_cycle(10, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria18-1')])
# Save the network
learn18.save('malaria18-stage-1')
# Deploy it also
exportStageTo(learn18, path)

This table shows that the network learned to an accuracy of roughly 96.1% and suggests that the network should not be trained further: the loss between epoch #8 and #9 shows a 0.005 decrease but the accuracy has remained the same, suggesting that the network has started overfitting.

Let’s generate a ClassificationInterpretation and look at the confusion matrix and the loss curves.

interp = ClassificationInterpretation.from_learner(learn18)
losses,idxs = interp.top_losses()
interp.plot_confusion_matrix(figsize=(5,5), dpi=100)

This confusion matrix is slightly, but only very slightly worse than the one we generated for the ResNet34. Is the ResNet18 less well suited for this problem?

Let’s look at the losses:

This graph shows that the ResNet18 started overfitting a bit after about 290 batches. Remember that our bs is 512 here and was 256 for the ResNet34.

Let’s see if we can make this better fine-tuning the network.

Fine-Tune it!

Here we will introduce another fast.ai HYPO: automatically chosen variable learning rates. We will let fast.ai choose which learning rate to use for each epoch and each layer, providing a range of learning rates we consider adequate. We will train the network for 30 epochs.

# Unfreeze the network
learn18.unfreeze()
# Learning rates range: max_lr=slice(1e-4,1e-5)
learn18.fit_one_cycle(30, max_lr=slice(1e-4,1e-5),
callbacks=[SaveModelCallback(learn,
every='epoch', monitor='accuracy',
name='malaria18')])
# Save as stage 2...
learn18.save('malaria18-stage-2')
# Deploy
exportStageTo(learn18, path)

97% accuracy! That is exactly what Adrian Rosebrock also achieved with his custom Keras ResNet implementation in the PyImagesearch posting, which presents the best accuracy results among the three references above.

The validation loss, however, was becoming worse for the last epochs. This indicates that we have been overfitting from about epoch #20 on. If you want to deploy this network, I would suggest you to load the results from epoch 20 and generate a deployment network. It does not seem to become better, not with this network.

Look at the results

interp = ClassificationInterpretation.from_learner(learn18)
losses,idxs = interp.top_losses()
interp.plot_confusion_matrix(figsize=(5,5), dpi=100)

This is better than what we’ve had before. Let’s look at the loss curves:

Here we see that the network seems to start to overfit after 500 batches, which would confirm our suspicion inferred from the results table above. If you look at the curves above, you’ll see that the validation loss starts to grow in the last third of the training, suggesting that this part of the training only overfitted the network.

What do I do if the training was interrupted halfway through?

What do you do if your training was interrupted? This can happen because you reached your 12 hours of continuous “free” operating time on a Google Colab notebook or because your computer stopped for some reason. I live in Brazil and events of power shortages are common…

The fit_one_cycle method works with varying, adaptive learning rates, following a curve where the rate is first increased and then decreased. If you interrupt a training in epoch #10 of, say, 20 epochs and then start again for more 9 epochs, you’ll not have the same result as training uninterruptedly for 20 epochs. You have to be able to record where you stopped and then resume the training cycle from that point and with the correct hyperparameters for that part of the cycle.

A fit_one_cycle training session divided into three subsessions. Image by PPW@GitHub

The first thing you have to do is to save your network every epoch:

learn.fit_one_cycle(20, max_lr=slice(1e-5,1e-6), 
callbacks=[SaveModelCallback(learn, every='epoch',
monitor='accuracy', name=<callback_save_file>)])

This will have your network be saved every epoch, with the name you provided followed by _#epoch. So at epoch #3, the file saved_net_3.pth will be written. This file you can load after you:

  • re-created the DataBunch and
  • re-instantiated the network with it.

After reloading the .pth file, you can restart your training, only you’ll tell fit_one_cycle to consider 20 epochs, but to start training from epoch #4.

To learn how this is done, look here:

How do you do it?

The fit_one_cycle method in fast.ai has been developed to allow you to tell it from which part of the cycle to resume an interrupted training. The code for resuming a training will look like this:

# Create a new net if training was interrupted and you had to 
# restart your Colab session
learn = cnn_learner(data, models.<your_model_here>,
metrics=[accuracy, error_rate])
# If you're resuming, only indicating the epoch from which to
# resume, indicated by start_epoch=<epoch#> will load the last
# saved .pth, it is not necessary to explicitly reload the last
# epoch, you only should NOT change the name given in
# name=<callback_save_file>:
# when resuming fast.ai will try to reload
# <callback_save_file>_<previous_epoch>.pth
# Unfreeze the network
learn50.unfreeze()
# Use start_epoch=<some_epoch> to resume training...
learn.fit_one_cycle(20, max_lr=slice(1e-5,1e-6),
start_epoch=<next_epoch#>,
callbacks=[SaveModelCallback(learn,
every='epoch', monitor='accuracy',
name=<callback_save_file>)])

fast.ai will tell you “Loaded <callback_save_file>_<previous_epoch#>” and resume training.

You can look at all parameters supported by the fit_one_cycle method here:

What have we learned?

Compared to other approaches, where the authors respectively employ pure Keras and TensorFlow.Keras to solve the problem, with fast.ai we were able to solve the same malaria blood smear classification problem employing much less code while using high-level hyperparameter optimization strategies that allowed us to train much faster. At the same time, a set of high level functions allows us also to easily inspect the results both as tables and as graphs.

With off-the-shelf residual network models, pre-trained on ImageNet, provided by fast.ai we obtained accuracy results that are 1% better than two of the three previous works above (including the scientific paper published on PeerJ) and equal to the best performing of the works above. This shows that fast.ai is a very promising alternative to more traditional CNN frameworks, especially if the task at hand is is a “standard” deep learning task such as image classification, object detection or semantic segmentation, that can be solved fine-tuning off-the-shelf pre-trained network models.

Want to see a more ambitious example?

Look at our skin cancer detection tutorial with fast.ai:

In this example we will show you how to work with larger datasets and adapt the best learning rates.

--

--

Computer Sciences Professor at UFSC. I do research on image processing, artificial intelligence and telemedicine. https://www.inf.ufsc.br/~aldo.vw/