The world’s leading publication for data science, AI, and ML professionals.

How To Train Your Siamese Neural Network

When you first start out with machine learning, it becomes clear from the offset that massive amounts of data are required for accurate…

Hands-on Tutorials

The easy way to work with classes not seen at train time

Photo by Sereja Ris on Unsplash
Photo by Sereja Ris on Unsplash

When you first start out with Machine Learning, it becomes clear from the offset that massive amounts of data are required for accurate and robust model training. And whilst this is true, when training models for purposes where a custom dataset is required you often need to compromise on the level of data your model sees.

This was the case for myself; working in conservation tech, any models we deploy to an area are built using data collected from previous years’ surveys, which in some cases may be sparse (certainly nowhere near the levels of benchmarking datasets such as ImageNet [1]). To make matters worse, working in conservation tech means working with open-ended datasets. Because the animals we work with are free roaming, there is no guarantee that a dataset we use for model training will contain examples of everything the model will see in the field.

This results in somewhat of an uphill battle when trying to deploy models using traditional machine learning approaches. Building a model for conservation is useless if you need thousands of examples for each class and you need to retrain your model each year as the classes change. But this problem isn’t confined to conservation, lots of areas outside of benchmarking have similar issues with amounts of data and rates of change.

In this article I will discuss a type of model known as a Siamese Neural Network. Hopefully after reading, you will have a better understanding of how this architecture can help not just in conservation, but in any area where data quantities are limited and rates of class change are fast.

Prerequisites

Before getting started you should probably have an understanding of machine learning, specifically Convolutional Neural Networks. If you don’t, I found Sumit Saha‘s post A Comprehensive Guide to Convolutional Neural Networks – the ELI5 way to be a good starting point in lieu of a formal education in the area. You should probably read that first.

You should also be comfortable with Python, Keras, and TensorFlow. We will be working through examples of code during this article, as I find doing this gives a better understanding than just free form text on its own. All code in this guide was written in TensorFlow 1.14, but there is no reason why the code shouldn’t work in newer versions (possibly with a few modifications), or indeed ported to other deep learning frameworks such as PyTorch.

What is a Siamese Neural Network?

In short, a Siamese Neural Network is any model architecture which contains at least two parallel, identical, Convolutional Neural Networks. We’ll call these SNNs and CNNs from now on. This parallel CNN architecture allows for the model to learn similarity, which can be used instead of a direct classification. SNNs have found uptake primarily for image data, such as in facial recognition, although they do have their uses outside of this domain. For example, Selen Uguroglu gave a great talk at NeurIPS 2020 about how Netflix utilises SNNs to generate user recommendations based on film metadata. For this guide, we will focus on image data.

Each parallel CNN which forms a part of the SNN is designed to produce an embedding, or a reduced dimensional representation, of the input. For example, if we specify an embedding size of 10 we may input a high dimensional image of size width height channels and receive as output a float value vector of size 10 which directly represents the image.

These embeddings can then be used to optimise a Ranking Loss, and at test time used to generate a similarity score. The parallel CNNs can, in theory, take any form. One important point however is that they must be completely identical; they must share the same architecture, share the same initial and updated weights, and have the same hyperparameters. This consistency allows the model to compare the inputs it receives, usually one per CNN branch. The SigNet paper from Dey et al. [2] provides an excellent visualisation of this, which can be seen below.

The SigNet architecture. Image from Dey et al. [2].
The SigNet architecture. Image from Dey et al. [2].

The goal of SigNet is to determine if a given signature is genuine or a forgery. This can be achieved through the use of two parallel CNNs, trained on genuine and forged signature pairs. Each signature is fed through one branch of the SNN which generates a _d-_dimensional embedding for the image. It is these embeddings which are used to optimise a loss function rather than the images themselves. More recent versions of SNNs will most likely utilise triple or even quadruple branching, containing three or four parallel CNNs respectively.

What’s the Point of SNNs?

