The world’s leading publication for data science, AI, and ML professionals.

PonderNet explained

Implementing a pondering network for the MNIST dataset


You can find the code for this article in this GitHub repo. Clone it now!

You can run the code as you read with this Colab notebook. Open it now!


Photo by Tingey Injury Law Firm on Unsplash
Photo by Tingey Injury Law Firm on Unsplash

If we are serious about reaching a dystopian future with robot overlords, we have to recognize that current AI is never going to make the cut. There is something missing in most modern-day neural networks, a key attribute that keeps them from attaining world-domination: they are not able to ponder.

Luckily, DeepMind has recently released PonderNet, a framework that could potentially allow any network to ponder. All of a sudden, the future is bright again.

Jokes aside, pondering is an important concept that could have serious repercussions in how we design new models. In this article we’re going to go through the theory (you’ll be surprised by how simple it is!) and implement a version of PonderNet that performs image classification on the MNIST dataset. In the later sections we will perform some experiments to determine how impactful the ability to ponder can be.

We will be using PyTorch Lightning as our framework (because it is simply just great); if you’re unfamiliar with it, now is a great time to learn the basics! For logging we will use Weights & Biases (also because it is simply just great). I strongly recommend you try it out, but if you prefer a different logger you will only need to change one line of code (because, again, PyTorch Lightning is simply just great).

From PyTorch to PyTorch Lightning – A gentle introduction


1. Motivation

This is all nice and all, but what does it even mean to ponder? The authors put it in the following way:

To ponder is to adjust the computational budget based on the complexity of the task.

Given this definition, it becomes clear that machine learning researchers and engineers have been pondering all along: every time they choose a specific number of hidden layers, select different GPUs to train their models on, or make any decision that affects the architecture of a network.

It is also clear that most networks are not able to ponder. We could argue that images of the digit 6 are really easy to recognize, while one may need some more time to tell a 1 and a 7 apart; nevertheless, a CNN will spend the same amount of resources predicting the label for both kinds of images. This is mainly due to the rigid structure of neural networks, and how they act as black-box mappings.

PonderNet innovates from previous research and is able to allocate more resources to inputs that it thinks need them. This is a key attribute if we want models to learn beyond the current state-of-the-art. It’s also a step towards thinking of neural networks in the context of classical algorithms, which can bring a lot of fresh ideas to the field (do not miss the beautiful appendix E).


2. The PonderNet framework

In this section we’re going to lay down all the theory behind PonderNet. For starters, let us try to understand how it is able to adjust its computational budget in a high level.

Intuition

Consider we have a task we want to solve (e.g. classifying digits in MNIST!) and a model to solve it (e.g. good ol’ CNN). A conventional approach would simply process the inputs once and produce an output. In contrast, the PonderNet framework allows the input to be processed multiple times, and is able to find the appropriate time to stop and output a result.

Figure 1: The options of PonderNet. (Image by author)
Figure 1: The options of PonderNet. (Image by author)

In broad terms, PonderNet does the following:

  • Process the original input.
  • Produce a prediction and a probability to stop the computation at the current step.
  • Flip a coin to decide whether to "Halt" or "Continue".
  • If we halt, output the latest prediction; otherwise process the input again along with some contextual information.

Figure 1 summarizes this idea by representing the halting options as a binary tree.

Formal definition

More formally, we can define the PonderNet framework as a step function s (usually a neural network) that satisfies the following equation:

Here, the x represents the original input, both h represent a hidden state that is propagated through the different steps, the y is the output for the current step and the λ is probability of halting at the current step.

In words, this means that at each step PonderNet takes the original input and the latest hidden state, and produces a prediction, a probability of halting and the updated hidden state for the new step. It will flip a biased coin with probability λ to decide whether to halt and output y or to continue by propagating h further.

Figure 2: PonderNet unrolled. (Image by author)
Figure 2: PonderNet unrolled. (Image by author)

Figure 2 is great to visualize how the hidden state flows through all the steps, and how each step produces a pair of output and probability of halting. If you’re familiar with RNNs, you will find some similarities in the structure.

Recurrent Neural Networks (RNNs)

The meaning of λ

It is important to note that, technically, λ is the probability of halting at the current step given that no halting happened in the previous steps. This is what allows us to treat λ as the probability for a Bernoulli random variable that tells us whether we should halt at the current step or not. If we wish to find the unconditional probability of halting at the current step (henceforth p), it will also have to involve the probability of not halting in the previous steps:

Pondering steps

We will usually only allow a specified maximum number of "pondering steps". This means that we will force λ to be 1 at the last step, so that halting is guaranteed.

