The world’s leading publication for data science, AI, and ML professionals.

Chinese Zodiac Sign Classification Challenge with Pytorch

One of Bertelsmann AI Udacity Scholarship's challenge, this is one way of of how one can classify Chinese Zodiac signs with Pytorch

In the Bertelsmann AI Udacity Scholarship, the scholars not only have to finish the AI Udacity course, but they also challenge each other to apply the skills and knowledge they collect and practice during the course. One of these challenges is the Chinese Zodiac Sign Classification Challenge. On the occasion of, you know, the Chinese Lunar New Year.

Source
Source

The Chinese zodiac is divided into twelve yearly cycles with each year identified by a particular animal.

I thought this would be an interesting challenge, so I dived right in.

For those of you who want to try this challenge out, you can get the data here.


Set up the environment

First thing’s first. Using GPU would make the task easier, so I tried using both Google Colab and Kaggle Notebook to build this model. But since I like to be able to edit (add, change, move around, or delete) my data, I prefer using Google Colab since the data is stored in my Google Drive.

Source
Source

Load and check the data

We have 12 classes of images we would like our Pytorch model to classify. It is a good idea to divide the data into three groups – training data in which we train our model, validation and test data in which we will use to ensure our model does not overfit.

Let’s first check the distribution of these data.

Zodiac signs image distribution
Zodiac signs image distribution

Ok, so for each training, we have 600 training images (one short for goat), 54 validation images (one extra for goat, which is likely because that one image got misplaced), and 54-55 test images.

Looks good to me.

Let’s see a sample of those images.

Sample images
Sample images

Looks good to me as well. We can also see that we are incorporating drawing images as well (cause dragons don’t exist and all).

Creating the model using Pytorch

One of the good things of Pytorch (as well as other machine learning/Deep Learning frameworks) is that it gives us simplified boilerplate codes. One of which is loading train-val-test data.

Now to build the model.

I will be using the pretrained Resnet34 model to transfer its learning to build this classification model. I have also tried other pretrained models like the Resnet101 and VGG 19 with BatchNorm, but the Resnet34 gives me a pretty good performance so I’m going with that. The Resnet34 requires the input images to have width and height of 224.

Full model architecture is below.

I won’t use anything very complex here. Just two additional FC layers with 512 neurons each, and one output layer with 12 neurons (one for each zodiac sign class, of course).

Training the model

Now comes the first exciting part, training the model.

We simply need to iterate over the train data loader whilst doing (a) a forward and backward pass on the model, and (b) measure the current/running performance of the model. I choose to do point (b) every 100 minibatch.

I chose to train the model over 7-15 epochs. You’ll see why in this chart below.

Training and Validation Losses and Accuracies (Resnet50 try 1)
Training and Validation Losses and Accuracies (Resnet50 try 1)

Over every 100 minibatch passes, the model definitely improves over time on the training dataset. But when we look at the performance on the validation dataset, it does not seem to improve that much (the accuracy seems to improve by a tiny bit over time though, but I don’t think it is enough).

But when we take a look at other models with different architecture, the same thing happens.

Training and Validation Losses and Accuracies (Resnet50 try 2)
Training and Validation Losses and Accuracies (Resnet50 try 2)
Training and Validation Losses and Accuracies (VGG 19 with Batch Norm)
Training and Validation Losses and Accuracies (VGG 19 with Batch Norm)

The second model is with a Resnet50, pretty much the same with the first except I changed the learning rate from 0.001 to 0.003. The third is using VGG 19 with Batch Norm and learning rate of 0.001.

Three different models+parameters tells the same story – Accuracy performance on validation does not improve significantly after multiple epoch as much as training (especially for the last two models).

We do not concern ourselves too much with the models’ loss, as it is a measure of just how ‘confident’ the model performs, and we are focusing more on the accuracy.

Test the model

Let’s see if the models are actually good, or breaks apart meeting the test dataset.

  • ResNet (lr 0.001) – loss: 0.355 acc: 90.5%
  • ResNet (lr 0.003) – loss: 0.385 acc: 90.6%
  • VGG 19 with Batch Norm – loss: 0.586 acc: 90.8%

Looking at the accuracy, they are pretty much the same. For loss, the ResNet with 0.001 lr reigns supreme. These numbers are almost the same as our training and validation, so we can say our model does not overfit (or at least we stopped it just before it did with our small epoch)and is working pretty good to classify Chinese zodiac signs.

I am curious, as to which images it is having the most problem on.

Confusion Matrix
Confusion Matrix

Loving the numbers in that confusion matrix.

We can see that the model rarely mistakes a goat, but if it does, it is always with ox (horns).

Goat images mistaken for an Ox by the model
Goat images mistaken for an Ox by the model

We can also see that the model had a bit of difficulty distinguish dragon with other zodiac signs. Most notably with oxes (horns) and snake (body).


Final words

Working on the Chinese Zodiac Sign Classification Challenge during the Bertelsmann AI Udacity Scholarship was awesome and refreshing. It pushed and motivated me to do what I love – experimenting and building, in the field of AI/Machine Learning/deep learning.

More challenges will come, and I can’t wait for it.

You can check my notebooks on my Github page. Please review and follow :))


Related Articles

Some areas of this page may shift around if you resize the browser window. Be sure to check heading and document order.