Implementing Capsule Network in TensorFlow

A guide to implement Capsule Network and Visualize its features

Parth Rajesh Dedhia
Towards Data Science

--

Photo by Fotis Fotopoulos on Unsplash

We are well aware that Convolution Neural Network(CNN) has outperformed humans in many computer vision tasks. All the CNN based models have the same base architecture of the Convolution layer followed by Pooling layers with intermediate Batch Normalization layers, for normalizing batch in the forward pass and controlling the gradients in the backward pass.

However, there were a couple of drawbacks in CNN primarily the Max Pooling layer as it does not consider the relation between pixel having maximum value and its immediate neighbors. To solve the problem, Hinton comes up with the idea of Capsule Network and an algorithm called “Dynamic Routing Between Capsules”. Many resources have explained the intuition and the architecture of the model. You can have a look at them in the series of blog posts here.

In this post, I have explained the implementation details of the model. It assumes a good understanding of Tensors and TensorFlow Custom Layers and Models.

This post has been structured as follows:

  • Essential TensorFlow Operations
  • Capsule Layer Class
  • Miscellaneous Details
  • Results and Feature Visualization

TensorFlow Operations

Building a model in TensorFlow 2.3 with a Functional API or Sequential model is quite easy with very few lines of code. However, in this capsule network implementation, we make use of Functional API as well as some custom operations and decorated them with the @tf.function for optimization. In this section, I am just going to highlight the tf.matmul function for higher dimensions. If you are familiar with this, then you can skip this section and move ahead to the next one.

tf.matmul

For 2D Matrices, the matmul operation performs matrix multiplication operations provided the shape signatures are respected. However, for tensors with rank (r > 2), the operation becomes a combination of 2 operations i.e., element-wise multiplication and matrix multiplication.

For a rank (r = 4) matrices, it first performs broadcasting along the axis = [0, 1] and makes each of them of equal shape. And the last two axes ([2,3]) undergo matrix multiplication if and only if the last dimension of the first tensor and the second to last dimension of the second tensor should have the matching dimensions. The example below will explain it, for brevity I have only printed the shapes, but feel free to print and calculate the number on the console.

>>> w = tf.reshape(tf.range(48), (1,8,3,2))
>>> x = tf.reshape(tf.range(40), (5,1,2,4))
>>> tf.matmul(w, x).shape
TensorShape([5, 8, 3, 4])

w is broadcasted along axis=0 and x is broadcasted along axis=1, and the remaining two dimensions were matrix multiplied. Let’s check out the transpose_a/transpose_b parameter of matmul. On calling tf.transpose on a tensor all the dimensions are reversed. For example,

>>> a = tf.reshape(tf.range(48), (1,8,3,2))
>>> tf.transpose(a).shape
TensorShape([2, 3, 8, 1])

So let’s just see how it work in tf.matmul

>>> w = tf.ones((1,10,16,1))
>>> x = tf.ones((1152,1,16,1))
>>> tf.matmul(w, x, transpose_a=True).shape
TensorShape([1152, 10, 1, 1])

Wait !!! I was expecting an error but it worked out fine. How ???

What TensorFlow did was first it broadcasted along the first two dimensions and then assumed them as a stack of 2D matrices. You could visualize it as transposed being applied only to the last two dimensions, of the first array. The shape of the first array after the transpose operation was [1152, 10, 1, 16] (Transpose applied to the last two-dimension), and now matrix multiplication is applied. By the way, transpose_a = True means the above-mentioned transpose operation will be applied to the first element provided in matmul. Refer to the docs for more details.

Okay!! That’s enough to get through this post. We can now check out the code for Capsule Layer.

Capsule Layer Class

Let’s see what’s happening in the code.

Note: All the hyper-parameters are used the same as that from the paper.

Convolution operations

