Bayesian Inference Algorithms: MCMC and VI

Intuition and diagnostics

Wicaksono Wijono
Towards Data Science

--

Unlike other areas of machine learning (ML), Bayesian ML requires us to know when an output is not trustworthy. When you train a regression or xgboost model, the model can be taken at face value given the settings and data. With Bayesian ML, the output is not guaranteed to be correct.

Bayesian workflow can be split into three major components: modeling, inference, and criticism. Even when we have written a sensible probabilistic model, the results can be misleading due to the inference algorithm, whether because the algorithm has failed or because we have chosen an inappropriate algorithm. This article will explain how each algorithm works, discuss the pros and cons of each, and how to diagnose their performance.

The major goal of Bayesian computation is to find a workaround to the denominator in Bayes’ Theorem:

Except for the simplest of models, that integral is impossible to compute. Inference algorithms are ways to get p(θ|X) without ever having to evaluate that integral.

MCMC

Markov Chain Monte Carlo, as the name implies, runs a Monte Carlo simulation using a Markov Chain that must satisfy some conditions so we always end up at our desired stationary distribution (the posterior) regardless of starting point.

Imagine the posterior distribution is some kind of hilly terrain. You want to explore the terrain and spend time at any spot proportional to the height of the mound. The caveat is that it’s so foggy you can’t see anything. Even if you stand on top of a hill, you don’t know whether it is a tall hill or a short hill compared to the others. You can know that you are standing at an elevation of 2km, but are the other hills a mere 1km high, or are there 5km hills? Equipped with only a device to measure altitude, you need to come up with a rule to achieve your goal.

The Markov Chain must have:

  • ergodicity. You can get to a state from any state (and can therefore go back to your original state) in finite time. There is no deterministic cycle.
  • detailed balance. p(x|y)q(y) = p(y|x)q(x)

Some rules are more efficient than others. In practice, most MCMC is NUTS (No-U-Turn Sampler) with some Gibbs thrown in when needed (mostly when parameters are discrete).

Gibbs

Suppose we want to find the posterior distribution of two parameters, x and y. We can jump from spot to spot without restriction, or we can decompose each jump into a horizontal movement + a vertical movement. This is what Gibbs does. We toss our hands up and say: we don’t know how to come up with a rule for the former, but we can create rules for the latter.

source

Gibbs requires conditional conjugacy. We go through all the parameters one by one. The order doesn’t matter — intuitively (but not really), horizontal + vertical = vertical + horizontal. We sample each parameter by analytically solving for the posterior, holding all other parameters fixed.

As a simple example, suppose we want to estimate mean μ and precision λ = 1/σ² of a normal distribution. We place the priors μ~N(0,1) and λ~Gamma(1,1). Then the conditional posterior is

We want to alternate between sampling from these two conditional posteriors. Here’s R code to show how Gibbs sampling works for this model:

