AI for Good — Medical Image Analyses for Malaria Detection

Shubham Goyal
Towards Data Science
14 min readNov 13, 2019

--

In this blog we will talk about why Malaria detection is important to detect early presence of parasitized cells in a thin blood smear and some hands on for same.

Introduction

Malaria is a deadly, infectious mosquito-borne disease caused by Plasmodium parasites. These parasites are transmitted by the bites of infected female Anopheles mosquitoes. While we won’t get into details about the disease, there are five main types of malaria.

Let’s now look at the significance of how deadly this disease can be in the following plot.

It is pretty clear that malaria is prevalent across the world particularly in tropical regions. The motivation for this project is but supported the character and fatality of this sickness. initially if associate infected mosquito bites you, parasites carried by the mosquito can get in your blood and begin destroying oxygen-carrying RBCs (red blood cells). Usually the primary symptoms of malaria are kind of like the flu or a virus once you usually start feeling sick within a few days or weeks once the mosquito bite. however these deadly parasites will sleep in your body for over a year with none problems! so, a delay in the right treatment will cause complications and even death. therefore early and effective testing and detection of malaria will save lives.

Approach to the solution

Although the malaria virus doesn’t take the form of a mutant mosquito, it sure feels like a mutant problem. The deadly disease has reached epidemic, even endemic proportions in different parts of the world — killing around 400,000 people annually . In other areas of the world, it’s virtually nonexistent. Some areas are just particularly prone to a disease outbreak — there are certain factors that make an area more likely to be infected by malaria .

  • High poverty levels
  • Lack of access to proper healthcare
  • Political instability
  • Presence of disease transmission vectors (ex. mosquitos) [6]

With this mixture of these problems, we must keep some things in mind when building our model:

  • There may be a lack of a reliable power source
  • Battery-powered devices have less computational power
  • There may be a lack of Internet connection (so training/storing on the cloud may be hard!)

Traditional Methods for Malaria Detection

There are several methods and tests which can be used for malaria detection and diagnosis.

These include but are not limited to, thick and thin blood smear examinations, polymerase chain reaction (PCR) and rapid diagnostic tests (RDT). I will not going to cover all the methods but the thing is , traditional tests typically used an alternative particularly where good quality microscopy services cannot be readily provided.

Microscopic examination of blood is the best known method for diagnosis of malaria. A patient’s blood is smeared on a glass slide and stained with a contrasting agent that facilitates identification of parasites within red blood cells.

A trained clinician examines 20 microscopic fields of view at 100 X magnification, counting red blood cells that contain the parasite out of 5,000 cells (WHO protocol).

thanks Carlos Atico for this wonder blog on data science insights

Thus, malaria detection is definitely an intensive manual process, which can perhaps be automated using deep learning which forms the basis of this blog.

Deep learning for Malaria Detection

Deep Learning models, or if I have to say more specifically, Convolutional Neural Networks (CNNs) have proven to be really effective in a wide variety of computer vision tasks. While we assume that you have some knowledge on CNNs, in case you don’t, feel free to dive deeper into them by checking out this article here. Briefly, The key layers in a CNN model include convolution and pooling layers as depicted in the following figure.

Convolutional neural networks(CNN) can automatically extract features and learn filters. In previous machine learning solutions, features had to be manually programmed in — for example, size, color, the morphology of the cells. Utilizing Convolutional neural networks (CNN) will greatly speed up prediction time while mirroring (or even exceeding) the accuracy of clinicians.

CNN learns hierarchical patterns from our data. Thus they are able to learn different aspects of images. For example, the first convolution layer will learn small and local patterns such as edges and corners, a second convolution layer will learn larger patterns based on the features from the first layers, and so on.

You can go through an very interesting Research paper ‘Pre-trained convolutional neural networks as feature extractors toward improved parasite detection in thin blood smear imagesby Rajaraman et al. It explains a six pre-pretrained models on the data mentioned in the above paper. to obtain an accuracy of 95.9% in detecting malaria vs non-infected samples.

Dataset Explanation