During inference, we do not explicitly need such a limit, since we can let the network run indefinitely until one of the coin flips makes it halt naturally; we would then output the y obtained in that step. Nevertheless, it is still a good idea to have a bound on the pondering steps, since theoretically it could run forever.

During training, the bound on pondering steps is required due to the loss function, as we will see now.

Training PonderNet

Like virtually every neural network, PonderNet tries to optimize a loss function L:

Let us break this down. L can be separated into two losses, the reconstruction loss and the regularization loss (similarly to VAEs); the trade-off between these two is regulated by the hyperparameter β. In both cases, the losses can be separated into N terms; that’s our maximum number of pondering steps.

The reconstruction loss is very intuitive: for every pondering step we made we calculate the loss of the output from that step weighted by the unconditional probability that we stopped in it. In this sense, the reconstruction loss is nothing more than the expected loss across all steps. Note that throughout these calculations we are using the underlying loss function for the task we are trying to solve (in the case of multi-class classification, cross-entropy).

The regularization loss introduces bias in terms of how halting should behave. It tries to minimize the KL-divergence between the halting distribution generated by PonderNet and a prior geometric distribution with some hyperparameter λp. For those unfamiliar with the KL-divergence, feel free to check out the link below; intuitively, what we are saying is that we want all the λ produced by the network to be close to λp.

Intuitive Guide to Understanding KL Divergence

The repercussions of this regularization term are two-fold. On one hand, we bias the expected number of pondering steps towards 1/λp (since that is the expected value of a geometric distribution). On the other hand, it promotes exploration by giving positive probability to halting in any step, no matter how far it is.


3. Implementing PonderNet for MNIST

That wasn’t that bad, wasn’t it? We’re now ready to start getting our hands dirty! Keep in mind that you run the code as you read by following the article’s Colab notebook. If you’d rather run it locally, all the code (and some extra things!) can be found in this GitHub repo; if you find any bugs don’t hesitate to open an issue!

The Data Module

Let us first get the Data Module out of the way. There is nothing fancy going on in here, except maybe for the fact that this Data Module allows you to have multiple test dataloaders (it’s a surprise tool that will help us later!).

Losses

We will model each term in the loss separately, as two different classes extending nn.Module. Let us begin with the reconstruction loss. The only thing we need to do is calculate the weighted average, which is easily achieved:

As for the regularization loss, there are some extra details to talk about. In order to calculate the KL-divergence between the generated halting distribution and our prior, we will first have to manufacture the values of our prior, which is done in the initialization function.

Finally, we create a class to wrap the two losses together so that they can be passed around functions compactly.

Helper modules

To make our life easier, we create a class for a simple CNN. This will be used to embed the image into a vector representation within PonderNet.

Similarly, we create a basic multi-layer perceptron to combine the image embedding with the hidden state inside of PonderNet.

PonderNet

Finally, we get to the important stuff. In order to comment more in detail the different parts of the LightningModule, some of the functions within it will be displayed in separate code snippets; just keep in mind they are all under the same class!

Our particular implementation of PonderNet consists of multiple submodules. First, a CNN embeds the image into a vector representation. This vector is concatenated with the hidden state corresponding to the previous step, and fed through an MLP to obtain the hidden state for the current step. This is in turn pushed through two different linear layers to obtain the logits for the predictions on one hand, and the logits for the lambdas on the other. The snippet below shows how these submodules are declared to become part of PonderNet, as well as the losses and some metrics:

We now reached what is arguably the most complex part of the code: the forward pass. Obtaining h, y and λ is straight-forward, keeping in mind that to obtain λ you have to use the sigmoid function, or else you are not forcing it to be a probability.

Since we operate in batches, we need to find a way to keep track of each element individually when it comes to the ponder step where the network stopped. To this end, we maintain a vector halting_step that for each element is 0 if PonderNet has not halted on it yet, and the step where it halted otherwise. Of course, updating halting_step includes drawing from a Bernoulli distribution at each step to determine whether the network should stop.

Finally, we also maintain a vector of un_halted_prob that helps us obtain the value of p for each step in a cheap and fast way.

At the end we return all the predictions for all the batch elements and all steps, all p‘s for all batch elements and all steps and the halting step for all batch elements.

Since the output of the forward pass needs to be processed in order to obtain the loss (and we will want to do this in several occasions), we define a helper function that calculates the loss, the predictions and some other metrics.

The if condition helps circumvent a technical issue I personally came across. In early stages of training, some of the p values can be 0 for selected steps. When this is put through the regularization loss the zeros become minus infinities and the network eventually starts returning NaNs. Ignoring the elements that contain p values of 0 solves the issue, and these same elements are correctly classified in later epochs.