library(ggplot2)num_sample <- 5000    # number of samples for Gibbs
burn_in <- 1000 # first n samples we discard
prior_mean <- 0 # the prior on mu
prior_precision <- 1 # the prior on mu
prior_shape <- 1 # alpha in prior for precision
prior_rate <- 1 # beta in prior for precision
num_obs <- 30 # size of our data
true_mean <- 3
true_precision <- 0.25
set.seed(9)
X <- rnorm(num_obs, true_mean, 1/sqrt(true_precision))
mu <- rep(NA, num_sample)
lambda <- rep(NA, num_sample)
#initialize some values
mu[1] <- 0
lambda[1] <- 1
for(i in 2:num_sample){
if(i %% 2){
mu[i] <- rnorm(
1,
mean = (prior_precision * prior_mean + lambda[i-1] * sum(X)) /
(prior_precision + num_obs * lambda[i-1]),
sd = sqrt(1 / (prior_precision + num_obs * lambda[i-1]))
)
lambda[i] <- lambda[i-1]
} else{
mu[i] <- mu[i-1]
lambda[i] <- rgamma(
1,
shape = prior_shape + num_obs / 2,
rate = prior_rate + sum((X - mu[i])^2) / 2
)
}
}
posterior <- data.frame(mu, lambda)[(burn_in+1):num_sample,]ggplot(posterior) +
geom_point(aes(x = mu, y = lambda)) +
geom_path(aes(x = mu, y = lambda), alpha = 0.3) +
ggtitle('Gibbs sampling') +
xlab(expression(mu)) +
ylab(expression(lambda))
ggplot(posterior) +
geom_histogram(
aes(x = mu, y = stat(count) / sum(count)),
alpha = 0.5) +
geom_vline(
aes(xintercept = quantile(posterior$mu, 0.025)),
color = 'red') +
geom_vline(
aes(xintercept = quantile(posterior$mu, 0.975)),
color = 'red') +
ylab('Relative frequency') +
xlab(expression(mu)) +
ggtitle(bquote('95% credible interval of ' ~ mu))
ggplot(posterior) +
geom_histogram(
aes(x = lambda, y = stat(count) / sum(count)),
alpha = 0.5) +
geom_vline(
aes(xintercept = quantile(posterior$lambda, 0.025)),
color = 'red') +
geom_vline(
aes(xintercept = quantile(posterior$lambda, 0.975)),
color = 'red') +
ylab('Relative frequency') +
xlab(expression(lambda)) +
ggtitle(bquote('95% credible interval of ' ~ lambda))

Note the burn-in setting. MCMC hopefully will converge to the target distribution but it might take a while to get there. As a rule of thumb, we discard the first 1000 because the chain might not have reached its destination yet.

Try changing the values to get the intuition of how the posterior behaves. If we trace the path of the movements, we see the horizontal-vertical patterns:

We can compute the credible intervals using the marginal distributions:

Gibbs can be preferable to other methods when your model is conditionally conjugate. For instance, trying to run NUTS on LDA does not work because there is no gradient with respect to discrete latent variables. However, running a Gibbs sampler for LDA is (comparatively) quick and easy.

Metropolis

The Metropolis algorithm looks at Bayes’ Theorem and asks “can we make the denominator cancel each other out?” The algorithm:

  1. Start at some random initial point θ.
  2. Draw a proposed value θ* from some distribution p(θ*|θ).
  3. If p(X|θ*)p(θ*) > p(X|θ)p(θ), accept proposed value. Otherwise, accept the proposed value with probability [p(X|θ*)p(θ*)] / [p(X|θ)p(θ)] .
  4. If accepted, move to the new spot. Otherwise, stay put. Either way, record your position.
  5. Repeat (2), (3), and (4) for a set number of iterations.

Look at the acceptance probability:

Remember p(X) is an unknown constant. Because our goal is to spend time at a spot proportional to the posterior density, we can do so without ever computing p(X).

To easiest way to satisfy ergodicity and detailed balance is to sample θ* from N(θ, s²). A symmetric continuous distribution will do.

Metropolis-Hastings (MH) generalizes this algorithm to non-symmetric proposal distributions while maintaining detailed balance. Look closely at the acceptance rule for MH and compare it to what detailed balance means.

Here’s some R code to see how the Metropolis algorithm works, using a N(0,1) prior on θ and given that we know X ~ N(θ,1) :

library(ggplot2)num_iter <- 2000
s <- c(0.1, 1, 10)
set.seed(1)
x <- rnorm(10, 10, 1)
thetas <- list()
for(i in 1:length(s)){
theta <- rep(NA, num_iter)
current_theta <- rnorm(1, 0, 10)
for(j in 1:num_iter){
proposed_theta <- rnorm(1, current_theta, s[i])
accept_prob <- exp(
dnorm(proposed_theta, 0, 1, log = TRUE) +
sum(dnorm(x, proposed_theta, 1, log = TRUE)) -
dnorm(current_theta, 0, 1, log = TRUE) -
sum(dnorm(x, current_theta, 1, log = TRUE))
)
if(runif(1) < accept_prob){
current_theta <- proposed_theta
theta[j] <- proposed_theta
} else {
theta[j] <- current_theta
}
}
thetas[[i]] <- cbind(1:num_iter, theta, rep(s[i], num_iter))
}
thetas <- data.frame(do.call('rbind', thetas))
colnames(thetas) <- c('iter', 'theta', 's')
thetas$s <- factor(thetas$s)
ggplot(thetas) +
geom_line(aes(x = iter, y = theta)) +
facet_grid(s~.)
ggplot(thetas[thetas$iter > 1000,]) +
geom_line(aes(x = iter, y = theta)) +
facet_grid(s~.)