We have used tf.keras functional API to create the primary capsule outputs. These will just perform simple convolution operation in the forward pass of input image input_x. Till now we have achieved 256 (32 * 8) features maps, each of 6 x 6 size.

Now instead of visualizing the above feature map as convolution output, we re-imagine them as 32- 6 x 6 x 8 vectors piled along the last axis. Hence, we could easily obtain 6 * 6 * 32 = 1152, 8D vectors just by reshaping them. Each of these vectors is multiplied by a weight matrix which encapsulates the relation between these lower level features and the higher-level features. The dimension of the output features in the Primary Capsule Layer is 8D, and that input to the Digit Caps layer is 16D. So basically we have to multiply them with a 16 X 8 matrix. Okay, that was easy !! But wait, there are 1152 vectors in the Primary Capsule, which implies we will have 1152–16 x 8 matrices.

So are we cool now ?? Nope, you forgot the number of Digits Capsule

We have 10 Digit Capsules in the next layer, and hence we will have 10 such 1152–16 x 8 matrices. So basically we get a weight tensor of shape [1152, 10,16, 8]. Each of the 1152–8D vectors of primary capsule output is contributing to each of the 10 Digit Capsules, so we could simply use the same 8D vector for each capsule in the Digit Capsule Layer. More simply we could just add a 2 new axis in 1152, 8D vectors thus converting them into the shape of [1152, 1, 8, 1]. Okay! I see what you did there, you are going to the broadcasting in tf.matmul you described above.

Great !! That’s correct.

Preparing input for Digit Capsules

Note: The shape of variable W has an extra dimension of 1 along the first axis since then the same weight has to be broadcasted for the entire batch.

In the u_hat, the last dimension is extraneous and was added for the correctness of matrix multiplication and hence can be now be removed using the squeeze function. The (None) in the above shapes is for the batch_size which is determined at training time.

Removing extraneous dimensions

Let’s move to the next step.

Dynamic Routing — This is where the magic begins!

Squash Function -paper

Before exploring the algorithm let’s just make the squash function and keep it for further use. I have added a small value of epsilon to avoid the gradients from exploding in-case if the denominator sums up to zero.

Code for Squash Function.

In this step, the input to the Digit Capsule is the 16D vector ( u_hat ) and the no of routing iterations (r = 3) is used as specified by the paper.

Routing algorithm from paper

There is not much tweaking in the dynamic routing algorithm, and the code is pretty much a direct implementation of the algorithm in the paper. Have a look at the snippet below.

Dynamic Routing

Some key points should be highlighted.

  • The c represents the probability distribution of u_hat values and for a particular capsule in the primary capsule layer, it sums to 1. Simply speaking, the values of u_hat are distributed among the Digit capsule based on the variable c which is trained in the routing algorithm.
  • The Σcij ûj|i is the weighted summation of all the lower level vector which are input to the digit capsule. Since there are 1152 lower level vectors, the reduce_sum function is applied across that dimension. Setting the keep_dims=True, just makes the further computation easier.
  • The squash non-linearity is applied across the 16D vector of Digit Capsule to normalize the values.
  • The next step has a subtle implementation where the dot product between the input and output of the digit capsule layers is calculated. This dot product governs the “agreement” between lower and higher-level capsules. You can understand the reasoning and intuition behind this step here.

The above loop is iterated 3 times and the hence obtained values of v are then used in the reconstruction network.

Wow !! Great. You have just completed most of the difficult part. Now, it’s relatively simple.

Reconstruction Network

The Reconstruction network is a kind of regularizer that regenerates the images from the features of Digit Capsule Layers. While back-propagation it has an impact on the entire network, thus making features good for both prediction as well as regeneration. During training, the model uses the actual label of the input image to mask the digit caps values to zeros except the one corresponding to the label (shown in the figure below).

Reconstruction network from the paper

The v tensor from the above network is of shape (None, 1, 10, 16) and we broadcast and label along the 16D vector of the Digit Caps layer, and apply the masking.

