Uncovering what neural nets “see” with FlashTorch

Open source feature visualisation toolkit for neural networks in PyTorch

Misa Ogura
Towards Data Science

--

Visualisation of what AlexNet “sees” in these images of birds, using FlashTorch. Source

Setting the scene

A couple of weeks ago, I gave a talk at Hopperx1 London organised by AnitaB.org as part of London Tech Week. The slide deck is available here.

I received such a positive feedback after the talk, so I decided to write a slightly longer version of the talk to formerly introduce FlashTorch to the world :)

The package is available to install via pip. Check out the GitHub repo for the source code. You can also play around with it in this notebook hosted on Google Colab, without needing to install anything!

But first, I’ll briefly go through the history of feature visualisation to give you a better context as to the what & why.

Introduction to feature visualisation

Feature visualisation is an active area of research which aims to understand how neural networks perceive images, by exploring ways in which we can look “through their eyes”. It has emerged and evolved in response to an increasing desire to make neural networks more interpretable to humans.

Earliest works include analysing what neural networks are paying attention to within input images. For example, image-specific class saliency maps visualise regions within an input image that contribute the most to the corresponding output, by calculating the gradient of a class output with respect to the input image via backpropagation (more on saliency maps later in the post).

Earliest work on image-specific class saliency maps. Source

Another strand of technique in feature visualisation is activation maximisation. This allows us to iteratively update an input image (originally made with some random noise) to generate an image that maximally activates a target neuron. It provides some intuition on how individual neurons respond to inputs. It’s the technique behind so-called Deep Dream, which was popularised by Google.

Deep Dream: what does the network sees in the sky? Source

This was a huge step forward, but had shortcomings in that it doesn’t provide enough insights into how the entire network operates, as neurons don’t operate in isolation. This led to an effort to visualise interactions between neurons. Olah et al. demonstrated arithmetic properties of activation space by adding or interpolating between two neurons.

Neuron arithmetic. Source

Then Olah at al. went further to define more meaningful unit of visualisation, by analysing the amount each neutron fires within a hidden layer when given a particular input. Visualising a group of neurons that are strongly activated together revealed that there seem to be groups of neurons responsible for capturing concepts such as floppy ears, furry legs and grass.

A group of neurons detecting floppy ears. Source

One of the most recent development within the field is Activation Atlas (Carter et al., 2019). In this study, authors addressed a major weakness in visualising filter activations in that it only gives a limited view of how the network responds to a single input. In order to see a big picture of how the network perceives a myriad of objects and how these objects related to one another in the network’s world’s view, they devised a way to create “a global map seen through the eye of the network” by showing common combinations of neurons.

Different ways to visualise the network. Source

Motivation behind FlashTorch

When I discovered the world of feature visualisation, I got immediately drawn to its potential in making neural nets more interpretable and explainable. Then I quickly realised that there was no tool available to easily apply these techniques to neural networks I’ve built in PyTorch.

So I decided to build one — FlashTorch, which is now available to install via pip! The first feature visualisation technique I implemented is saliency maps.

We’re going to look at what saliency maps are in more detail below, along with how to use FlashTorch to implement them with your neural networks.

Saliency maps

Saliency, in human visual perception, is a subjective quality that makes certain things in the field of view stand out and grab our attention. Saliency maps in computer vision can give indications of the most salient regions within images.

Examples of saliency maps. Source

The method to create saliency maps from convolutional neural networks (CNNs) was first introduced in 2013 in the paper Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps. Authors reported that, by calculating the gradients of a target class with respect to an input image, we can visualise the regions within the input image, which have effects on the prediction value of that class.

Saliency maps using FlashTorch

Without further ado, let’s use FlashTorch and visualise saliency maps ourselves!

FlashTorch comes with some utils functions to make data handling a bit easier too. We’re going to use this image of great grey owl as an example.

Then we’ll apply some transformations to the image to make its shape, type and values suitable as an input to a CNN.

I’m going to use the AlexNet which has been pre-trained with ImageNet classification dataset for this visualisation. In fact, FlashTorch supports all the models that come with torchvision out of the box, so I encourage you to try out other models too!

TheBackprop class is the core to creating saliency maps.

On instantiation, it takes in a model Backprop(model) and registers custom hooks to layers of interest within the network, so that we can grab the intermediate gradients out of the computational graph for visualisation. These intermediate gradients are not immediately available to us, due to how PyTorch is designed. FlashTorch sorts this out for you :)

Now, one final thing we need before calculating the gradients — the target class index.

To recap, we’re interested in the gradients of the target class with respect to the input image. However, the model is pre-trained with the ImageNet dataset and therefore its prediction is provided as a probability distribution over 1000 classes. We want to pinpoint the value of the target class (in our case great grey owl) out of these 1000 values to avoid unnecessary computation and to focus only on the relationship between the input image and the target class.

For this, I also implemented a class called ImageNetIndex. If you don’t want to download the whole dataset, and just want to find out class indices based on class names, this is a handy tool. If you give it a class name, it will find the corresponding class index target_class = imagenet['great grey owl']. If you do want to download dataset, use the ImageNet class provided in the latest release of torchvision==0.3.0.

Now, we have the input image and the target class index (24), so we’re ready to calculate gradients!

These two lines are the key:

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

By default, gradients will be calculated per colour channel, so it’s shape will be the same as the inputs image — (3, 224, 224) in our case. Sometimes it’s easier to visualise the gradients if we take the maximum gradients across colour channels. We can do so by passing take_max=True to the method call. The shape of the gradients will be (1, 224, 224).

Finally, let’s visualise what we’ve got!

From far left: input image, gradients across colour channels, max gradients, an overlay of input image and max gradients

We can appreciate that pixels in the area where the animal is present have the strongest effects on the value of the prediction.

But this is kind of noisy… the signal is spread and it doesn’t tell us much about the neural network’s perception of an owl.

Is there a way to improve this?

Guided backproagation to the rescue

The answer is yes!

In the paper Striving for Simplicity: The All Convolutional Net, authors introduced an ingenues way to reduce noise in gradients calculation.

Guided backpropagation. Source

In essence, in guided backpropagation, neurons that have no effects or negative effects on the prediction value of a target class are masked out and ignored. By doing so, we can prevent the flow of gradients through such neurons, resulting in less noise.

You can use guided backpropagation in FlashTorch, by passing guided=True to the method call of calculate_gradients, like so:

Let’s visualise guided gradients.

The difference is striking!

Now we can clearly see that the network is paying attention to the sunken eyes and the round head of an owl. These are the characteristics that “convinced” the network to classify the object as great grey owl.

But it doesn’t always focus on the eyes or the heads…

As you can see, the network has learnt to focus on traits which are largely in line with what we would deem the most distinguishing things about these birds.

Applications of feature visualisation

With feature visualisation, not only can we obtain better understanding on what the neural network has learnt about objects, but also we are better equipped to:

  • Diagnose what the network gets wrong and why
  • Spot and correct biases in algorithms
  • Step forward from only looking at accuracy
  • Understand why the network behaves in the way it does
  • Elucidate mechanisms of how neural nets learn

Use FlashTorch today!

If you have projects which utilise CNNs in PyTorch, FlashTorch can help you make your projects more interpretable and explainable.

Please let me know what you think if you use it! I would really appreciate your constructive comments, feedback and suggestions 🙏

Thanks, and happy coding!

--

--

Senior Machine Learning Engineer @Healx | Creator of github.com/MisaOgura/flashtorch | Published Scientist | Co-founder of @womendrivendev