One problem with the Metropolis algorithm is that it is sensitive to our choice of proposal distribution. Using different standard deviations yields different results:

As you can see, it takes some time for the chain to reach its target distribution. Analytically, we know that the posterior should be N(0.92, 0.09²). We should discard samples from the burn-in period, say the first 1000, to see it more clearly:

Had we made s very small, say 0.01, perhaps the chain will not have reached the target distribution yet even after 1000 iterations. On the other hand, having a large s results in very jagged transitions as we reject most of the proposals. The best s lies somewhere between 0.1 and 1. We need to find it through trial and error, which can get cumbersome. Although we won’t use Metropolis, it is important to get the intuition of how MCMC is sensitive to the settings.

Hamiltonian / NUTS

Anything I say here will pale in comparison to McElreath’s excellent article, so I suggest reading that for animations and details. The basic idea is that Hamiltonian Monte Carlo (HMC) is a physics simulation. You have a ball rolling around some terrain that is the negative log posterior shifted up by some unknown constant. For a 2-D normal distribution, imagine flipping it over so you have a bowl. You flick the ball in a random direction with a random momentum, stop it after a certain time, and record where it ends up. You reject the sample when the total energy (potential + kinetic) differs too much from your starting energy, indicating a failure in your simulation. Stan calls these “divergent transitions”.

MH fails in extremely high dimensions because eventually you’ll get close to 0% acceptance rate. Human intuition breaks down in high dimensions. Our 3-D brain might imagine a high-dimensional multivariate normal distribution as a solid ball, but it is actually a very thin shell of a sphere due to concentration of measure. If we project this down to 2D it will look like a donut instead of a circle. The fact that HMC performs well on donuts is profound.

Large step sizes lead to bad approximation (source)

Physics simulations have to approximate trajectories by discretizing the steps. If you check every 10 “seconds”, then your simulated trajectory might differ too much from the actual trajectory. If you check every 0.00001 “seconds”, then your simulation will take a long time to compute a single trajectory even though it will be much more accurate. You need to tune the simulation settings just right to get good results out of HMC.

NUTS automatically tunes the settings for you during the warm-up phase (instead of burn-in) and gives you better samples by preventing U-turns. If you didn’t check the article linked above, I suggest looking at it now. There is virtually no reason to use MH or HMC over NUTS in this day and age.

That said, NUTS still runs into some issues.

First, it can get stuck near a single mode of a multimodal posterior (again, refer to the article). You can detect this problem through diagnostics (run multiple chains!) and then ask NUTS to obtain more samples to more fully explore the posterior.

Second, NUTS needs to evaluate the gradient of the terrain after each step. Computing gradients is extremely costly and we must use the entire dataset, so it’s not scalable.

Neal’s funnel (source)

Third, it cannot explore something like Neal’s funnel. This is a particularly degenerate case of a hierarchical model. You can read up on the example in the Stan documentation. Essentially, y is a log scale parameter to model the variance between groups while x is the group mean parameter. NUTS learns a single step size but this kind of terrain requires different step sizes depending on where the ball is on the funnel. A step size that works in the wide section will not work in the narrow section, while a step size that works in the narrow part will explore the wide part far too slowly.