Note: One hot encoded label is used for masking.

This v_masked is then sent to the reconstruction network and which is used for regeneration of the entire image. The reconstruction network is just 3 Dense layer shown in the gist below

Reconstruction from features

We will convert the same above code into a CapsuleNetwork Class which inherits from tf.keras.Model. You could directly use the class with your custom training loop and for prediction.

Capsule Network Class

As you would have noticed that I have added two different functions predict_capsule_output() and regenerate_image() which predict the digit Caps vectors and regenerate the image respectively. The first function will help in the prediction of numbers during test time and the second one will be helpful to regenerate the image from a given set of input features. (Will be used in the visualization)

Parameters used by the model

So one last thing is remaining, and that’s the loss function. The paper uses margin loss for classification and uses the squared difference for reconstruction with a weight of 0.0005 to re-construction loss. The parameters m+, m-, lambda are described in the gist above and the loss function in the gist below.

Margin Loss from the paper
Loss function implementation

The v is the unmasked Digit Caps Vector, the y is the one_hot_encoded vector of the label and y_image is the actual image send as input to the model. The safe norm function is just a function is similar to the TensorFlow norm function but contains an epsilon to avoid the value from becoming exact 0.

Safe Norm Function

Let’s check the summary of the model.

Summary of the model

Congratulation !!! We have completed the model architecture. The model has 8215568 Parameters which corroborated to the paper where they said that the model with reconstruction has 8.2M parameters. However, this blog has 8238608 parameters. The reason for the difference is that TensorFlow considers only tf.Variable resources in the trainable params. If we consider 1152 * 10 b and 1152 * 10 c as trainable then we get the same number.

8215568 + 11520 + 11520 = 8238608

That’s the same number. Yipee!!

Miscellaneous Details

We will be using the tf.GradientTape for finding the gradients and we will use the Adam optimizer.

Training Step

Since we have subclassed out class with tf.keras.Model, we can simply call the model.trainable_variables and apply gradients.

Prediction Function

I have made a custom prediction function that will take the input image as well as the model as a parameter. The purpose of sending the model as a parameter is that the checkpointed model could be used later for prediction.

Phew !!! We are done. Congratulations!

So, you can now try writing your code with this explanation or use it to one on my repository. You can simply run the notebook on your local system or on google colab. To only obtain prediction accuracy, even 10 epochs are sufficed. In the repository, I have added only a single notebook that trains the feature for 50 epochs. However, to tweak and visualize the feature, you may need to train them up to 100 epochs.

Note: The training of the model takes a lot of time even on Google Colab’s GPU. So put the model on training and take a break.

Results and Feature Visualization

The model produces a training accuracy of 99% and the testing accuracy is 98%. However, in some checkpoints, the accuracy is 98.4% while in some other its 97.7%.

In the gist below, the index_ means a particular sample number in the test set and index means the actual number which the sample y_test[index_] represents.

Reconstruction image for each extracted feature

The code below tweaks each of the feature, and tweaking them in the range of [-0.25, 0.25] with an increment of 0.05. At each point, images are generated and stored in an array. Thus we can see how each feature is contributing to the reconstruction of an image.

Tweaking features for reconstruction

See some samples of reconstruction in the image below. As we can see, some of the features control the brightness, angle of rotation, thickness, skew, etc.

Reconstruction of Images by tweaking the features (Image by author)

Conclusion

In this article, we have tried to reproduce the results as well as visualize the features described in the paper. The training accuracy is 99% and the testing accuracy is almost 98% which is really great. Although, the model takes a lot of time to train, but the features are very intuitive.

Github Repository: https://github.com/dedhiaparth98/capsule-network

References

S. Sabour, N. Frost, G. Hinton, Dynamic Routing Between Capsules (2017), arXiv.

--

--

I like reading and implementing the ideas researched in Computer Vision and Deep Learning papers. I post my notes on Medium.