Photo by Sarah Dorweiler on Unsplash

10 Minutes to Building a Binary Image Classifier By Applying Transfer Learning to MobileNet in TensorFlow

How to build a binary image classifier by training on top of the MobileNet model

Binh Phan
Towards Data Science
5 min readJul 13, 2020

--

This is a short introduction to computer vision — namely, how to build a binary image classifier using transfer learning on the MobileNet model, geared mainly towards new users. This easy-to-follow tutorial is broken down into 3 sections:

  1. The data
  2. The model architecture
  3. The accuracy, ROC curve, and AUC

Requirements: Nothing! All you need to follow this tutorial is this Google Colab notebook containing the data and code. Google Colab allows you to write and run Python code in-browser without any setup, and includes free GPU access! To run this code, simply go to File -> Make a copy to create a copy of the notebook that you can run and edit.

1. The Data

We’re going to build a dandelion and grass image classifier. I’ve created a small image dataset using images from Google Images, which you can download and parse in the first 8 cells of the tutorial.

By the end of those 8 lines, visualizing a sample of your image dataset will look something like this:

Note how some of the images in the dataset aren’t perfect representations of grass or dandelions. For simplicity’s sake, let’s make this okay and move on to how to easily create our training and validation dataset.

The data that we fetched earlier is divided into two folders, train and valid. In those folders, the foldersdandelion and grass contain the images of each class. To create a dataset, let’s use the keras.preprocessing.image.ImageDataGenerator class to create our training and validation dataset and normalize our data. What this class does is create a dataset and automatically does the labeling for us, allowing us to create a dataset in just one line!

2. The Model Architecture

In the beginning of this section, we first import TensorFlow.

Now, let’s add the MobileNet model. Make sure that to include the include_top parameter and set to to False. This will subtract the last layer of the model, so that we can add our own layer that we will train on. This is called transfer learning! We will then add a GlobalAveragePooling2D layer to reduce the size of the output that we will feed into our last layer. For that last layer, we will add a Sigmoid layer for binary classification. This is important: we must set our MobileNet layers’ trainable parameter to False so that we don’t end up training the entire model — we only need to train the last layer!

Here is the model that we have built:

model = Sequential()model.add(MobileNetV2(include_top = False, weights="imagenet", input_shape=(200, 200, 3)))model.add(tf.keras.layers.GlobalAveragePooling2D())model.add(Dense(1, activation = 'sigmoid'))model.layers[0].trainable = False

Let’s see a summary of the model we have built:

Model: "sequential" _________________________________________________________________ Layer (type)                 Output Shape              Param #    ================================================================= mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984    _________________________________________________________________ global_average_pooling2d (Gl (None, 1280)              0          _________________________________________________________________ dense (Dense)                (None, 1)                 1281       ================================================================= Total params: 2,259,265 Trainable params: 1,281 Non-trainable params: 2,257,984

Next, we’ll configure the specifications for model training. We will train our model with the binary_crossentropy loss. We will use the RMSProp optimizer. RMSProp is a sensible optimization algorithm because it automates learning-rate tuning for us (alternatively, we could also use Adam or Adagrad for similar results). We will add accuracy to metrics so that the model will monitor accuracy during training.

model.compile(optimizer=RMSprop(lr=0.01), loss = 'binary_crossentropy', metrics = 'accuracy')

Let’s train for 15 epochs:

history = model.fit(train_generator,steps_per_epoch=8,epochs=15,verbose=1,validation_data = validation_generator,validation_steps=8)

3. The Accuracy, ROC Curve, and AUC

Let’s evaluate the accuracy of our model:

model.evaluate(validation_generator)

Now, let’s calculate our ROC curve and plot it.

First, let’s make predictions on our validation set. When using generators to make predictions, we must first turn off shuffle (as we did when we created validation_generator) and reset the generator:

STEP_SIZE_TEST=validation_generator.n//validation_generator.batch_sizevalidation_generator.reset()preds = model.predict(validation_generator,verbose=1)

To create the ROC curve and AUC, we’ll need to compute the false-positive rate and the true-positive rate:

fpr, tpr, _ = roc_curve(validation_generator.classes, preds)roc_auc = auc(fpr, tpr)plt.figure()lw = 2plt.plot(fpr, tpr, color='darkorange',lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate')plt.title('Receiver operating characteristic example')plt.legend(loc="lower right")plt.show()
ROC curve for our model

The ROC curve is a probability curve plotting the true-positive rate (TPR) against the false-positive rate (FPR). In this curve, the diagonal line is the curve for random guessing, e.g. coin flipping, so the ROC curve above shows that our model does pretty well on classification!

Similarly, the AUC (area under curve), as shown in the legend above, measures how much our model is capable of distinguishing between our two classes, dandelions and grass. The higher the AUC, the better our model is at classification. An AUC of .96 is considered pretty good!

Finally, at the end of the notebook, you’ll have a chance to make predictions on your own images!

you can now make predictions on your own images!

I hope this gives you a gentle introduction to building a simple binary image classifier using transfer learning on the MobileNet model! If you are interested in similar tutorials to this, please check out my other stories.

--

--