You can find this problem if you see too many divergent transitions and they all happen in the same region. The modeler needs to reparameterize the model so NUTS can work its magic, possibly by finding (relatively) uncorrelated parameterizations or rescaling the data so the parameters are on the same scale. For fixed effects models, this is done by QR decomposition on the predictor matrix.

As an aside: reparameterization is often a good idea to speed up computation. Think of it this way: the ball can roll all over the place if we provide a nice bowl-shaped surface. Exploring a straw is more difficult. Without divergent transitions, samples from both methods should be correct, but runtime can be magnitudes faster with QR decomposition.

Diagnostics

R hat

R hat looks suspiciously like the F statistic in ANOVA, so that’s the intuition I will give. If all of the chains are the same, then the “F statistic” should be 1.

However, the converse is not true. It is possible for Rhat to be 1, and yet the chains do not converge to anything. We still need to check other diagnostics, but at the very least if Rhat > 1.00 then we quickly know that something is wrong. (It’s funny that the older textbooks said > 1.1, then it changed to >1.01, and now moving to >1.005. Let’s stick with 1.00.)

n_eff

The effective number of sample size is defined as:

where ρ_t is the autocorrelation at lag t. In practice, we truncate the summation where we think the autocorrelation has gone to 0.

MCMC does not draw independent samples because of the Markov property — at the very least your samples depend on the previous sample. Historically people skirted this issue by thinning their samples, e.g. keeping only every tenth sample. Now we know better; we should keep all the samples and use n_eff for CLT-like purposes.

This is a useful diagnostic because if your n_eff is much lower than your total number of samples (minus the burn-in / warm-up), then something has gone terribly wrong (or you need to draw more samples).

One of my favorite passages in a textbook is from McElreath’s Statistical Rethinking 2:

When people start using Stan, or some other Hamiltonian sampler, they often find that models they used to fit in Metropolis-Hastings and Gibbs samplers…no longer work well. The chains are slow. There are lots of warnings. … Those problems were probably always there, even in the other tools. But since Gibbs doesn’t use gradients, it doesn’t notice some issues that a Hamiltonian engine will. A culture has evolved in applied statistics of just running bad chains for a very long time — for millions of iterations — and then thinning aggressively, praying, and publishing. This must stop. [example [redacted] — 5 million samples, neff of 66!]

He publicly shamed a paper; that paper deserves every bit of shaming it gets. Savage.

This point can be a strong argument for using NUTS as the default. I suspect that Gibbs can draw reasonable samples but get stuck in some local space so the diagnostics are fine, but NUTS will try to explore other places and throw a warning.

Sometimes n_eff will be higher than your number of samples. That’s no cause for alarm.

ELPD

Think of expected log pointwise predictive density (ELPD) as a generalization of log-likelihood for the Bayesian case. Bayesian models output probability distributions, while metrics like RMSE / cross-entropy evaluate the performance of point predictions. ELPD evaluates the entire predicted distribution.

Other similar metrics exist, but ELPD is all you need. AIC and BIC evaluate point predictions. WAIC has nice asymptotic properties but has erratic behavior for smaller samples. Their similarity is in the interpretation. ELPD is meaningless on its own, but it can be used to compare different models like you would use AIC.

Because we care about how well the model generalizes, we want to obtain ELPD from cross-validation. Otherwise, we will be overly optimistic about model performance — evaluating a model on the training set will yield optimistically low error estimates.

Contrary to conventional wisdom in machine learning, leave-one-out cross-validation (LOOCV) is much more computationally efficient than k-fold cross-validation for Bayesian models. K-fold requires us to refit the model k times, but fitting the model is the expensive part. We can approximate LOOCV using importance sampling on the samples we have already obtained from the posterior, so we don’t need to refit the model.

Nowadays, people use Pareto-Smoothed Importance Sampling (PSIS-LOOCV). These acronyms keep getting longer. Other than improving the stability of ELPD estimates, PSIS-LOOCV provides an additional diagnostic: k. This algorithm takes the 20% highest importance weights and fits a generalized Pareto distribution to it. When k > 0.5, the GPD has infinite variance and signals that the ELPD estimates might be untrustworthy, though from empirical testing the approximation isn’t that bad until k > 0.7. When k is large, it can indicate highly influential observations that mess with the model.

