Deep Learning and Medical Image Analysis for Malaria Detection with fastai
Learn to classify blood smear images using a high-level deep learning environment
After the Lister Hill National Center for Biomedical Communications (LHNCBC), part of National Library of Medicine (NLM), has made available an annotated dataset of healthy and infected blood smear malaria images, various postings and papers have been published showing how to use image classification with convolutional neural networks to learn and classify those images.
The images in this dataset look like the ones we collected below: Parasitized smears will show some colored dots and uninfected ones will tend to be homogeneously colored.
It should not be an excessively difficult task to classify these smears. The results described in the postings listed below show that a classification accuracy of 96% to 97% is feasible.
In this posting we will show how to use the fast.ai CNN library for the purpose of learning to classify these malaria smears. Fast.ai is a library, built on PyTorch, which makes writing machine learning applications much faster and simpler. Fast.ai also offers an online course covering the use of fast.ai and deep learning in general. Compared to lower level “high-level” libraries such as Keras, TensorFlow.Keras or pure PyTorch, fast.ai dramatically reduces the amount of boilerplate code you’ll need to produce state of the art neural network applications.
This posting is based upon the following material:
- PyImagesearch::Deep Learning and Medical Image Analysis with Keras — Example with malaria images by Adrian Rosebrock on December 3, 2018;
- Malaria Datasets from the NIH — a repository of segmented cells from the thin blood smear slide images from the Malaria Screener research activity;
- TowardsDataScience::Detecting Malaria with Deep Learning — AI for Social Good — A Healthcare Case Study — a review of the above by Dipanjan (DJ) Sarkar;
- PeerJ::Pre-trained convolutional neural networks as feature extractors toward improved malaria parasite detection in thin blood smear images by Sivaramakrishnan Rajaraman et.al. — the original scientific paper the authors of the postings above based their work upon.
I have adapted this material for usage with PyTorch/fast.ai in Mai 2019. The commented code is freely available as a Google Colaboratory Notebook.
Let’s go to the code…
Initializations
Do what is in this section below once each time you run this notebook…
%reload_ext autoreload
%autoreload 2
%matplotlib inline
Testing your virtual machine on Google Colab…
Only to be sure, look at which CUDA driver and which GPU Colab has made available for you. The GPU will typically be either:
- a K80 with 11 GB RAM or (if your really lucky)
- a Tesla T4 with 14 GB RAM
If Google’s servers are crowded, you’ll eventually have access to only part of a GPU. If your GPU is shared with another Colab notebook, you’ll see a smaller amount of memory made available for you.
Tip: Avoid peak times of US west coast. I live at GMT-3 and we are two hours ahead of the US east coast, so I always try to perform heavy processing in the morning hours.
!/opt/bin/nvidia-smi
!nvcc --version
When I started running the experiments described here, I was lucky: I had a whole T4 with 15079 MB RAM! My output looked like this:
Thu May 2 07:36:26 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79 Driver Version: 410.79 CUDA Version: 10.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 63C P8 17W / 70W | 0MiB / 15079MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130
Libray imports
Here we import all the necessary packages. We are going to work with the fast.ai V1 library which sits on top of Pytorch 1.0. The fast.ai library provides many useful functions that enable us to quickly and easily build neural networks and train our models.
from fastai.vision import *
from fastai.metrics import error_rate
from fastai.callbacks import SaveModelCallback# Imports for diverse utilities
from shutil import copyfile
import matplotlib.pyplot as plt
import operator
from PIL import Image
from sys import intern # For the symbol definitions
Utility Functions: Export and restoration functions
Export network for deployment and create a copy
def exportStageTo(learn, path):
learn.export()
# Faça backup diferenciado
copyfile(path/'export.pkl', path/'export-malaria.pkl')
#exportStage1(learn, path)
Restoration of a deployment model, for example in order to conitnue fine-tuning
def restoreStageFrom(path):
# Restore a backup
copyfile(path/'export-malaria.pkl', path/'export.pkl')
return load_learner(path)
#learn = restoreStage1From(path)
Download the Malaria Data
The authors I listed above employed the NIH Malaria Dataset. Let’s do the same:
!wget --backups=1 -q https://ceb.nlm.nih.gov/proj/malaria/cell_images.zip
!wget --backups=1 -q https://ceb.nlm.nih.gov/proj/malaria#/malaria_cell_classification_code.zip
# List what you've downloaded:
!ls -al
The backups=1 parameter of wget will allow you to repeat the command line many times, if a download fails, without creating a lot of new versions of the files.
The last line should produce the following output:
total 345208
drwxr-xr-x 1 root root 4096 May 2 07:45 .
drwxr-xr-x 1 root root 4096 May 2 07:35 ..
-rw-r--r-- 1 root root 353452851 Apr 6 2018 cell_images.zip
drwxr-xr-x 1 root root 4096 Apr 29 16:32 .config
-rw-r--r-- 1 root root 12581 Apr 6 2018 malaria_cell_classification_code.zip
drwxr-xr-x 1 root root 4096 Apr 29 16:32 sample_data
Now unzip the NIH malaria cell images dataset:
!unzip cell_images.zip
This will produce a very large verbose output, that will look like this:
Archive: cell_images.zip
creating: cell_images/
creating: cell_images/Parasitized/
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_162.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_163.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_164.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_165.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_166.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_167.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_168.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_169.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_170.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_171.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_138.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_139.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_140.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_141.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_142.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_143.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_144.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_157.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_158.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_159.png
....
....
....
...and so on...
Prepare your data
Change the name of the cell_images folder to train, and then mv it on top of a new data folder, so fast.ai can use it to automatically generate train, validation and test sets, without any further fuss…
!mv cell_images train
!mkdir data
!mv train data
Look at your data folder
Use apt install tree if you don’t have the tree command installed yet:
!apt install tree
Now run it:
!tree ./data --dirsfirst --filelimit 10
This will show the structure of your files tree:
./data
└── train
├── Parasitized [13780 exceeds filelimit, not opening dir]
└── Uninfected [13780 exceeds filelimit, not opening dir]3 directories, 0 files
Do not forget to set a filelimit, otherwise you’ll have a huge amount of output…
Initialize a few variables
bs = 256 # Batch size, 256 for small images on a T4 GPU...
size = 128 # Image size, 128x128 is a bit smaller than most
# of the images...
path = Path("./data") # The path to the 'train' folder you created...
Create your training and validation data bunches
In the original material from PyImagesearch, which employs Keras, there’s a long routine to create training, validation and test folders from the data. With fast.ai it is not necessary: if you only have a ‘train’ folder, you can automatically split it while creating the DataBunch by simply passing a few parameters. We will split the data into a training set (80%) and a validation set (20%). This is done with the valid_pct = 0.2 parameter in the ImageDataBunch.from_folder() constructor method:
# Limit your augmentations: it's medical data!
# You do not want to phantasize data...
# Warping, for example, will let your images badly distorted,
# so don't do it!
# This dataset is big, so don't rotate the images either.
# Lets stick to flipping...
tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.0)
# Create the DataBunch!
# Remember that you'll have images that are bigger than 128x128
# and images that are smaller, o squish them all in order to
# occupy exactly 128x128 pixels...
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.SQUISH, valid_pct = 0.2, bs=bs)
#
print('Transforms = ', len(tfms))
# Save the DataBunch in case the training goes south...
# so you won't have to regenerate it..
# Remember: this DataBunch is tied to the batch size you selected.
data.save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.pkl')
# Show the statistics of the Bunch...
print(data.classes)
data
The print() will output the transforms and the classes:
Transforms = 2
['Parasitized', 'Uninfected']
The last line, ‘data’ will simply output the return value of the ImageDataBunch instance:
ImageDataBunch;Train: LabelList (22047 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Uninfected,Uninfected,Uninfected,Uninfected,Uninfected
Path: data;Valid: LabelList (5511 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Parasitized,Uninfected,Parasitized,Uninfected,Uninfected
Path: data;Test: None
Look at your DataBunch to see if the augmentations are acceptable…
data.show_batch(rows=5, figsize=(15,15))
Training: resnet34
If you do not know what to use, it is a good choice to start with a Residual Network with 34 layers. Not too small and not too big… In the tutorials listed above the authors used:
- a custom, but small, ResNet (PyImagesearch)
- a VGG19 (TowardsDataScience)
We will employ off-the-shelf fast.ai residual networks (ResNets). Let’s create our first network:
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.model
The last line will output the architecture of the network as a text stream. It will look like this:
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
....and so on....
Even a “small” network such as a ResNet34 is still very large. Do not bother trying to understand the output. You can read more about residual networks later. There are lots of introductory postings about ResNets.
Training strategy
Here comes one of the great differentials of fast.ai: Easy-to-use HYPOs (hyperparameter optimization strategies). Hyperparameter optimization is a somewhat arcane subdiscipline of CNNs. This is because CNNs have so much parameters and trying to choose which we will change by setting some non-standard values in order to provide for a better performance of our network, is a very complex issue and a study per se. The fast.ai library provides a few very advanced yet easy-to-use HYPOs that help immensely in implementing much better CNNs in a fast way.
We will employ the fit1cycle method developed by Leslie N. Smith — see below for details:
- https://docs.fast.ai/callbacks.one_cycle.html
- A disciplined approach to neural network hyper-parameters: Part 1 — learning rate, batch size, momentum, and weight decay — https://arxiv.org/abs/1803.09820
- Super-Convergence: Very Fast Training of Residual Networks Using Large Learning Rates — https://arxiv.org/abs/1708.07120
- There is a very interesting article from Nachiket Tanksale called Finding Good Learning Rate and The One Cycle Policy where cyclic learning rates and momentum are discussed.
Since this method is fast, we will employ only 10 epochs in this first Transfer Learning stage. We will also save the network each epoch, if the performance gets better: https://docs.fast.ai/callbacks.html#SaveModelCallback
learn.fit_one_cycle(10, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria-1')])
# Save it!
learn.save('malaria-stage-1')
# Deploy it!
exportStageTo(learn, path)
This will produce a table like this as an output:
The table above shows an accuracy of 96.4% for the validation set, and this with transfer learning only! The error_rate fast.ai shows you will always see the one associated with the training set. As a comparison, consider that Adrian Rosebrock achieved 97% with his custom ResNet in the PyImagesearch posting.
Results for ResNet34
Let’s see what additional results we have got. We will first look which were the instances the model confused with one another. We will try to see if what the model predicted was reasonable or not. In this case the mistakes look reasonable (none of the mistakes seems obviously naive). This is an indicator that our classifier is working correctly.
Furthermore, we will plot the confusion matrix. This is also very simple in fast.ai.
interp = ClassificationInterpretation.from_learner(learn)losses,idxs = interp.top_losses()len(data.valid_ds)==len(losses)==len(idxs)
Look at your 9 worst results (without employing a heatmap, at first):
interp.plot_top_losses(9, figsize=(20,11), heatmap=False)
Now, do the same but highlight using a heatmap what induced the wrong classification:
interp.plot_top_losses(9, figsize=(20,11), heatmap=True)
Show the confusion Matrix
fast.ai’s ClassificationInterpretation class has a high-level instance method that allows for fast and easy plotting of confusion matrices that show you in a better way how good the CNN has performed. It doesn’t make so much sense with only two classes, but we’ll do it anyway: it generates beautiful pictures… You can set the size and the resolution of the resulting plot. We’ll set 5x5 inches with 100 dpi.
interp.plot_confusion_matrix(figsize=(5,5), dpi=100)
Show your learning curve:
It is interesting to look at the learning and the validation curves. It will show us if the network has learned in a steady way or if it oscillated (which can indicate bad quality data) and if we have a result that is OK or if we are overfitting or underfitting our network.
Again fast.ai has high-level methods that’ll help us. Each fast.ai cnn_learner has an automatically created instance of a Recorder. A recorder records epoch, loss, opt and metric data during training. The plot_losses() method will create a graphic with the train and validation curves:
learn.recorder.plot_losses()
This result looks really so good that it doesn’t make sense to fine-tune the network. If we look attentively, we’ll see that the validation loss gets worse than the training loss at about 500 batches, indicating that the network is probably starting to overfit at this point. This is an indication that we have definitively trained enough, at least for this ResNet model.
The overfitting we observed could be an indication that we are employing a network model that is an overkill for the complexity of the data, meaning that we are training a network that learns individual instances and not a generalization for the dataset. One very simple and practical way to test this hypothesis is to try to learn the dataset with a simpler network and see what happens.
Let’s employ a smaller network and try it all again…
ResNet18
This network is much simpler. Let’s see if it does the job.
The ResNet18 is much smaller, so we’ll have more GPU RAM for us. We will create the DataBunch again, this time with a bigger batch size…
# Limit your augmentations: it's medical data!
# You do not want to phantasize data...
# Warping, for example, will let your images badly distorted,
# so don't do it!
# This dataset is big, so don't rotate the images either.
# Lets stick to flipping...
tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.0)
# Create the DataBunch!
# Remember that you'll have images that are bigger than 128x128
# and images that are smaller, so squish them to occupy
# exactly 128x128 pixels...
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.SQUISH, valid_pct = 0.2, bs=512)
#
print('Transforms = ', len(tfms))
# Save the DataBunch in case the training goes south... so you won't have to regenerate it..
# Remember: this DataBunch is tied to the batch size you selected.
data.save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.pkl')
# Show the statistics of the Bunch...
print(data.classes)
data
Observe that we stuck with our valid_pct = 0.2: we will still let fast.ai randomly choose 20% of the dataset as validation set.
The code above will output something like this:
Transforms = 2
['Parasitized', 'Uninfected']
and:
ImageDataBunch;Train: LabelList (22047 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Uninfected,Uninfected,Uninfected,Uninfected,Uninfected
Path: data;Valid: LabelList (5511 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Parasitized,Uninfected,Parasitized,Uninfected,Parasitized
Path: data;Test: None
Now, create the learner:
learn18 = cnn_learner(data, models.resnet18, metrics=error_rate)
If you Colab environment doesn’t have the pretrained data for the ResNet18, fast.ai will automatically download it:
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.torch/models/resnet18-5c106cde.pth
46827520it [00:01, 28999302.58it/s]
Look at the model:
learn18.model
This will list the structure of your net. It is much smaller than the ResNet34, but still has a lot of layers. The output will look like this:
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
...and so on...
Let’s train it
We will again use the fit_one_cycle HYPO training strategy. Limit the training to 10 epochs to see how this smaller network behaves:
learn18.fit_one_cycle(10, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria18-1')])
# Save the network
learn18.save('malaria18-stage-1')
# Deploy it also
exportStageTo(learn18, path)
This table shows that the network learned to an accuracy of roughly 96.1% and suggests that the network should not be trained further: the loss between epoch #8 and #9 shows a 0.005 decrease but the accuracy has remained the same, suggesting that the network has started overfitting.
Let’s generate a ClassificationInterpretation and look at the confusion matrix and the loss curves.
interp = ClassificationInterpretation.from_learner(learn18)
losses,idxs = interp.top_losses()interp.plot_confusion_matrix(figsize=(5,5), dpi=100)
This confusion matrix is slightly, but only very slightly worse than the one we generated for the ResNet34. Is the ResNet18 less well suited for this problem?
Let’s look at the losses:
This graph shows that the ResNet18 started overfitting a bit after about 290 batches. Remember that our bs is 512 here and was 256 for the ResNet34.
Let’s see if we can make this better fine-tuning the network.
Fine-Tune it!
Here we will introduce another fast.ai HYPO: automatically chosen variable learning rates. We will let fast.ai choose which learning rate to use for each epoch and each layer, providing a range of learning rates we consider adequate. We will train the network for 30 epochs.
# Unfreeze the network
learn18.unfreeze()
# Learning rates range: max_lr=slice(1e-4,1e-5)
learn18.fit_one_cycle(30, max_lr=slice(1e-4,1e-5),
callbacks=[SaveModelCallback(learn,
every='epoch', monitor='accuracy',
name='malaria18')])
# Save as stage 2...
learn18.save('malaria18-stage-2')
# Deploy
exportStageTo(learn18, path)
97% accuracy! That is exactly what Adrian Rosebrock also achieved with his custom Keras ResNet implementation in the PyImagesearch posting, which presents the best accuracy results among the three references above.
The validation loss, however, was becoming worse for the last epochs. This indicates that we have been overfitting from about epoch #20 on. If you want to deploy this network, I would suggest you to load the results from epoch 20 and generate a deployment network. It does not seem to become better, not with this network.
Look at the results
interp = ClassificationInterpretation.from_learner(learn18)
losses,idxs = interp.top_losses()interp.plot_confusion_matrix(figsize=(5,5), dpi=100)
This is better than what we’ve had before. Let’s look at the loss curves:
Here we see that the network seems to start to overfit after 500 batches, which would confirm our suspicion inferred from the results table above. If you look at the curves above, you’ll see that the validation loss starts to grow in the last third of the training, suggesting that this part of the training only overfitted the network.
What do I do if the training was interrupted halfway through?
What do you do if your training was interrupted? This can happen because you reached your 12 hours of continuous “free” operating time on a Google Colab notebook or because your computer stopped for some reason. I live in Brazil and events of power shortages are common…
The fit_one_cycle method works with varying, adaptive learning rates, following a curve where the rate is first increased and then decreased. If you interrupt a training in epoch #10 of, say, 20 epochs and then start again for more 9 epochs, you’ll not have the same result as training uninterruptedly for 20 epochs. You have to be able to record where you stopped and then resume the training cycle from that point and with the correct hyperparameters for that part of the cycle.
The first thing you have to do is to save your network every epoch:
learn.fit_one_cycle(20, max_lr=slice(1e-5,1e-6),
callbacks=[SaveModelCallback(learn, every='epoch',
monitor='accuracy', name=<callback_save_file>)])
This will have your network be saved every epoch, with the name you provided followed by _#epoch. So at epoch #3, the file saved_net_3.pth will be written. This file you can load after you:
- re-created the DataBunch and
- re-instantiated the network with it.
After reloading the .pth file, you can restart your training, only you’ll tell fit_one_cycle to consider 20 epochs, but to start training from epoch #4.
To learn how this is done, look here:
How do you do it?
The fit_one_cycle method in fast.ai has been developed to allow you to tell it from which part of the cycle to resume an interrupted training. The code for resuming a training will look like this:
# Create a new net if training was interrupted and you had to
# restart your Colab sessionlearn = cnn_learner(data, models.<your_model_here>,
metrics=[accuracy, error_rate])# If you're resuming, only indicating the epoch from which to
# resume, indicated by start_epoch=<epoch#> will load the last
# saved .pth, it is not necessary to explicitly reload the last
# epoch, you only should NOT change the name given in
# name=<callback_save_file>:
# when resuming fast.ai will try to reload
# <callback_save_file>_<previous_epoch>.pth
# Unfreeze the network
learn50.unfreeze()# Use start_epoch=<some_epoch> to resume training...
learn.fit_one_cycle(20, max_lr=slice(1e-5,1e-6),
start_epoch=<next_epoch#>,
callbacks=[SaveModelCallback(learn,
every='epoch', monitor='accuracy',
name=<callback_save_file>)])
fast.ai will tell you “Loaded <callback_save_file>_<previous_epoch#>” and resume training.
You can look at all parameters supported by the fit_one_cycle method here:
What have we learned?
Compared to other approaches, where the authors respectively employ pure Keras and TensorFlow.Keras to solve the problem, with fast.ai we were able to solve the same malaria blood smear classification problem employing much less code while using high-level hyperparameter optimization strategies that allowed us to train much faster. At the same time, a set of high level functions allows us also to easily inspect the results both as tables and as graphs.
With off-the-shelf residual network models, pre-trained on ImageNet, provided by fast.ai we obtained accuracy results that are 1% better than two of the three previous works above (including the scientific paper published on PeerJ) and equal to the best performing of the works above. This shows that fast.ai is a very promising alternative to more traditional CNN frameworks, especially if the task at hand is is a “standard” deep learning task such as image classification, object detection or semantic segmentation, that can be solved fine-tuning off-the-shelf pre-trained network models.
Want to see a more ambitious example?
Look at our skin cancer detection tutorial with fast.ai:
- TowardsDataScience::Deep Learning for Diagnosis of Skin Images with fastai — Learn to identify skin cancer and other conditions from dermoscopic images by Aldo von Wangenheim
In this example we will show you how to work with larger datasets and adapt the best learning rates.