The world’s leading publication for data science, AI, and ML professionals.

Transfer Learning and Data Augmentation applied to the Simpsons Image Dataset

Deep Learning application using Tensorflow and Keras

Making Sense of Big Data

1. Introduction

In the ideal scenario for Machine Learning (ML), there are abundant labeled training instances, which share the same distribution as the test data [1]. However, these data can be resource-intensive or unrealistic to collect in certain scenarios. Thus, Transfer Learning (TL) becomes a useful approach. It consists of increasing the learning ability of a model by transferring information from a different but related domain. In other words, it relaxes the hypothesis that the training and testing data are independent and identically distributed [2]. It only works if the features that are intended to be learned are general to both tasks. Another method to work with limited data is by using Data Augmentation (DA). It consists of applying a suite of transformations to inflate the dataset. Traditional ML algorithms rely significantly on feature engineering, while Deep Learning (DL) focuses on learning data by unsupervised or semi-supervised feature learning methods and hierarchical feature extraction. DL often requires massive amounts of data to be trained effectively, making it a strong candidate for TL and DA.

Our task is to classify a series of labeled images. We are faced with two problems due to the small size of the dataset: the challenge to effectively learn the patterns in the data and the high probability of overfitting. We start by implementing a Convolutional Neural Network (CNN) model from scratch to be used as the benchmark model. Next, following the principle of TL, we use a pre-trained Convolutional Neural Network ([3], [4]) on the ImageNet dataset. We remove its top layers to include our own deep structure suited for our problem specifications. Thus, the pre-trained CNN works as a feature extraction layer in the overall new model. With this approach, we address both problems: we greatly reduce the requirement of a large training data while also reducing overfitting. We also conduct a second experiment, where we augment our training data by applying a suite of techniques that enhance the size and quality of the images. This method is defined as Data Augmentation (DA), and it is a regularization technique. While it preserves the labels, it also inflates the dataset using transformations to add more invariant examples [5].

In this article, we use the Simpsons characters dataset [6]. We filtered the dataset only to contain classes (characters) that contained more than 100 images. After the split between train, validation, and test datasets, the resulting size of the dataset is the following: 12411 images for training, 3091 images for validation, and 950 for testing.

As always, the code is available on my GitHub.

This article belongs to a series of articles on Deep Learning using TensorFlow:

2. Data Preprocessing

Despite the size of the dataset being small to be effectively learned by a CNN, it is big enough for us to have memory issues when loading and transforming it. We use data generators to feed real-time data to our different models. Generator functions are a special type of function that returns a lazy iterator, i.e., they do not store their contents in memory. When creating the generator, we apply a transformation to normalize our data, split them between training and validation datasets, and define a batch size of 32.

import tensorflow as tf
from tensorflow.keras.models import  Sequential, Model
import numpy as np
import os
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from tensorflow.keras import Input, layers
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
import tensorflow_hub as hub
from tensorflow.keras.layers import Dropout, BatchNormalization
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
directory_train = "./data/simpsons_data_split/train/"
directory_test = "./data/simpsons_data_split/test/"
def get_ImageDataGenerator(validation_split=None):
    image_generator = ImageDataGenerator(rescale=(1/255.),
                                         validation_split=validation_split)
    return image_generator
image_gen_train = get_ImageDataGenerator(validation_split=0.2)
def get_generator(image_data_generator, directory, train_valid=None, seed=None):
    train_generator = image_data_generator.flow_from_directory(directory, batch_size=32, class_mode='categorical', target_size=(299,299), subset=train_valid, seed=seed)    
    return train_generator
train_generator = get_generator(image_gen_train, directory_train, train_valid='training', seed=1)
validation_generator = get_generator(image_gen_train, directory_train, train_valid='validation')
Found 12411 images belonging to 19 classes.
Found 3091 images belonging to 19 classes.
image_gen_test = get_ImageDataGenerator(validation_split=None)
test_generator = get_generator(image_gen_test, directory_test)
Found 950 images belonging to 19 classes.