Rank plots

A cousin of trace plots. While trace plots can be useful to detect degenerate cases, it can be hard to interpret. Rank plots are easier to inspect. Using the code for Metropolis from before, and creating four chains with s = 0.2:

library(data.table)
sampled_thetas <- data.table(thetas[iter > 1000])
sampled_thetas[,rank := frank(theta)]
ggplot(sampled_thetas) +
geom_histogram(aes(x = rank)) +
facet_grid(.~s)

ggplot(thetas[(iter > 1000) & (iter < 1100)]) +
geom_step(aes(x = iter, y = rank, color = s))

These two plots should show some uniform mixing. Otherwise, something has gone wrong.

VI

With MCMC, understanding the algorithm is key. With VI, I think understanding the objective function is more important than the algorithms.

Variational Inference (VI) takes a different approach from MCMC but still uses Monte Carlo in most applications. Instead of sampling from the posterior, we propose a simpler and tractable family of distributions to approximate the posterior. The question is then framed as an optimization problem. Thus, VI can be scaled to big data whereas NUTS cannot possibly work on big data.

While several objective functions exist (the main alternative is Renyi divergence), the most commonly used one is the Kullback-Leibler divergence (KL divergence), defined as:

The exact posterior is typically denoted p while our variational approximation is denoted q. Understanding the properties of KL divergence is vital for working with VI, and we will start with two:

First, it is not symmetric. KL(p||q) requires us to take the expectation w.r.t. p, while KL(q||p) requires us to take the expectation w.r.t. q. Hence, it is not a distance metric.

Second, it is nonnegative. From the expression it’s not immediately obvious and you can refer to the proof here.

KL(p||q) is called the forward-KL and it is intractable because we need to integrate over p (if we know p, why are we even doing this?). Instead, VI seeks to minimize KL(q||p), the reverse-KL. For example, we might want to approximate a highly complex posterior with a normal distribution; we seek the variational parameters μ and σ² that will minimize reverse-KL. In more precise notation, letting ν be the variational parameters, we want to find the ν that minimizes:

But we still have that pesky unknown posterior in the denominator, so we cannot directly work with the KL. Just like with Metropolis, we apply a trick so we never have to compute p(X):

Recall that p(X) is a constant so we can take it out of the expectation. Rearranging gives us the Evidence Lower BOund (ELBO), the objective function of VI:

Because KL is nonnegative, the maximum possible value of the ELBO is log(p(X)), the log evidence. Hence why it’s called ELBO: the log evidence must be at least as high as the ELBO. However, it can be a very loose bound and the gap will vary across hypotheses and models. Never compare hypotheses using ELBO. Instead, compute the Bayes factors using the fitted posteriors.

The ELBO has two important properties:

First, the ELBO is entropy minus the cross-entropy. It’s interesting. Think about it for a bit. The entropy wants q to be spread out as much as possible while the cross-entropy wants q to converge to a point mass on the mode of p. This has a similar weighted-average feel of the prior and MLE, as is the theme in Bayesian statistics.

Second, the ELBO encourages q to have variances that are too low. In places where p has high density, overshooting by x results in a % error that is small relative to overshooting by x in a region where p has low density. To compound this issue, we are taking the expectation w.r.t. q, so placing less mass in low density regions of p will place lower weight on this error.

VI has trouble with multimodal posteriors and highly correlated posteriors.

As an illustration, let’s try both NUTS and ADVI on a multimodal posterior. It commonly shows up in mixture models, but we’ll use the simplest example possible:

Stan file:

data {
int<lower=0> N;
real x[N];
}
parameters {
real mu;
}
model {
mu ~ normal(0, 1);
for (n in 1:N)
x[n] ~ normal(fabs(mu), 1);
}

R code:

library(rstan)set.seed(555)
x <- rnorm(100, 2, 1)
data <- list(x = x, N = length(x))
model <- stan_model(file = 'bimodal_example.stan')
mcmc <- sampling(
model,
data,
iter = 2000,
warmup = 1000,
chains = 4,
seed = 1)
advi <- vb(
model,
data,
output_samples = 4000,
seed = 3
)
mcmc_samples <- sapply(
mcmc@sim$samples,
function(x) x[['mu']][1001:2000]
)
advi_samples <- advi@sim$samples[[1]][['mu']]
hist(mcmc_samples, breaks = 100, xlim = c(-3, 3))
hist(advi_samples, breaks = 30, xlim = c(-3, 3))

NUTS reports an Rhat of 1.53, letting us know that the model is ill-specified. ADVI converged without any warnings! Because the fitted variational distribution has almost 0 mass on the right part of the graph, the parts it failed to cover has almost 0 weight in the ELBO. Thus, it reports much lower variance than it should, making us overconfident in the wrong conclusion.

Next, let’s compare NUTS vs ADVI on highly correlated posteriors.

Stan code:

data {
int<lower=0> N;
real y[N];
real x1[N];
real x2[N];
}
parameters {
real beta0;
real beta1;
real beta2;
real<lower=0> sigmasq;
}
transformed parameters {
real<lower=0> sigma;
sigma = sqrt(sigmasq);
}
model {
beta0 ~ normal(0, 1);
beta1 ~ normal(0, 1);
beta2 ~ normal(0, 1);
sigmasq ~ inv_gamma(1, 1);
for (n in 1:N)
y[n] ~ normal(beta0 + beta1 * x1[n] + beta2 * x2[n], sigma);
}

R code:

library(rstan)set.seed(555)
x1 <- runif(100, 0, 2)
x2 <- x1 + rnorm(100, 0, 1)
y <- mapply(
function(x1, x2) rnorm(1, x1 + 2 * x2, 1),
x1,
x2
)
data <- list(y = y, x1 = x1, x2 = x2, N = length(y))
model <- stan_model(file = 'correlated_example.stan')
mcmc <- sampling(
model,
data,
iter = 2000,
warmup = 1000,
chains = 4,
seed = 1)
advi <- vb(
model,
data,
output_samples = 4000,
seed = 3
)
nuts_beta1 <- sapply(
mcmc@sim$samples,
function(x) x[['beta1']][1001:2000]
)
nuts_beta2 <- sapply(
mcmc@sim$samples,
function(x) x[['beta2']][1001:2000]
)
advi_beta1 <- advi@sim$samples[[1]][['beta1']]
advi_beta2 <- advi@sim$samples[[1]][['beta2']]
plot(nuts_beta1, nuts_beta2)
plot(advi_beta1, advi_beta2)

NUTS managed to pick up the correlation just fine, but VI thinks the parameters are uncorrelated! What gives? (To be fair, Stan warns you that the approximation is bad.)

By default, VI speeds up computation through the mean-field assumption, i.e. local parameters are uncorrelated with each other. This makes the observations conditionally exchangeable and speeds up the gradient computations. However, as this example demonstrates, the results can be terribly wrong!

As a rule of thumb:

  1. If your data size is reasonable, use MCMC as it’s more likely to converge to the exact posterior.
  2. If your data is too big for MCMC, try mean-field VI first.
  3. If mean-field VI fails, try full-rank VI, which allows correlation between parameters but makes computation much slower.

Coordinate Ascent VI

CAVI is your vanilla gradient ascent. Someone computes the analytic updates and we iterate until the ELBO converges. This is possible only for conditionally conjugate models. If it’s impossible to set up a Gibbs sampler for your model, then CAVI is also impossible.

Simple models like the Latent Dirichlet Allocation still requires a good deal of mathematical know-how. (I say simple because the model can be described in maybe five lines, it is conditionally conjugate, and it is possible to compute the gradient.) Even with analytic updates, CAVI can be extremely slow to converge, but if a model is amenable to CAVI then we can use SVI.

