Bayesian Inference and Transformers

Can Transformers help doing Bayesian Inference? A simple review of Transformers Can Do Bayesian Inference

Kaan Bıçakcı
Towards Data Science

--

Photo by Arseny Togulev on Unsplash

This is the last post of the series in Uncertainty in Deep Learning:

Last part will be a little bit different from the other parts of the series. I will be describing a method from a paper and the intution behind the proposed method (in order to compare it with Variational Inference) in the following paper: TRANSFORMERS CAN DO BAYESIAN INFERENCE.

Note: I will try to briefly explain the paper, this is not meant be something official.

This article is organized as follows:

  • Recap
  • Meta-Learning and Bayesian Inference
  • Posterior and Posterior Predictive Distribution
  • Proposed Method: Prior-Fitted-Networks
  • Using Transformer and Putting Pieces Together
  • Conclusion

Recap

Traditional Deep Learning Models

As explained in Part 2 and Part 3, standard models are deterministic. In other words, you should get same outputs for an input. Because with Maximum Likelihood Estimation (MLE), you will get point-estimate weights.

Let’s illustrate this:

Image by author.

On the other hand, a Probabilistic Bayesian Deep Learning Model can output different predictions for the same input. By looking at the outputs we can interpret the uncertainty in the predictions itself.

Bayesian Approach in Deep Learning

Bayesian approach to statistics infers from a set of data the underlying probability distribution that is consistent with the data.

This approach offers several advantages over more traditional methods, such as the ability to incorporate prior information about the distribution into the analysis and to compute the degree of belief in the various possible distributions.

In Bayesian Deep Learning we want to calculate the posterior function:

Image by author.

However this is difficult and most of the times it is impossible to calculate the true posterior, therefore it is approximated. That approximation is called Variational Bayes, or more specifically Variational Inference. However those methods are expensive in terms of computation.

Meta Learning and Bayesian Inference

Before we start off, it will be helpful to just define briefly what meta learning is. As we are aware of, computers and humans learn differently. As a concrete example:

Photo by Fede de Rodt on Unsplash

If we had labeled this image, it would be a cup of coffee. And the fun part is that a human can easily learn from a few samples it is a cup of coffee. Well, this does not apply for the computers.

Consider those images:

Images generated by DALLE-2 (https://openai.com/dall-e-2/)

A human who learnt from few images can easily classify those correctly. Computers which had few images for training may struggle to classify them.

Images generated by DALLE-2 (https://openai.com/dall-e-2/)

Or those images, each of them represent a cup of coffee and a human should be able to classify them correctly, however a computer might misclassify them.

It would be very nice if computers also generalize the dataset with few samples. That’s the point and the main idea of meta-learning. It is also called learning-to-learn.

We can train a model on images which do not have any dogs in it, and then it should be able to tell us if the given image contains dog or not after seeing some cute dog photos.

Images generated by DALLE-2 (https://openai.com/dall-e-2/)

Actually there are variety of approaches in meta-learning era. The mentioned paper utilizes a meta-learning technique which is highly research oriented that tries to allow models learn to learn in a single forward pass.

Posterior and Posterior Predictive Distribution

Now, let’s differentiate those two words.

  • Posterior
  • Posterior Predictive Distribution
Image by author. (Posterior)

When we want to find distributions over model parameters, we re-write the posterior using Bayes’ Theorem. The evidence, marginal likelihood, P(x) acts a normalizing constant.

Here the posterior distribution depends on theta which is the unknown parameter.

Predictive distributions are different. First let’s see prior predictive distribution:

Image by author. (Prior Predictive Distribution)

The integral on the right hand side tells us to average all possible values of Theta. That’s nice, but why?

Before starting training, we only have prior distribution which means only way to measure uncertainty in model parameters is to use prior distribution!

By time model sees new data, so that we want to average all possible values of Theta. Because we are interested in understanding the distribution of the data which is given by:

Image by author.

After some samples are taken and training is performed, we have a better representation of the uncertainty of the model parameters by using the posterior. Following the same logic above, posterior predictive distribution becomes:

Image by author.

Xn stands for the data points which are independent. There are some conclusions:

  • If we can find a good prior, our training data will follow this pattern closely.
  • It is expected that posterior predictive will have different variance (greater mostly) because of the incorporation of the new data point.

Summarizing

In order to sum up, posterior deals with distribution of the model parameters. On the other hand posterior predictive distribution explains the distribution of the data point which we will get in the future (Xn).

So the main idea of posterior predictive is, we are uncertain about model parameters (epistemic uncertainty) and want to predict new data point by using average value of likelihood, but at the same we consider all possible values for model parameters.

The crucial distribution for prediction, the posterior predictive distribution (PPD), can then be inferred as [1]:

Posterior predictive formula from the paper [1].

Proposed Method: Prior-Fitted-Networks

In Bayesian Deep Learning, most of the times the true posterior is approximated as it is intractable. At the end of the day, we are interested in with the term right hand side which is shown in posterior predictive formula.

Interestingly, this paper proposes a new method, called Prior-Data-Fitted-Networks which instead trains a model to directly optimize the term on the left. This means that, instead of just using a dataset, we also use a query.

The process is drawn by the authors of the paper:

“A visualization of Prior-Data Fitted Networks (PFNs). We sample datasets from a prior and fit a PFN on hold-out examples of these datasets. Given an actual dataset, we feed it and a test point to the PFN and obtain an approximation to Bayesian inference in a single forward propagation.” [1]

Let’s break this down into smaller pieces:

  • First, we pick a prior.
  • Sample datasets from the prior.
  • Start meta-learning process by maximizing the likelihood. This also equals to minimizing the negative likelihood. Logarithm is applied because it has nice properties when multiplication is involved. So negative log-likelihood is minimized.
  • The dataset is fed into a parameterized model.

In other words, they generate large amount of data (well, infinitely many is desired) and use those parameterized models to label generated data.

Explained algorithm from the paper [1].

The model can be a convolutional network, or any other type of neural network architectures.

Authors chose Transformer as the PFN, because a Transformer is all we need after all.

Using Transformer and Putting Pieces Together

As we are aware of original Transformer architecture [2] consists of an encoder and a decoder. The authors slightly modified this architecture to make it permutation invariant. So this means positional encoding is no longer utilized.

“A visualization of the Transformer for n = 3 input pairs and m = 2 queries. Every bar is the representation of one input and the arrows show what each representation can attend to.” [1]

The proposed model returns the posterior predictive distribution for each given query x, and that only depends on the dataset and the query itself.

During the training, total number of inputs are kept fixed (N inputs). Those inputs can be divided into a summation formula which is:

  • N = m + n

where n is the number inputs, and m is the number of the given queries. In other words, it is desired that Transformer should learn to work with different sized datasets.

Since there are variety of tasks, most popular ones are regression and classification, the output layer is flexible and changed according to the task.

Conclusion

The paper proposes a method which we can train a special Transformer at the same time employing meta-learning techniques. This allows us to make a posterior inference faster than Variational Inference (1000x faster).

Benchmark results confirms that:

“Aggregated performance on subsets of datasets with 30 training samples. Our novel PFN-BNN performs strongest overall. See Table 7 in the appendix for per-dataset results.” [1]

You can check the official implementation of the paper from here.

References

[1]: Samuel Müller, Noah Hollmann, Sebastian Pineda, Josif Grabocka, Frank Hutter, TRANSFORMERS CAN DO BAYESIAN INFERENCE, ICLR 2022.

[2]: Ashish Vaswani et al. Attention Is All You Need, Advances in Neural Information Processing Systems, 2017.

--

--