Lets see what data we are using for this problem set , I am very thankful to researchers at the Lister Hill National Center for Biomedical Communications (LHNCBC), part of National Library of Medicine (NLM) who have carefully collected and annotated this dataset of healthy and infected blood smear images. You can download these images from the official site.

They had also launched an mobile application , that can run on an andriod smartphone attached to a conventional light microscope (Poostchi et al., 2018). Giemsa-stained thin blood smear slides from 150 P. falciparum-infected and 50 healthy patients were collected and photographed at Chittagong Medical College Hospital, Bangladesh. The smartphone’s built-in camera acquired images of slides for each microscopic field of view. The images were manually annotated by an expert slide reader at the Mahidol-Oxford Tropical Medicine Research Unit in Bangkok, Thailand.

Note: Now before we begin, I’d like to point out that I am neither a doctor nor a healthcare researcher and I’m nowhere near to being as qualified as they are. I do have interests though in applying AI for healthcare research. The intent of this article is not to dive into the hype that AI would be replacing jobs and taking over the world, but to showcase how AI can be useful in assisting with malaria detection, diagnosis and reducing manual labor with low-cost effective and accurate open-source solutions.

This is what our training data is look like

Note: I am using google colab and the code will be according to the same , i recommend you all to use google colab for ease.

lets go to the code:

Code

Initialization

%reload_ext autoreload
%autoreload 2
%matplotlib inline

Virtual machine testing on google colab

If Google’s servers are crowded, you’ll eventually have access to only part of a GPU. If your GPU is shared with another Colab notebook, you’ll see a smaller amount of memory made available for you.

!/opt/bin/nvidia-smi
!nvcc --version

When I started running the experiments described here, I was lucky: I had 11441 MB RAM! My output looked like this:

Mon Nov  4 05:40:26 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |
| N/A 49C P8 30W / 149W | 0MiB / 11441MiB | 0% Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130

Import libraries

Here we import all the necessary packages. We are going to work with the fast.ai V1 library which sits on top of Pytorch 1.0. The fast.ai library provides many useful functions that enable us to quickly and easily build neural networks and train our models.

from fastai.vision import *
from fastai.metrics import error_rate
from fastai.callbacks import SaveModelCallback
# Imports for diverse utilities
from shutil import copyfile
import matplotlib.pyplot as plt
import operator
from PIL import Image
from sys import intern # For the symbol definitions

Some util functions: Export and restoration

Now here is an export network for deployment and one for creating a copy

# Export network for deployment and create a copydef exportStageTo(learn, path):
learn.export()
# Faça backup diferenciado
copyfile(path/'export.pkl', path/'export-malaria.pkl')

#exportStage1(learn, path)

Restoration of a deployment model, for example in order to continue fine-tuning

# Restoration of a deployment model, for example in order to continue fine-tuningdef restoreStageFrom(path):
# Restore a backup
copyfile(path/'export-malaria.pkl', path/'export.pkl')
return load_learner(path)

#learn = restoreStage1From(path)

Now let’s download the NIH dataset, on which we will work today:

!wget  --backups=1 -q https://ceb.nlm.nih.gov/proj/malaria/cell_images.zip
!wget --backups=1 -q https://ceb.nlm.nih.gov/proj/malaria/malaria_cell_classification_code.zip
!ls -al

The backups=1 parameter of wget will allow you to repeat the command line many times, if a download fails, without creating a lot of new versions of the files.

The output of the last line should be like this:

total 690400
drwxr-xr-x 1 root root 4096 Nov 4 10:09 .
drwxr-xr-x 1 root root 4096 Nov 4 05:34 ..
-rw-r--r-- 1 root root 353452851 Apr 6 2018 cell_images.zip
-rw-r--r-- 1 root root 353452851 Apr 6 2018 cell_images.zip.1
drwxr-xr-x 1 root root 4096 Oct 30 15:14 .config
drwxr-xr-x 5 root root 4096 Nov 4 07:26 data
-rw-r--r-- 1 root root 12581 Apr 6 2018 malaria_cell_classification_code.zip
-rw-r--r-- 1 root root 12581 Apr 6 2018 malaria_cell_classification_code.zip.1
drwxr-xr-x 1 root root 4096 Oct 25 16:58 sample_data