Stochastic VI

If CAVI is gradient ascent, then SVI is stochastic gradient ascent. As long as we satisfy the Robbins-Monro conditions, then SVI is guaranteed to converge (though it might take many many many iterations). The step size should go down slow enough for us to fully explore the parameter space but it should go down fast enough for it to converge to a point. The LDA implementation in Spark uses SVI by default.

Initially, I thought that SVI will perform worse than CAVI, but it’s surprisingly the opposite. This paper shows that the parameters are learned much faster using SVI. A previous personal project corroborates this. A single CAVI iteration through all the documents yields worse results than SVI with only 10% of the documents.

If you can compute the natural gradients for SVI, then it should be the best algorithm for fitting the model. The main challenge: you must compute them by hand. The next two flavors of VI are used when it is too difficult or even impossible to compute natural gradients. Dubbed “black box variational inference”, they approximate the gradient using Monte Carlo methods. The svi function in Pyro is BBVI.

Score Gradient VI

Paper here. Derivation in the appendix. We want to do SGD by approximating the gradient of the ELBO. The intuition:

  • The ELBO is an integral and the gradient is a limit. The conditions hold for the dominated covergence theorem, so we can exchange the order of the ∇ and the integral.
  • Apply the product rule to the integrand.
  • The gradient of a constant is 0.
  • The expectation of the score function is 0.

Applying these in order gives us

We can do a Monte Carlo approximation on this expectation by sampling from our current q and evaluating each of the terms in the integrand.

In reality, the score gradient has too high of a variance to make it practical. Additional expertise is needed to get the variance under control. Otherwise, the BBVI will not converge in any reasonable time as we step all over the place.

Reparameterization Gradient VI

Paper here. This is the version of VI that is implemented by default on most probabilistic programming packages. I suggest reading the original paper as it’s very clearly written.

The basic idea is that we know how to do VI on a multivariate normal distribution, so why not transform all VI problems to something we already know how to solve? All the parameters in our model are MVN that have been transformed. We can apply the chain rule on these transformation functions (reparameterizations) and compute the gradients using automatic differentiation.

In practice, the posterior distribution is sensitive to the reparameterization. It’s not clear what functions yield the best approximations. Again, expertise is needed for VI to work well. Algorithms to automatically find the best reparameterizations is high on my wish list and I’ll keep an eye out for it.

Diagnostics

ELBO trace plot

source

Just like SGD algorithms, we can inspect whether the objective function has converged or not. In reality, it’s hard to know when to stop. The ELBO can jump once it escapes a local optimum, but we can’t tell whether it’s stuck in a local optimum. While the trace plot is nice, the real diagnostic is…

Importance sampling

This is the diagnostic that Stan uses to warn users when ADVI converges to a bad fit. Suppose we want to compute E[θ] w.r.t. the posterior, but we can’t do that because we don’t know the posterior. One neat trick is

So even if we can’t sample from p, we can do a Monte Carlo approximation by sampling from a convenient q. We approximate the expectation by taking this weighted average

Where w_i is the importance weight defined by

Just like with the Metropolis algorithm, we never have to evaluate p(X) because it cancels out when taking the weighted average (see a pattern?). Luckily we already have a convenient q to sample from: our fitted variational distribution. It is well-known that the importance weights can have infinite variance if p and q don’t overlap much.

Stan fits a generalized pareto distribution to the 20% highest importance weights. The GPD has infinite variance if k > 0.5, though in practice the fit is relatively decent until k > 0.7, as a rule of thumb. Stan warns the user when k > 0.7. This should be your primary diagnostic to assess the fit of your VI.

In Closing

Hopefully this article has provided good information to use and diagnose the Bayesian inference algorithms. As always, if you see anything wrong or have any suggestions, please let me know so I can amend the article.

--

--

Bayesian data scientist. Alternates between light reading and more in-depth articles about applied statistics and machine learning.