Now that we understand the make-up of an SNN, we can highlight their value. Using the generated _d-_dimensional embeddings, we can create some _d-_dimensional hyperspace that allows the embeddings to be plotted creating clusters. This hyperspace can then be projected down to 2-dimensions for plotting using Principle Component Analysis, or PCA.

A plot of an embedding hyperspace after 100 training epochs, projected down to 2-dimensions using PCA. Image generated using code based on this notebook.
A plot of an embedding hyperspace after 100 training epochs, projected down to 2-dimensions using PCA. Image generated using code based on this notebook.

This plot shows the embedding locations for a subset of the MNIST test data [3]. Here, a model has been trained to generate embeddings of images for the 10 unique classes (images of handwritten digits between 0 and 9). Notice how, even after only 100 training epochs, the model is starting to generate similar embeddings for images of the same class. This can be seen by the clusterings of dots of the same colour in the graph above – some clusters in the plot are visualised on top of each other, this is due to the reduction down to 2-d through PCA. Other visualisations such as t-SNE plots, or reducing to a higher number of dimensions, can help in this situation.

It’s this embedding clustering that makes SNNs such a powerful tool. If we suddenly decided that we wanted to add another class to the data, then there is no need to retrain the model. The new class embeddings should be generated in such a way that, when plotted into the hyperspace, they are far away from the existing clusters, but cluster together with other examples of the new class as they are added. By using this embedding similarity, we can begin to produce likely classifications for both seen and unseen classes using very little data.

Model Training

Previously, I mentioned that SNNs consist of at least two parallel CNN branches, but modern implementations often rely on more. The number of branches in your SNN has a big influence on your model training. Not only do you need to ensure that your data is fed to the SNN in such a way that each branch receives training examples, but your choice of loss function must also take the number of branches into account. Regardless of the number of branches chosen, the type of loss function will likely stay consistent.

Ranking Losses, also known as Contrastive Losses, aim to predict relative distances between model inputs when projected onto a hyperspace. This is in comparison to more traditional losses which aim to predict some set of class labels. Ranking Losses play an important role in SNNs, although they are also useful for other tasks such as Natural Language Processing. There are many different types of Ranking Loss, but they all work (generally) in the same way.

Let’s assume we have two inputs, and we want to know how similar they are. Using a Ranking Loss, we would perform the following steps:

  1. Extract the features from the input.
  2. Embed the extracted features onto a _d-_dimensional hyperspace.
  3. Calculate the distance between the embeddings (e.g. using Euclidean distance) to be used as a measure of similarity.

It’s important to note here that we often don’t particularly care about the values of the embeddings, just the distances between them. Taking the plot of the embeddings shown earlier, notice all points lie between about -1.5, 2.0 on the x-axis and about -2.0, 2.0 on the y-axis. There is nothing inherently good or bad about a model that embeds within this range, all that matters is the points are clustering in their respective classes.

Triplet Ranking Loss

One of the more common types of Ranking Loss used for SNNs is Triplet Ranking Loss. You’ll often see SNNs using this loss function called Triplet Networks as if they are their own thing (and indeed this is how they are defined by Hoffer _e_t al. in the paper that first conceived them [4]) but really they’re just an SNN with three branches. Because Triplet Loss is so commonplace now, and it’s the loss function we’ll be using later in this post, it’s important to understand how it works.

An example showing how triplet ranking loss works to pull embedded images of the same class closer together, and different classes further apart. Image by author.
An example showing how triplet ranking loss works to pull embedded images of the same class closer together, and different classes further apart. Image by author.

Triplet Ranking Loss requires, as the name suggests, three inputs which we call a triplet. Each data-point in the triplet has its own job. The Anchor is data of some class C which defines which class the triplet will train the model on. The Positive is another example of the class C. The Negative is a data-point of some class which is not C. At train time, each of our triplet components is fed to its own CNN branch to be embedded. These embeddings are passed to the Triplet Loss Function, which is defined as:

Where D(A,P) is the embedding distance between the Anchor and the Positive, and D(A,N) is the embedding distance between the Anchor and the Negative. We also define some margin – an often used initial value for this is 0.2, the margin used in FaceNet [5].

The purpose of this function is to minimise the distance between the Anchor and the Positive, whilst maximising the distance between the Anchor and the Negative. For a more in-depth look at Triplet Ranking Loss, I’d suggest this excellent post from Raúl Gómez.

Semi-Hard Triplet Mining

Because of the importance of the triplet components, it is imperative that our SNN is provided only with triplets which will enable it to learn. More specifically, we want to provide Negatives such that our triplets allow the model to learn, but not be so difficult that learning takes too long.

An easy way to do this is through a process known as Semi-Hard Triplet Mining. To perform this, we first define three categories of triplet:

  • Easy Triplets are those where D(A,P) + margin < D(A,N), thus L = 0.
  • Hard Triplets are those where D(A,N) < D(A,P).
  • Semi-Hard Triplets are those where D(A,P) < D(A,N) < D(A,P) + margin.

The goal is to find as many Semi-Hard Triplets as possible. These triplets have a positive loss, but the Positive embedding distance is closer to the Anchor embedding than the Negative. This allows for fast training, but is still difficult enough for the model to actually learn something during training.

Finding these Semi-Hard triplets can be performed in one of two ways. In Offline mining, the entire dataset is converted into triplets before training. In Online mining, batches of data are fed in, with random triplets generated on the fly.

As a general rule of thumb, Online mining should be performed wherever possible as it allows for much faster training due to the ability to constantly update our threshold definition of a Semi-Hard Triplet as training progresses. This can be supplemented with data augmentation, which can also be performed in an Online fashion.

Using SNNs at Inference Time

Now that we understand how SNNs are trained, we next need to understand how they can be used at inference time. During training we used all of the branches of the SNN, whereas inference can be performed using a single CNN branch.

At inference time, the input image of an unknown class is processed by the CNN branch and has its features embedded. This embedding is then plotted onto the hyperspace and compared with the other clusters. This provides us with a list of similarity scores, or relative distances between the image of unknown class and all of the existing clusters. The clusters we compare our input image against are known as the support set. Let’s take a look at an example to help understand this.

Finding the most likely family class for the test image moth, based on data from Vetrova et al. [6]. Image generated using code based on this notebook.
Finding the most likely family class for the test image moth, based on data from Vetrova et al. [6]. Image generated using code based on this notebook.

The plot above is the output of an SNN I created to determine the scientific family of moths. Each image in the dataset, adapted from Vetrova et al. [6], was labelled with one of four scientific family names or labelled as ‘larvae’, giving a total of five classes. For ease of visualisation (although in hindsight not necessarily ease of understanding) each of the known labelled classes is displayed in the support set, shown in the middle of the plot above, using a random example image from each class. On the left of the plot is a test image; this is a moth image unseen by the SNN, which is now tasked with determining the scientific family.

First, the SNN embeds the test image using the embedding function learned during training. Next, it compares this embedding with the support set embeddings, which provides a most likely moth family for the test image. To the right of the plot, we can see the first image in the support set has been printed again.

The code used to generate the plot above was told to show the corresponding example of the test image’s family first (the plotting code knows the correct class, the SNN does not). Because the first support set image is shown again on the right of the plot, this tells us that the SNN was correct in determining the scientific family of the test image moth! If this plot is a bit confusing to you, don’t worry as we’ll be working through creating the same plot on different, simpler, data later.

This code could be extended further to alert users if an embedding is placed in a new area of the hyperspace if if exceeds some predefined class distance threshold. This could be an indication that a new moth family has been seen by the SNN for the first time.

Where Do We Measure From?

In order to determine the distance between the test image and the classes in the support set, we need a location for each class to measure from. At first glance, it might seem okay to use a randomly selected embedding from each support set class; after all, if all embeddings are perfectly clustered surely it doesn’t matter which one we use?

