Learn how to achieve good accuracy on a classification task with just a few samples per class and imbalanced distribution of classes

Nowadays there are several deep learning models like BERT, GANs, and U-Nets that are achieving a state-of-the-art performance of tasks like image recognition, image segmentation, and language modeling. Hardly a day goes by without a new innovation in Machine Learning. Tech Giants like Google, Microsoft, and Amazon are coming up with complex deep learning architectures that achieve human-like performance. But one problem with these models is that they require a ton of labeled data. Sometimes much data is not available for a specific task. Less data means the deep learning model will not be able to model different classes properly and will perform poorly. This is where Siamese networks come to the rescue. It helps to build models with good accuracy even with fewer samples per class and imbalanced class distribution.
What is a Siamese Network?

A Siamese network is a class of neural networks that contains one or more identical networks. We feed a pair of inputs to these networks. Each network computes the features of one input. And, then the similarity of features is computed using their difference or the dot product. For same class input pairs, target output is 1 and for different classes input pairs, the output is 0.
Remember, both networks have same the parameters and weights. If not, then they are not Siamese.
By doing this, we have converted the classification problem to a similarity problem. We are training the network to minimize the distance between samples of the same class and increasing the inter-class distance. There are multiple kinds of similarity functions through which the Siamese network can be trained like Contrastive Loss, triplet loss, and circle loss.
- Contrastive loss: In Contrastive loss, pairs of images are taken. For same class pairs, distance is less between them. For different pairs, distance is more. Although binary cross-entropy seems like a perfect loss function for our problem, the contrastive loss does a better job differentiating between image pairs. Contrastive loss, L = Y D^2 + (1-Y) max(margin – D, 0)^2
D is the distance between image features. ‘margin’ is a parameter that helps us pushing different classes apart.

- Triplet loss: Triplet loss was introduced by Google in 2015 for face recognition. Here, the model takes three inputs- anchor, positive, and negative. The anchor is a reference input. Positive input belongs to the same class as anchor input. Negative input belongs to a random class other than the anchor class.

The idea behind the Triplet Loss function is that we minimize the distance between the anchor and the positive sample and simultaneously also maximize the distance between the anchor and the negative sample.
Let’s take a look at the mathematical formula of triplet loss.
Here d denotes a distance metric. The distance between anchor and positive should be less than negative and anchor. So, d(a,p) – d(a,n) < 0. To keep it positive, we can modify it like this: Loss L= max(d(a,n) – d(a,p), 0).
To further increase the separation between positive and negative, a parameter ‘margin‘ is introduced.
Now, L = max(d(a,n) – d(a,p) + margin, 0).
Pros and Cons
Before moving to implementation, let’s talk about a few advantages and disadvantages of this method.
Pros:
- Fewer data samples required.
- Can work with highly imbalanced data.
Cons:
- A large amount of training data because of so many pairs of classes.
- Not generalizable. A model trained for one task cannot be used for another task.
- Sensitive to some variations in the input.
Implementation
Problem statement: This is a very simple problem to demonstrate the concept. We have a database of different shapes(triangle, circle, and rectangle) with 3 different colors (red, green, and blue). These shapes are of different sizes, at different locations, and rotated at different angles.
We want to classify images based on their colors by learning a color encoding for images. We have 20 samples from each color. Here is a GitHub repository containing this dataset.
There are several other datasets also to try like the famous Omniglot dataset, MNIST, etc.
Part 1: Prepare data
Read the images and generate positive and negative pairs. There are two NumPy arrays containing pair of images of size Nx28x28x3. The target label is 1 if the pair is of the same color and 0 if there are of a different color. For example:
Input: _red_triangle, redsquare (same color objects)
Output: 1
Input: _red_triangle, bluetriangle (different color objects)
Output: 0
Part 2: Create the model
From lines 21–32, we have created the model for a single image. Firstly, we have our input layer. Then, Conv2D and MaxPool layers. After that, we flatten it and add a dense layer. We have normalized the dense layer after that because these will act as our features for the images. The features are of length 16. And, later on, when we connect these two models, we take the dot product of these features. Since we have normalized the features already, they will be between 0 and 1. So, we can easily compare them to our target labels.
In lines 36 and 37, we create two instances of the same model and pass the inputs to them. In line 45, we take the outputs of both models and apply a dot product to them. The dot product will tell the similarity of the two images.
This is a simple example so we got accuracy up to 99%. It will not be that good for complex images and tasks. But, the goal of this article is to demonstrate Siamese Networks using an elementary example.
Here is the diagram of a single encoder from the siamese network:

This is how the complete siamese model looks like:

Part 3: Test the model
Load the model and test it on unseen images. We can do the following things to check the accuracy and separation between classes-
a) Firstly, we can use a single encoder model to encode an image to get features to plot. We can make a scatter plot of these features to see how well their separation is. (The features are of length 16)
b) Secondly, we can use the siamese network output (between 1 and 0) to make a confusion matrix.
Both are shown below.
Here is a visualization of the first 3 features in 3D-




Here is the link to the complete repository:
Conclusion
- This model is useful when less data is available and classes are imbalanced.
- It has applications like image classification, object detection, text classification, voice classification,
- Siamese networks can be used to encode a particular feature also. A similar model can be created to classify different shapes also.
- One-shot learning also uses Siamese networks. It further extends this idea by using one sample from each class. Zero-shot learning is also similar. It trains the model without seeing any input at all. It is a fairly new topic and has gained a lot of attention.
Future Project Ideas
After the siamese model is trained, the final layer can be discarded and the features can be used with a variational autoencoder to reconstruct the input. This way you can also change the style and other features of the image by reconstructing the original image using some other image’s features. For example, converting the shape or color of the image.
Next, I am going to write on voice classification using Siamese networks and using VAE with siamese networks.
Thanks for reading! I hope it’s helpful in understanding the underlying logic.
If you enjoyed this article please recommend and share.