Finally, the rest of the usual functions needed in a PyTorch Lightning Module can be found below. These include the logging and some callbacks to make the training more user-friendly.

Running an interpolation experiment

We’re basically done! That’s the beauty of PyTorch Lightning. With the snippet below we will be able to run a basic experiment on the vanilla MNIST; our logger will keep track of the losses (together and individually), the accuracy and the pondering steps throughout training. If you want to use a different logger, you just need to import the supported class from PyTorch Lightning and instantiate it accordingly in line 40; it’s that easy!

Below is the accuracy and average number of ponder steps we obtain when evaluating on the MNIST test set. As we can see, PonderNet does pretty well. More interestingly though, the average number of ponder steps is very close to 5, which is the value of 1/λp, the expected number of steps; this means that our regularization is indeed working!

4. Are we really pondering?

Running the code above should convince you that PonderNet can at least attain a respectable accuracy on MNIST… But are we using any of its properties? Couldn’t we achieve a similar result with a simple CNN? How can we know that it does indeed ponder longer for harder inputs? We will try to answer some of these questions by the means of an extrapolation experiment.

Reading the original paper, one can get a bit frustrated with how experiments are either carried out on toy datasets or tasks that are so complex that it is impossible to interpret anything. The MNIST dataset provides a middle ground where the task has definitely not been engineered to produce the desired results with PonderNet, and its results can be interpretable to some extent.

Our premise for the experiment will be that rotated images are harder to classify. Following a similar framework to the one presented in the PonderNet paper, we will train on "slightly hard" inputs and evaluate on a range of "harder" inputs. In particular, we will train on images that have been rotated up to 22.5 degrees, and evaluate on rotations up to 22.5, 45, 67.5 and 90 degrees. We expect the accuracy to fall gradually (some images are probably even impossible to classify once rotated to that extent), but hopefully we see an increase in the number of steps, signalling that the network finds the task harder and decides to allocate more resources to it. Below is the code for such an experiment:

The results we obtain after running this experiment are a bit controversial. On one hand, testing on any of the rotated datasets does indeed require an average number of steps that is significantly higher than the ones needed for the interpolation experiment, meaning that PonderNet recognizes classifying rotated images to be a harder task.

On the other hand, nevertheless, the network seems to need less steps for more pronounced rotations, which is counter-intuitive. We would expect PonderNet to be unsure about highly tilted images, thus allocating more resources to their classification and pondering for longer, but it seems like it instead becomes confindently incorrect. The accuracy also decreases, although in this case it is as expected.

Note: here the dataloaders correspond to 22.5, 45, 67.5 and 90 degree rotations respectively.
Note: here the dataloaders correspond to 22.5, 45, 67.5 and 90 degree rotations respectively.

5. Conclusions

PonderNet is a nice addition to the Deep Learning landscape. Its mathematically-backed design choices are reason enough to be excited about **** its possibilities, and the results presented in the original paper are encouraging. Sadly, though, there hasn’t been much effort in honestly interpreting what this network is really capable of.

We tried to shed some light on this issue by implementing PonderNet for MNIST, a common benchmark in the field, and performing experiments by training the network in tasks of varying difficulty.

Our results are inconclusive since in some ways they agree with what we expected, and in others they don’t. There may be multiple reasons for this. On one hand it is perfectly possible that the degree of rotation is not a good heuristic for complexity in this particular task, and my human bias is interfering with the experiment design. On the other hand, we cannot discard the possibility that PonderNet is just not the all-powerful framework we would wish it to be; after all, there is still a lot of research to be done in this area.

All in all, I hope you learned something today; maybe I inspired you to try PonderNet for your own project! I would love to hear any comments, questions, or suggestions you may have. Thank you for reading!

6. Acknowledgements

This blog post would not have been possible without three main sources.

  • The first one is this YouTube video by Yannic Kilcher explaining the basics of PonderNet; if you’re not subscribed to his channel I don’t know what you’re doing!
  • An important part of the PonderNet implementation is borrowed from this GitHub repo by MildlyOverfitted. Although I had a prototype running before I cam across his code, I couldn’t resist using his cool tricks!
  • Finally, some boilerplate code is taken from this tutorial on PyTorch Lightning with Weights & Biases.

As a last note, I would like to thank my supervisor Fabian Laumer for encouraging me to try different things with this newtork as part of my seminar presentation. I may have lost dozens of hours over this little project, but I certainly learned a lot!

References

[1] A. Banino, J. Balaguer, C. Blundell, PonderNet: Learning to Ponder (2021), arXiv: 2107.05407.

[2] A. Graves, Adaptive Computation Time for Recurrent Neural Networks (2017), arXiv: 1603.08983.


Related Articles