Whilst this assumption certainly holds if our class embeddings are perfectly clustered, in a real world system this won’t be the case. Let’s examine the toy example below.

An example embedding space with two classes, crosses and squares, and a yet undetermined class embedding represented by a triangle. Image by author.
An example embedding space with two classes, crosses and squares, and a yet undetermined class embedding represented by a triangle. Image by author.

In this example we have a two class embedding space, one for crosses and one for squares. All of the square class embeddings are clustered to the right of the plot, however the class of crosses has one embedding which has not been clustered with the others in the top left. This erroneous cross has been embedded into the space where the squares usually cluster. There is also a triangle plotted in the top-right, this is the current test image, embedded into the space but not yet assigned to a class based on its distance to the other clusters.

In order to determine if the triangle should actually be a cross or a square, we randomly select an embedding to measure from for each class; the erroneous cross and the bottom-left square are chosen (both circled). If we compare the distances from these embeddings, the chosen cross is closest, so the triangle would be labelled a cross.

However looking at the plot as a whole, it’s clear that the triangle should probably be labelled as a square, and the cross is an outlier. By selecting random embeddings to measure from we run the risk of having outliers skew the distance measurement, and thus the final outcome.

This can be solved using prototypes, an elegant and easy to understand solution to our problem. Prototypes are essentially generalised embeddings for each class, reducing the effect of outliers on the distance measurements. These can be calculated in a variety of ways, but simple techniques such as taking the median work well. Let’s update the toy example…

The same embedding space as previous, but with the inclusion of prototypes. Image by author.
The same embedding space as previous, but with the inclusion of prototypes. Image by author.

Now, each class has been given a prototype near the centre of its cluster (e.g. Pₓ is the prototype for the cross class). If we select the prototypes when measuring similarity, our triangle is correctly labelled as a square. This simple solution can greatly reduce the effect outliers have when calculating similarities.

Determining how prototypes should be calculated is difficult, and solutions such as using the median may break down with certain datasets. For example, if all of our cross class examples formed a circle of radius 1 around the origin and the square class examples formed a circle of radius 2, the prototypes would both now be formed at the origin, resulting in equal distance measurements. We’d need to find another way to calculate the prototypes for that dataset.

Building a Siamese Neural Network

Now that we have a grasp of the underlying theory of SNNs and why they are an important tool, let’s take a look at how we build one. As mentioned previously, we’ll be using Python, Keras, and TensorFlow 1.14 for this although there’s really nothing preventing this code being converted for use in another framework like PyTorch; I use TensorFlow out of personal preference rather than because it’s better for making SNNs. We’re also going to stick with using MNIST as our dataset, both for consistency and for ease of training.

The code here is based on a variety of sources, which I will link as we go, but the underlying construction is based on the approach described in Amit Yadav’s Coursera, which is itself based on FaceNet [5].

If you prefer to have full code rather than snippets, this is available from my Github.

Step 1: Importing packages

First, we’re going to need to import the required packages. For a complete list of package versions used on the virtual machine to run this code, see here. I tested this code with Python 3.6.7.

Step 2: Importing data

Next, we need to import a dataset for our SNN to work with. As previously mentioned we’ll be using MNIST, which can be loaded using TensorFlow’s mnist.load_data().

After the data is loaded in, it is reshaped and flattened. This allows the data to be read into the SNN more easily.

Note that we only have a height and width here as MNIST is greyscale, therefore only has 1 colour channel. If we had a dataset with multiple colour channels we would need to adapt our code for this, for example using x_train_w_h_c instead.

Step 3: Create the triplets

Now we need to create our MNIST triplets. Two methods are required for this.

The first, create_batch(), generates triplets by randomly selecting two class labels, one for the Anchor/Positive and one for the Negative, before randomly selecting a class example for each.