Let's unzip the NIH cell images dataset:

!unzip cell_images.zip

This will produce a very large verbose output, that will look like this:

Archive:  cell_images.zip
creating: cell_images/
creating: cell_images/Parasitized/
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_162.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_163.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_164.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_165.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_166.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_167.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_168.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_169.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_170.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_171.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_138.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_139.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_140.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_141.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_142.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_143.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144348_cell_144.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_157.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_158.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_159.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_160.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_161.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_144823_cell_162.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145042_cell_162.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145042_cell_163.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145042_cell_164.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145042_cell_165.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145042_cell_166.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145042_cell_167.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_163.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_164.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_165.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_166.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_167.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_168.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145422_cell_169.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_144.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_145.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_146.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_147.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_148.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_149.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_150.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145609_cell_151.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_167.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_168.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_169.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_170.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_171.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_172.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_173.png
extracting: cell_images/Parasitized/C100P61ThinF_IMG_20150918_145938_cell_174.png
and .......... so on

Prepare your data

Change the name of the cell_images folder to train, and then mv it on top of a new data folder, so fast.ai can use it to automatically generate train, validation and test sets, without any further fuss

!mv cell_images train
!mkdir data
!mv train data

Deep dive on data folder

Install the tree command if you haven’t it:

!apt install tree

let's run the command now

!tree ./data --dirsfirst --filelimit 10

this will show a tree structure of your folder:

./data
└── train
├── Parasitized [13780 exceeds filelimit, not opening dir]
└── Uninfected [13780 exceeds filelimit, not opening dir]3 directories, 0 files

Variable Initialization

bs = 256                # Batch size, 256 for small images on a T4 GPU...
size = 128 # Image size, 128x128 is a bit smaller than most of the images...
path = Path("./data") # The path to the 'train' folder you created...

create training and validation data bunches

With fast.ai it is not necessary: if you only have a ‘train’ folder, you can automatically split it while creating the DataBunch by simply passing a few parameters. We will split the data into a training set (80%) and a validation set (20%). This is done with the valid_pct = 0.2 parameter in the ImageDataBunch.from_folder() constructor method:

Limit your augmentations: it’s medical data! You do not want to fantasize about data…

Warping, for example, will let your images badly distorted, so don’t do it!

This dataset is big, so don’t rotate the images either. Let's stick to flipping…

tfms = get_transforms(max_rotate=None, max_warp=None, max_zoom=1.0)
# Create the DataBunch!
# Remember that you'll have images that are bigger than 128x128 and images that are smaller,
# so squish them all in order to occupy exactly 128x128 pixels...
data = ImageDataBunch.from_folder(path, ds_tfms=tfms, size=size, resize_method=ResizeMethod.SQUISH,
valid_pct = 0.2, bs=bs)
#
print('Transforms = ', len(tfms))
# Save the DataBunch in case the training goes south... so you won't have to regenerate it..
# Remember: this DataBunch is tied to the batch size you selected.
data.save('imageDataBunch-bs-'+str(bs)+'-size-'+str(size)+'.pkl')
# Show the statistics of the Bunch...
print(data.classes)
data

The print() will output the transforms and the classes:

Transforms =  2
['Parasitized', 'Uninfected']

and the data command in the last will simply output the return value of the ImageDataBunch instance:

ImageDataBunch;Train: LabelList (22047 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Uninfected,Uninfected,Uninfected,Uninfected,Uninfected
Path: data;
Valid: LabelList (5511 items)
x: ImageList
Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128),Image (3, 128, 128)
y: CategoryList
Uninfected,Parasitized,Parasitized,Parasitized,Parasitized
Path: data;
Test: None

A look at augmented Databunches

data.show_batch(rows=5, figsize=(15,15))

ResNet18

The ResNet18 is much smaller, so we’ll have more GPU RAM for us. We will create the DataBunch again, this time with bigger batch size.

I had also used ResNet50, It was giving an 92% accuracy.