We can iterate through our generators to get a set of images with a size equal to the batch size defined above.

target_labels = next(os.walk(directory_train))[1]

target_labels.sort()

batch = next(train_generator)
batch_images = np.array(batch[0])
batch_labels = np.array(batch[1])

target_labels = np.asarray(target_labels)

plt.figure(figsize=(15,10))
for n, i in enumerate(np.arange(10)):
    ax = plt.subplot(3,5,n+1)
    plt.imshow(batch_images[i])
    plt.title(target_labels[np.where(batch_labels[i]==1)[0][0]])
    plt.axis('off')
Figure 1: Set of images yield by the training generator.
Figure 1: Set of images yield by the training generator.

3. The Benchmark Model

We define a simple CNN to be used as a benchmark model. It uses a combination of 2D convolution layers (to perform spatial convolution over images) with max-pooling operations. These are followed by a dense layer with 128 units and ReLU activation function and a dropout layer with a rate of 0.5. Finally, the last layer yields the output of our network, which has a number of units equal to the number of target labels and uses a softmax activation function. The model was compiled with the Adam optimizer with the default settings and categorical cross-entropy loss.

def get_benchmark_model(input_shape):
    x = Input(shape=input_shape)
    h = Conv2D(32, padding='same', kernel_size=(3,3), activation='relu')(x)
    h = Conv2D(32, padding='same', kernel_size=(3,3), activation='relu')(x)
    h = MaxPooling2D(pool_size=(2,2))(h)
    h = Conv2D(64, padding='same', kernel_size=(3,3), activation='relu')(h)
    h = Conv2D(64, padding='same', kernel_size=(3,3), activation='relu')(h)
    h = MaxPooling2D(pool_size=(2,2))(h)
    h = Conv2D(128, kernel_size=(3,3), activation='relu')(h)
    h = Conv2D(128, kernel_size=(3,3), activation='relu')(h)
    h = MaxPooling2D(pool_size=(2,2))(h)
    h = Flatten()(h)
    h = Dense(128, activation='relu')(h)
    h = Dropout(.5)(h)
    output = Dense(target_labels.shape[0], activation='softmax')(h)

    model = tf.keras.Model(inputs=x, outputs=output)

    model.compile(optimizer='adam',
             loss='categorical_crossentropy',
             metrics=['accuracy'])
    return model

Below, one can find the summary of our model, with the detailed list of defined layers and the number of parameters to train for each layer.

benchmark_model = get_benchmark_model((299, 299, 3))
benchmark_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 299, 299, 3)]     0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 299, 299, 32)      896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 149, 149, 32)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 149, 149, 64)      18496     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 149, 149, 64)      36928     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 74, 74, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 72, 72, 128)       73856     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 70, 70, 128)       147584    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 35, 35, 128)       0         
_________________________________________________________________
flatten (Flatten)            (None, 156800)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               20070528  
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 19)                2451      
=================================================================
Total params: 20,350,739
Trainable params: 20,350,739
Non-trainable params: 0
_________________________________________________________________

We trained the benchmark CNN model using a callback to stop the training process earlier if the validation accuracy is not improved for 10 epochs. This number should be smaller, but it was used to show that the model rapidly starts to overfit the data.

def train_model(model, train_gen, valid_gen, epochs):
    train_steps_per_epoch = train_gen.n // train_gen.batch_size
    val_steps = valid_gen.n // valid_gen.batch_size

    earlystopping = tf.keras.callbacks.EarlyStopping(patience=10)
    history = model.fit(train_gen, 
                        steps_per_epoch = train_steps_per_epoch,
                        epochs=epochs,
                        validation_data=valid_gen, 
                        callbacks=[earlystopping])

    return history

Every time we use a generator, we need to reset it before feeding it to a model; otherwise, we will be losing batches of data.