The second, create_hard_batch(), creates a batch of random triplets using create_batch(), and embeds them using the current SNN. This allows us to determine which triplets in the batch are Semi-Hard; if they are we keep num_hard of them, populating the rest of the batch with other random triplets. By padding with random triplets, we allow for training to begin as well as ensure our batches are of a consistent size.

Step 4: Defining the SNN

The SNN is defined in two parts. First, we must create the embedding model. This model receives an input image and generates a d-dimensional embedding. We create a very shallow embedding model here, but more complex models can be created.

Next, we create a model which receives a triplet, passes it to the embedding model sequentially for embedding, then passes the resultant embeddings to the triplet loss function.

Step 5: Defining the triplet loss function

In order for the SNN to train using the triplets, we need to define the triplet loss function. This mirrors the triplet loss function equation shown previously.

Step 6: Defining the data generator

In order to pass our triplets to the network, we need to create a data generator function. Both an x and y is required here by TensorFlow, but we don’t need a y value, so we pass a filler.

Step 7: Setting up for training and evaluation

Now that we have defined the basics of the SNN, we can set up the model for training. First, we define our hyperparameters. Next, we create and compile the models. I specify that this is performed using the CPU, but this may not be required depending on your setup.

Once the models are compiled, we store a subset of the test image embeddings. The model hasn’t been trained yet, so this gives us a good baseline to show how the embeddings have changed through the training process. Embedding visualisations via PCA are based on this notebook by AdrianUng.

Further evaluation can be performed on our SNN. Code used in this step is heavily influenced by Eric Craeymeersch‘s One Shot Learning, Siamese Networks and Triplet Loss with Keras and this notebook from asagar60.

Let’s take a look at the evaluation of the untrained model. From the plots, we can see our model is unable to distinguish between similar and dissimilar images. This is most pronounced in the third plot, highlighting the test images and their most likely classes, with very little difference between their scores.

Image generated using code based on this notebook.
Image generated using code based on this notebook.

Now we have compiled the models, we can also generate example random and Semi-Hard triplets. This code is based on a blog post by Ruochi Zang.

This produces the following:

Image generated using code based on this blog post.
Image generated using code based on this blog post.

Our example random triplet contains an Anchor and Positive of class 4, and a Negative of class 6. Our Semi-Hard triplet contains an Anchor and Positive of class 8, and a Negative of class 6, but note how similar they are in composition.

Step 8: Logging output from our model training

Let’s set up some logging and custom callbacks before we train our model, to help us if we need to come back at a later date. The Tensorboard logging callback is adapted from erenon’s helpful Stack Overflow answer, whilst the saving of the best model based on the validation loss is adapted from another Stack Overflow answer from OverLordGoldDragon.

Step 9: Training the SNN

Now that all of our setup has been completed, it is time to start training! I first begin by selecting the total number of GPUs available, and parallelising the model training over them. You may need to amend this should you not have access to multiple GPUs.

Note that when running model.fit() we provide train and test data generators rather than the train and test data directly. This allows for online triplet mining to occur.

Step 10: Evaluating the trained model

Once the model has trained, we can then evaluate it and compare its embeddings. First, we load in the trained model. I do this by reloading the saved logging files, but if you’re just running this all in one notebook as a closed system, there isn’t really a need to reload once the model is trained.

Once the models are loaded, we perform the same PCA decomposition we did on the untrained model to visualise how the embeddings have changed.

At the end of the above code-block, we run evaluate() again, which produces the below graphs:

Image generated using code based on this notebook.
Image generated using code based on this notebook.

Note how the first plot now shows an AUC of 0.985 and increased distance between our classes. Interestingly, when looking at the test images and their most likely classes, we can see for the 2nd and 3rd test images the corresponding class has been correctly achieved (for example, taking the 2nd test image, of class 0, we can see the lowest score for all the support set classes is also at class 0), however looking at the 1st test image, all of the scores for the support set classes are very close, indicating the trained model has had difficulties classifying this image.

To confirm our model has trained correctly and class clusters are now forming, lets plot the PCA decomposed embeddings we have stored previously.