But ResNet18 is a good fit for this data and in this blog, we are going to use that,

Now, create the learner:

learn18 = cnn_learner(data, models.resnet18, metrics=error_rate)

If your Colab environment doesn’t have the pretrained data for the ResNet18, fast.ai will automatically download it:

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.torch/models/resnet18-5c106cde.pth
46827520it [00:01, 28999302.58it/s]

Let's have a look at the model:

learn18.modelThis will list the structure of your net. It is much smaller than the ResNet34, but still has a lot of layers. The output will look like this:Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
.......... and so on

Let’s train it

We will again use the fit_one_cycle HYPO training strategy. Limit the training to 10 epochs to see how this smaller network behaves:

learn18.fit_one_cycle(10, callbacks=[SaveModelCallback(learn18, every='epoch', monitor='accuracy', name='malaria18-1')])
# Save the model
learn18.save('malaria18-stage-1')
# export the model stage
exportStageTo(learn18, path)

This table below shows that the network learned to an accuracy of roughly 96.1% and suggests that the network should not be trained further: the loss between epoch #8 and #9 shows a 0.005 decrease but the accuracy has remained the same, suggesting that the network has started overfitting.

epoch 	train_loss 	valid_loss 	accuracy 	time
0 0.543975 0.322170 0.901107 01:08
1 0.354574 0.198226 0.927055 01:07
2 0.256173 0.173847 0.938487 01:08
3 0.197873 0.157763 0.943930 01:09
4 0.163859 0.148826 0.947197 01:08
5 0.143582 0.142058 0.948104 01:08
6 0.130425 0.134914 0.949918 01:09
7 0.118313 0.132691 0.951551 01:09
8 0.112078 0.132101 0.952459 01:09
9 0.107859 0.131681 0.952096 01:09

Let’s generate a ClassificationInterpretation and look at the confusion matrix and the loss curves.

interp = ClassificationInterpretation.from_learner(learn18)losses,idxs = interp.top_losses()len(data.valid_ds)==len(losses)==len(idxs)

Let’s have a look at the confusion matrix:

interp.plot_confusion_matrix(figsize=(5,5), dpi=100)

Let’s look at the losses:

learn18.recorder.plot_losses()

ResNet18 started overfitting a bit after about 290 batches.

Remember that our bs is 512 here and was 256 for the ResNet34.

Fine-tuning the model

Here we will introduce another fast.ai HYPO: automatically chosen variable learning rates. We will let fast.ai choose which learning rate to use for each epoch and each layer, providing a range of learning rates we consider adequate. We will train the network for 30 epochs.

# Unfreeze the network
learn18.unfreeze()

learn18.fit_one_cycle(30, max_lr=slice(1e-4,1e-5),
callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='malaria18')])
learn18.save('malaria18-stage-2')
# export fo deployment
exportStageTo(learn18, path)

Here, our fast.ai model achieved an accuracy of approx 97% after some fine-tuning, which is quite good enough.

The validation loss, however, was becoming worse for the last epochs. This indicates that we have been overfitting from about epoch #20 on.

So if we want to deploy the model, choose the one with 20 epochs and note that this is the best accuracy at least with this network, we can you different network for better results.

Results

interp = ClassificationInterpretation.from_learner(learn18)losses,idxs = interp.top_losses()len(data.valid_ds)==len(losses)==len(idxs)interp.plot_confusion_matrix(figsize=(5,5), dpi=100)
learn.recorder.plot_losses()

This is better than what we’ve had before. Let’s look at the loss curves:

Here we see that the network seems to start to over-fit after 500 batches, which would confirm our suspicion inferred from the results table above. If you look at the curves above, you’ll see that the validation loss starts to grow in the last third of the training, suggesting that this part of the training only over-fitted the network.

Now I will give this task to you guys, tune the network or use any technique and control this situation.

Hint: save your model after every epoch.

Happy learning

Follow MachineX for more”

References

--

--

Shubham Goyal is a Data Scientist at Presight. He is an artificial intelligence researcher, written a few research papers on machine learning and a speaker.