train_generator = get_generator(image_gen_train, directory_train, train_valid='training')
validation_generator = get_generator(image_gen_train, directory_train, train_valid='validation')
Found 12411 images belonging to 19 classes.
Found 3091 images belonging to 19 classes.
history_benchmark = train_model(benchmark_model, train_generator, validation_generator, 50)
Epoch 1/50
387/387 [==============================] - 747s 2s/step - loss: 2.8358 - accuracy: 0.1274 - val_loss: 2.4024 - val_accuracy: 0.2436
Epoch 2/50
387/387 [==============================] - 728s 2s/step - loss: 2.3316 - accuracy: 0.2758 - val_loss: 1.9895 - val_accuracy: 0.4170
[...]
Epoch 14/50
387/387 [==============================] - 719s 2s/step - loss: 0.3846 - accuracy: 0.8612 - val_loss: 2.4831 - val_accuracy: 0.5962
Epoch 15/50
387/387 [==============================] - 719s 2s/step - loss: 0.3290 - accuracy: 0.8806 - val_loss: 2.5545 - val_accuracy: 0.5930

The figure below shows the accuracy and loss evolution over time (epochs) for the training and validation datasets. Clearly, our model is overfitting the data as the accuracy of training is close to 90%, and the loss of the validation dataset is actually increasing for the last epochs. That is also the reason for the reduced number of epochs in training.

