How to cluster images based on visual similarity

Use a pre-trained neural network for feature extraction and cluster images using K-means.

Gabe Flomo
Towards Data Science

--

Photo by Pietro Jeng on Unsplash

Objective

In this tutorial, I'm going to walk you through using a pre-trained neural network to extract a feature vector from images and cluster the images based on how similar the feature vectors are.

The model

The pre-trained model that will be used in this tutorial is the VGG16 convolutional neural network (CNN), which is considered to be state of the art for image recognition tasks. We are going to be using this model as a feature extractor only, meaning that we will remove the final (prediction) layer so that we can obtain a feature vector.

The Data

This implementation will use the flowers dataset from Kaggle which you can download here. The dataset contains 210 images of 10 different species of flowers that will be downloaded as png files.

Imports

Before we get started, we need to import the modules needed in order to load/process the images along with the modules to extract and cluster our feature vectors.

import statements
  • load_img allows us to load an image from a file as a PIL object
  • img_to_array allows us to convert the PIL object into a NumPy array
  • preproccess_input is meant to prepare your image into the format the model requires. You should load images with the Keras load_img function so that you guarantee the images you load are compatible with the preprocess_input function.
  • VGG16 is the pre-trained model we’re going to use
  • KMeans the clustering algorithm we’re going to use
  • PCA for reducing the dimensions of our feature vector

Loading the data

Now that the data is downloaded on your computer, we want python to point to the location where the images are located. This way instead of loading a whole file path, we can simply just use the name of the file.

loading the data
# view the first 10 flower entries
print(flowers[:10])
output:
['0001.png', '0002.png', '0003.png', '0004.png', '0005.png', '0006.png', '0007.png', '0008.png', '0009.png', '0010.png']

Now that we have all of the filenames loaded into the list of flowers, we can start preprocessing the images.

Data Preprocessing

This is where we put the load_img() and preprocess_input() methods to use. When loading the images we are going to set the target size to (224, 224) because the VGG model expects the images it receives to be 224x224 NumPy arrays.

loading the images

Currently, our array has only 3 dimensions (rows, columns, channels) and the model operates in batches of samples. So we need to expand our array to add the dimension that will let the model know how many images we are giving it (num_of_samples, rows, columns, channels).

image reshaping

Number of dimensions: 4
Number of images (batch size): 1
Number of rows (0th axis): 224
Number of columns (1st axis): 224
Number of channels (rgb): 3

The last step is to pass the reshaped array to the preprocess_input method and our image is ready to be loaded into the model.

preprocessing the input

The Model

Now we can load the VGG model and remove the output layer manually. This means that the new final layer is a fully-connected layer with 4,096 output nodes. This vector of 4,096 numbers is the feature vector that we will use to cluster the images.

Now that the final layer is removed, we can pass our image through the predict method to get our feature vector.

Heres the all the code in a single function

feature extraction pipeline

Now we can use this feature_extraction function to extract the features from all of the images and store the features in a dictionary with filename as the keys.

Wall time: 56.2 s

Dimensionality Reduction (PCA)

Since our feature vector has over 4,000 dimensions, your computer will thank you if you reduce the number of dimensions from 4,000 to a much smaller number. We can't simply just shorten the list by slicing it or using some subset of it because we will lose information. If only there was a way to reduce the dimensionality while keeping as much information as possible.

Enter the realm of principle component analysis.

I'm not going to waste time explaining what PCA is because there are already tons of articles explaining it, which I’ll link here.

Simply put, if you are working with data and have a lot of variables to consider (in our case 4096), PCA allows you to reduce the number of variables while preserving as much information from the original set as possible.

The number of dimensions to reduce down to is up to you and I'm sure there's a method for finding the best number of components to use, but for this case, I just chose 100 as an arbitrary number.

print(f"Components before PCA: {f.shape[1]}")
print(f"Components after PCA: {pca.n_components}")
Components before PCA: 4096
Components after PCA: 100

Now that we have a smaller feature set, we are ready to cluster our images.

KMeans clustering

You’ll define a target number k, which refers to the number of centroids you need in the dataset. A centroid is the imaginary or real location representing the center of the cluster.

This algorithm will allow us to group our feature vectors into k clusters. Each cluster should contain images that are visually similar. In this case, we know there are 10 different species of flowers so we can have k = 10.

kmeans.labels_

[6, 6, 8, 6, 6, 5, 4, 6, 5, 6, 4, 6, 6, 3, 3, 5, 6, 6, 4, 4, 8, 1,
3, 8, 4, 2, 8, 4, 2, 6, 9, 7, 4, 4, 0, 5, 4, 9, 8, 5, 9, 5, 3, 6,
5, 1, 3, 9, 6, 5, 0, 1, 3, 9, 6, 7, 4, 6, 4, 5, 8, 5, 3, 6, 5, 4,
6, 5, 2, 1, 4, 3, 9, 5, 4, 6, 2, 4, 5, 0, 5, 1, 2, 9, 5, 4, 8, 1,
7, 1, 3, 5, 4, 8, 5, 4, 6, 9, 5, 9, 5, 8, 1, 4, 9, 8, 5, 4, 5, 6,
4, 1, 8, 9, 4, 6, 5, 7, 5, 6, 4, 8, 1, 4, 5, 5, 8, 6, 5, 2, 4, 8,
5, 1, 1, 6, 6, 7, 8, 1, 9, 1, 6, 4, 8, 3, 6, 1, 0, 0, 8, 1, 3, 4,
9, 9, 0, 4, 0, 6, 4, 9, 0, 3, 5, 0, 3, 9, 9, 4, 9, 5, 0, 9, 5, 4,
5, 1, 8, 3, 6, 4, 5, 2, 6, 6, 9, 5, 0, 3, 1, 3, 5, 4, 5, 0, 9, 4,
2, 1, 0, 9, 4, 9, 1, 2, 6, 1, 6, 0]

Each label in this list is a cluster identifier for each image in our dataset. The order of the labels is parallel to the list of filenames for each image. This way we can group the images into their clusters.

# view the filenames in cluster 0
groups[0]
output:
['0035.png',
'0051.png',
'0080.png',
'0149.png',
'0150.png',
'0157.png',
'0159.png',
'0163.png',
'0166.png',
'0173.png',
'0189.png',
'0196.png',
'0201.png',
'0210.png']

All we have left to do is to view a cluster to see how well our model did by inspecting the clusters.

Cluster 0

Cluster 1

Cluster 2

Here we can see that our model did pretty well on clustering the flower images. We can even see that cluster 2 and cluster 0 both have yellow flowers yet, the type of flowers in each cluster are different species.

Conclusion

Here is the whole process in one file.

Hope you all learned something new, leave a comment if you have any questions or had an aha moment :)

References

--

--