This code produces the following output:

Image generated using code based on this notebook.
Image generated using code based on this notebook.

The left plot shows the embedding locations before training, decomposed into 2-dimensions using PCA for visualisation, and with each colour representing a distinct class as shown by the legend. Note how the embeddings are all jumbled up, there is no clear clustering structure, which makes sense as the model has not learned to separate the classes out. This is in contrast to the right plot, which shows the same data points embedded by a trained SNN. We can see clear clustering on the outskirts of the plot, but the middle is still looking a bit messy. Our plot indicates the model has learned very well to cluster embedded images of class 1 for example (the cluster in the bottom left), but still struggles with embedded images of class 5, which are still mostly in the centre. This is backed up by our previous plots, which shows the model struggling to determine a most likely match for the class 5 test image.

It would be good to quantify how well our model is performing. This can be achieved using an n-way accuracy score, utilising the prototypes we discussed before. In n-way accuracy, val_steps number of randomly selected test images are compared to a support set of size n. This provides an indication of model accuracy when n is the same as the total number of classes, num_classes in the code below. MNIST has 10 classes, giving a 10-way accuracy.

When we run the above code, the SNN achieves a 10-way accuracy of 97.4%, a commendable score. Due to the random nature with which the test images are chosen, you could perform cross validation here if you wish.

Finally, lets take a look at producing a support set image, similar to that shown in the previous moths example, only this time generated using MNIST. Again, we will use a 10 class support set.

This produces the following plot:

Image generated using code based on this notebook.
Image generated using code based on this notebook.

If the moth example discussed earlier in this article was confusing, hopefully the same plot using MNIST is clearer. The code has randomly selected a class 2 test image to classify, which is compared to the prototypes of all other classes in the support set. Again, the plotting code knows the test image is of class 2, and so the support set 2 is shown first. On the right, the same support set 2 is shown again, indicating the SNN has correctly determined a most likely class of 2 for the test image, which it should do for other test images approximately 97.4% of the time!

Conclusion

In this article, we have learned what a Siamese Neural Network is, how to train them, and how to utilise them at inference time. Even though we have utilised a toy example through the use of MNIST, I hope that it is clear how powerful SNNs can be when working with open-ended datasets where you may not have all classes available to you at the time of dataset creation, and how new unseen-at-train-time classes would be handled by the model.

I hope that I have provided you with a good balance of theoretical knowledge and practical application, and I’d like to thank everyone whom I have mentioned throughout for providing open-source code and Stack Overflow answers. This article, and indeed my own work in conservation tech, would not have been possible without it.

When I first started writing, I was worried there wouldn’t be enough content for it to be worthwhile. Now that it is finished, I realise how wrong I was! Hopefully if you have made it this far (and haven’t just skipped to the end) this article has taught you something, and if this is the case please do let me know on Twitter or LinkedIn.


References

[1] Deng, J., Dong, W., Socher, R., Li, L.J., Li, K. and Fei-Fei, L., 2009, June. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition (pp. 248–255). IEEE.

[2] Dey, S., Dutta, A., Toledo, J.I., Ghosh, S.K., Lladós, J. and Pal, U., 2017. Signet: Convolutional siamese network for writer independent offline signature verification. arXiv preprint arXiv:1707.02131.

[3] LeCun, Y., Bottou, L., Bengio, Y. and Haffner, P., 1998. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), pp.2278–2324.

[4] Hoffer, E. and Ailon, N., 2015, October. Deep metric learning using triplet network. In International Workshop on Similarity-Based Pattern Recognition (pp. 84–92). Springer, Cham.

[5] Schroff, Florian, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 815–823. 2015.

[6] Vetrova, V., Coup, S., Frank, E. and Cree, M.J., 2018, November. Hidden features: Experiments with feature transfer for fine-grained multi-class and one-class image categorization. In 2018 International Conference on Image and Vision Computing New Zealand (IVCNZ) (pp. 1–6). IEEE.


Related Articles