plt.figure(figsize=(15,5))
plt.subplot(121)
plt.plot(history_benchmark.history['accuracy'])
plt.plot(history_benchmark.history['val_accuracy'])
plt.title('Accuracy vs. epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='lower right')

plt.subplot(122)
plt.plot(history_benchmark.history['loss'])
plt.plot(history_benchmark.history['val_loss'])
plt.title('Loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.show()
Figure 2: Accuracy and Loss evolution over several epochs of the benchmark model.
Figure 2: Accuracy and Loss evolution over several epochs of the benchmark model.

We can evaluate our benchmark model by fitting it to our test dataset. The results are shown below.

test_steps = test_generator.n // test_generator.batch_size
benchmark_test_loss, benchmark_test_acc = benchmark_model.evaluate(test_generator, steps=test_steps)
print('nTest dataset:')
print("Loss: {}".format(benchmark_test_loss))
print("Accuracy: {}".format(benchmark_test_acc))
29/29 [==============================] - 9s 304ms/step - loss: 2.5011 - accuracy: 0.6272

Test dataset:
Loss: 2.5011332035064697
Accuracy: 0.6271551847457886

4. Pre-trained CNN

For the pre-trained model, we use the Xception architecture ([3], [4]), a deep CNN implemented in the keras.applications module. We have loaded the pre-trained parameters (learned on the ImageNet dataset). We use the pre-trained CNN as a large feature extraction layer, which we extend with an untrained set of layers specific for our multiclass classification task. This is where the principle of TL is effectively applied.

feature_extractor = tf.keras.applications.Xception(weights="imagenet")
feature_extractor.summary()
Model: "xception"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
[...]     
avg_pool (GlobalAveragePooling2 (None, 2048)         0           block14_sepconv2_act[0][0]       
__________________________________________________________________________________________________
predictions (Dense)             (None, 1000)         2049000     avg_pool[0][0]                   
==================================================================================================
Total params: 22,910,480
Trainable params: 22,855,952
Non-trainable params: 54,528
_________________________________________________________________________________________________
def remove_head(feature_extractor_model):
    model_input = feature_extractor_model.input
    output = feature_extractor_model.get_layer(name='avg_pool').output
    model = tf.keras.Model(inputs=model_input, outputs=output)
    return model
feature_extractor = remove_head(feature_extractor)
feature_extractor.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
[...]    
avg_pool (GlobalAveragePooling2 (None, 2048)         0           block14_sepconv2_act[0][0]       
==================================================================================================
Total params: 20,861,480
Trainable params: 20,806,952
Non-trainable params: 54,528
__________________________________________________________________________________________________
def add_new_classifier_head(feature_extractor_model):
    model = Sequential([
        feature_extractor_model,
        Dense(128, activation='relu'),
        Dropout(.5),
        Dense(target_labels.shape[0], activation='softmax')
    ])

    return model

Below, we can see the added layers to the head of our model.

new_model = add_new_classifier_head(feature_extractor)
new_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
model_1 (Functional)         (None, 2048)              20861480  
_________________________________________________________________
dense_2 (Dense)              (None, 128)               262272    
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 19)                2451      
=================================================================
Total params: 21,126,203
Trainable params: 21,071,675
Non-trainable params: 54,528
_________________________________________________________________
def freeze_pretrained_weights(model):
    model.get_layer(name='model_1').trainable=False

    model.compile(optimizer='adam',
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])
    return model

We freeze the pre-trained CNN parameters to be non-trainable – we can see that we have more than 20M non-trainable parameters in our new model. This also results in a shorter training time per epoch when compared to the benchmark model.

frozen_new_model = freeze_pretrained_weights(new_model)
frozen_new_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
model_1 (Functional)         (None, 2048)              20861480  
_________________________________________________________________
dense_2 (Dense)              (None, 128)               262272    
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 19)                2451      
=================================================================
Total params: 21,126,203
Trainable params: 264,723
Non-trainable params: 20,861,480
_________________________________________________________________
def train_model(model, train_gen, valid_gen, epochs):
    train_steps_per_epoch = train_gen.n // train_gen.batch_size
    val_steps = valid_gen.n // valid_gen.batch_size

    history = model.fit(train_gen, 
                        steps_per_epoch = train_steps_per_epoch,
                        epochs=epochs,
                        validation_data=valid_gen)

    return history
history_frozen_new_model = train_model(frozen_new_model, train_generator, validation_generator, 50)
Epoch 1/50
387/387 [==============================] - 564s 1s/step - loss: 2.6074 - accuracy: 0.1943 - val_loss: 2.0344 - val_accuracy: 0.4232
Epoch 2/50
387/387 [==============================] - 561s 1s/step - loss: 2.0173 - accuracy: 0.3909 - val_loss: 1.7743 - val_accuracy: 0.5118
[...]
Epoch 49/50
387/387 [==============================] - 547s 1s/step - loss: 0.4772 - accuracy: 0.8368 - val_loss: 1.7771 - val_accuracy: 0.6137
Epoch 50/50
387/387 [==============================] - 547s 1s/step - loss: 0.4748 - accuracy: 0.8342 - val_loss: 1.7402 - val_accuracy: 0.6215

We ran the model for a longer number of epochs without using a callback. We do not see the same overfitting pattern, as more than 90% of the weights were frozen. In fact, we could have run it for a small number of epochs as we do not see real improvements after 10 epochs.

plt.figure(figsize=(15,5))
plt.subplot(121)
plt.plot(history_frozen_new_model.history['accuracy'])
plt.plot(history_frozen_new_model.history['val_accuracy'])
plt.title('Accuracy vs. epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='lower right')

plt.subplot(122)
plt.plot(history_frozen_new_model.history['loss'])
plt.plot(history_frozen_new_model.history['val_loss'])
plt.title('Loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.show()
Figure 3: Accuracy and Loss evolution over several epochs of the TL model.
Figure 3: Accuracy and Loss evolution over several epochs of the TL model.

Fitting our new model to the test dataset results in a small accuracy increase (a little above 1%).

test_generator = get_generator(image_gen_test, directory_test)
new_model_test_loss, new_model_test_acc = frozen_new_model.evaluate(test_generator, steps=test_steps)
print('nTest dataset')
print("Loss: {}".format(new_model_test_loss))
print("Accuracy: {}".format(new_model_test_acc))
Found 950 images belonging to 19 classes.
29/29 [==============================] - 33s 1s/step - loss: 1.6086 - accuracy: 0.6390

Test dataset
Loss: 1.6085671186447144
Accuracy: 0.639008641242981

5. Data Augmentation

As we saw above, DA is a set of methods used to inflate a dataset while reducing overfitting. We focus on the generic DA that consists of geometric and photometric transformations (see [5] for more information on these and other methods). Geometric transformations alter the geometry of the image, making the CNN invariant to change in position and orientation. On the other side, photometric transformations make the CNN invariant to changes in color and lighting by adjusting the color channels of the image.

def get_ImageDataGenerator_augmented(validation_split=None):
    image_generator = ImageDataGenerator(rescale=(1/255.),
                                        rotation_range=40,
                                        width_shift_range=0.2,
                                        height_shift_range=0.2,
                                        shear_range=0.2,
                                        zoom_range=0.1,
                                        brightness_range=[0.8,1.2],
                                        horizontal_flip=True,
                                        validation_split=validation_split)
    return image_generator
image_gen_train_aug = get_ImageDataGenerator_augmented(validation_split=0.2)
train_generator_aug = get_generator(image_gen_train_aug, directory_train, train_valid='training', seed=1)
validation_generator_aug = get_generator(image_gen_train_aug, directory_train, train_valid='validation')
Found 12411 images belonging to 19 classes.
Found 3091 images belonging to 19 classes.
train_generator = get_generator(image_gen_train, directory_train, train_valid='training', seed=1)
Found 12411 images belonging to 19 classes.

We can display the original images and the augmented images for comparison. Notice the changes in the geometry of the images, such as flipping, translation in vertical and horizontal directions or zooming, and photometry, visible in the altered brightness of some images.

batch = next(train_generator)
batch_images = np.array(batch[0])
batch_labels = np.array(batch[1])

aug_batch = next(train_generator_aug)
aug_batch_images = np.array(aug_batch[0])
aug_batch_labels = np.array(aug_batch[1])

plt.figure(figsize=(16,5))
plt.suptitle("original images", fontsize=16)
for n, i in enumerate(np.arange(10)):
    ax = plt.subplot(2, 5, n+1)
    plt.imshow(batch_images[i])
    plt.title(target_labels[np.where(batch_labels[i]==1)[0][0]])
    plt.axis('off')
plt.figure(figsize=(16,5))
plt.suptitle("Augmented images", fontsize=16)
for n, i in enumerate(np.arange(10)):
    ax = plt.subplot(2, 5, n+1)
    plt.imshow(aug_batch_images[i])
    plt.title(target_labels[np.where(batch_labels[i]==1)[0][0]])
    plt.axis('off')
Figure 4: Comparison between a set of images without any transformation and the corresponding augmented ones.
Figure 4: Comparison between a set of images without any transformation and the corresponding augmented ones.
train_generator_aug = get_generator(image_gen_train_aug, directory_train, train_valid='training')
Found 12411 images belonging to 19 classes.

The augmented dataset is now fed to our custom model (that does not use pre-trained weights) defined above.

benchmark_model_aug = benchmark_model
benchmark_model_aug.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 299, 299, 3)]     0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 299, 299, 32)      896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 149, 149, 32)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 149, 149, 64)      18496     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 149, 149, 64)      36928     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 74, 74, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 72, 72, 128)       73856     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 70, 70, 128)       147584    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 35, 35, 128)       0         
_________________________________________________________________
flatten (Flatten)            (None, 156800)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               20070528  
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 19)                2451      
=================================================================
Total params: 20,350,739
Trainable params: 20,350,739
Non-trainable params: 0
_________________________________________________________________
history_augmented = train_model(benchmark_model_aug, train_generator_aug, validation_generator_aug, epochs=150)
Epoch 1/150
387/387 [==============================] - 748s 2s/step - loss: 2.1520 - accuracy: 0.3649 - val_loss: 1.8956 - val_accuracy: 0.4426
Epoch 2/150
387/387 [==============================] - 749s 2s/step - loss: 1.8233 - accuracy: 0.4599 - val_loss: 1.6556 - val_accuracy: 0.5273
[...]
Epoch 149/150
387/387 [==============================] - 753s 2s/step - loss: 0.2859 - accuracy: 0.9270 - val_loss: 0.6202 - val_accuracy: 0.8609
Epoch 150/150
387/387 [==============================] - 753s 2s/step - loss: 0.2830 - accuracy: 0.9259 - val_loss: 0.6289 - val_accuracy: 0.8622

Overfitting is clearly not a problem anymore. The training can run for a longer number of epochs since it shows consistent improvements in the metrics. Still, we could have stopped the learning process around epoch number 70 or so, but we extended the process to show that DA can, in fact, reduce the probability of overfitting. Also, notice that we increased the capacity of our model to recognize the features in the data that belong to each character, as the accuracy on the validation set was raised significantly (to more than 86%).

plt.figure(figsize=(15,5))
plt.subplot(121)
plt.plot(history_augmented.history['accuracy'])
plt.plot(history_augmented.history['val_accuracy'])
plt.title('Accuracy vs. epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='lower right')

plt.subplot(122)
plt.plot(history_augmented.history['loss'])
plt.plot(history_augmented.history['val_loss'])
plt.title('Loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.show()
Figure 5: Accuracy and Loss evolution over several epochs of the custom model using DA.
Figure 5: Accuracy and Loss evolution over several epochs of the custom model using DA.

Fitting our custom model with augmented data on our test set results in a significantly increased accuracy of more than 91%.

test_generator = get_generator(image_gen_test, directory_test)
augmented_model_test_loss, augmented_model_test_acc = benchmark_model_aug.evaluate(test_generator, steps=test_steps)
print('nTest dataset')
print("Loss: {}".format(augmented_model_test_loss))
print("Accuracy: {}".format(augmented_model_test_acc))
Found 950 images belonging to 19 classes.
29/29 [==============================] - 9s 307ms/step - loss: 0.4446 - accuracy: 0.9106

Test dataset
Loss: 0.44464701414108276
Accuracy: 0.9105603694915771

6. Results

Finally, we can compare the training, validation, and test metrics between the benchmark model, the pre-trained model, defined using the principle of TL, and the custom model with augmented data. The results show that the TL approach only slightly surpasses the benchmark model. This is probably due to the nature of the data (domain) where the model was initially trained and how it transfers to the Simpsons characters domain. On the other hand, the approach using augmented data was able to capture the patterns in the data more effectively, increasing accuracy to more than 91% in the test set.

benchmark_train_loss = history_benchmark.history['loss'][-1]
benchmark_valid_loss = history_benchmark.history['val_loss'][-1]
benchmark_train_acc = history_benchmark.history['accuracy'][-1]
benchmark_valid_acc = history_benchmark.history['val_accuracy'][-1]

new_model_train_loss = history_frozen_new_model.history['loss'][-1]
new_model_valid_loss = history_frozen_new_model.history['val_loss'][-1]
new_model_train_acc = history_frozen_new_model.history['accuracy'][-1]
new_model_valid_acc = history_frozen_new_model.history['val_accuracy'][-1]

augmented_model_train_loss = history_augmented.history['loss'][-1]
augmented_model_valid_loss = history_augmented.history['val_loss'][-1]
augmented_model_train_acc = history_augmented.history['accuracy'][-1]
augmented_model_valid_acc = history_augmented.history['val_accuracy'][-1]
comparison = pd.DataFrame([['Training loss', benchmark_train_loss, new_model_train_loss, augmented_model_train_loss],
                          ['Training accuracy', benchmark_train_acc, new_model_train_acc, augmented_model_train_acc],
                          ['Validation loss', benchmark_valid_loss, new_model_valid_loss, augmented_model_valid_loss],
                          ['Validation accuracy', benchmark_valid_acc, new_model_valid_acc, augmented_model_valid_acc],
                          ['Test loss', benchmark_test_loss, new_model_test_loss, augmented_model_test_loss],
                          ['Test accuracy', benchmark_test_acc, new_model_test_acc, augmented_model_test_acc]],
                           columns=['Metric', 'Benchmark CNN', 'Transfer Learning CNN', 'Custom CNN w/ Data Augmentation'])
comparison.index=['']*6
comparison
Table 1: Results comparing the 3 models tested. The custom CNN with DA yields the best results - a test accuracy of over 91%.
Table 1: Results comparing the 3 models tested. The custom CNN with DA yields the best results – a test accuracy of over 91%.

To illustrate the custom CNN output with DA, we plot the categorical distribution of the predictions for random images in the test set.

test_generator = get_generator(image_gen_test, directory_test, seed=123)
predictions = benchmark_model_aug.predict(test_generator)
Found 950 images belonging to 19 classes.
test_generator = get_generator(image_gen_test, directory_test, seed=123)
batches = []
for i in range(1):
    batches.append(next(test_generator))

batch_images = np.vstack([b[0] for b in batches])
batch_labels = np.concatenate([b[1].astype(np.int32) for b in batches])
Found 950 images belonging to 19 classes.
fig, axes = plt.subplots(3, 2, figsize=(16, 17))
fig.subplots_adjust(hspace = 0.4, wspace=0.8)
axes = axes.ravel()

for i in range(3):

    inx = np.random.choice(batch_images.shape[0], 1, replace=False)[0]

    axes[0+i*2].imshow(batch_images[inx])
    axes[0+i*2].get_xaxis().set_visible(False)
    axes[0+i*2].get_yaxis().set_visible(False)
    axes[0+i*2].text(60., -8, target_labels[np.where(batch_labels[inx]==1)[0][0]], 
                    horizontalalignment='center')
    axes[1+i*2].barh(np.arange(len(predictions[inx])),predictions[inx])
    axes[1+i*2].set_yticks(np.arange(len(predictions[inx])))
    axes[1+i*2].set_yticklabels(target_labels)
    axes[1+i*2].set_title(f"Categorical distribution. Model prediction: {target_labels[np.argmax(predictions[inx])]}")

plt.show()
Figure 6: Random images (on the left) and the corresponding categorical distribution of the predictions yield by the custom CNN with DA (on the right)
Figure 6: Random images (on the left) and the corresponding categorical distribution of the predictions yield by the custom CNN with DA (on the right)

7. Conclusion

We addressed the problem of having a small dataset and the high likelihood of overfitting using two different approaches. First, we loaded a pre-trained model, from which we remove the top layers and add a specific set of layers required for our task. The second approach tested the usage of DA techniques to inflate our dataset. Our results show that the second approach is able to overcome the limited data and the overfitting problems, yielding very interesting metrics.

The TL principle could be further studied by testing more architectures or pre-trained networks in different types of datasets. In this case, a pre-trained model on a facial recognition task could yield interesting results. In regards to the DA method, it could be extended by the application of more complex transformations.

Keep in touch: LinkedIn

8. References

[1] – [Zhuang et al., 2020] Zhuang, F., Qi, Z., Duan, K., Xi, D., Zhu, Y., Zhu, H., Xiong, H., and He, Q. (2020). A comprehensive survey on transfer learning.

[2] – [Tan et al., 2018] Tan, C., Sun, F., Kong, T., Zhang, W., Yang, C., and Liu, C. (2018). A survey on deep transfer learning.

[3] – [Chollet, 2017] Chollet, F. (2017). Xception: Deep learning with depthwise separable convolutions.

[4]https://keras.io/api/applications/xception/

[5]https://arxiv.org/pdf/1708.06020.pdf

[6]https://www.kaggle.com/alexattia/the-simpsons-characters-dataset


Related Articles