Video Tutorial

A Bayesian Approach to Linear Mixed Models (LMM) in R/Python

Implementing these can be simpler than you think

Eduardo Coronado Sroka
Towards Data Science
12 min readJun 22, 2020

--

There seems to be a general misconception that Bayesian methods are harder to implement than Frequentist ones. Sometimes this is true, but more often existing R and Python libraries can help simplify the process.

Simpler to implement ≠ throw in some data and see what sticks. (We already have machine learning for that. :P)

People make Bayesian methods sound more complex than they are, mostly because there’s a lot of jargon involved (e.g. weak priors, posterior predictive distributions, etc.) which isn’t intuitive unless you have previously worked with these methods.

Adding more to that misconception is a clash of ideologies — “Frequentist vs. Bayesian.” (If you weren’t aware, well, now you know.) The problem is people, especially statisticians, commonly quarrel over which methods are more powerful when in fact the answer is “it depends.”

Exhibit A: A gentle example of why this quarrel can be pointless https://twitter.com/RevBayes/status/506577193804111872

Bayesian methods, like any others, are just tools at our disposal. They have advantages and disadvantages.

So, with some personal “hot takes” out of the way, let’s move on to the fun stuff and implement some Bayesian linear mixed (LMM) models!

At a glance

Here’s what I’ll cover (both in R and Python):

  1. Practical methods to select priors (needed to define a Bayesian model)
  2. A step-by-step guide on how to implement a Bayesian LMM using R and Python (with brms and pymc3, respectively)
  3. Quick MCMC diagnostics to help you catch potential problems early on in the process

Bayesian model checking, comparison and evaluation methods aren’t covered in this article. (There are more ways to evaluate a model than RMSE.) I’ve published a subsequent article covering these in more detail.

Python tutorial
R tutorial

Setup

If you are unfamiliar with mixed models I recommend you first review some foundations covered here. Similarly, if you’re not very familiar with Bayesian inference I recommend Aerin Kim’s amazing article before moving forward.

Let’s just dive back into the marketing example I covered in my previous post. Briefly, our dataset is composed of simulated website bounce times (i.e. the length of time customers spend on a website), and the overall goal was to find out whether younger people spent more time on a website than older ones.

The dataset has 613 observed “bounce times” (bounce_time, secs) collected across 8 locations ( county), each with an associatedagewhich was later center-scaled ( std_age ). However, the number of observations per location varies (e.g. one has 150 observations while another has only 7).

Step 1: Exploratory data analysis (EDA)

EDA is the unsung hero of data analysis and you shouldn’t think of it as “just plotting data.”

(I’d stay away from posts implying it’s the “boring stuff” or those just looking to automate it so you can dive straight into modeling, even if you’re using an ML algorithm. If you’re doing so, you’re missing out on a very powerful tool.) Personally, I think it is one of the most crucial steps in your analytics workflow. It can work as an iterative tool to help you adapt your modeling strategy, especially when your strategy is to start simple and build-up to models with higher complexity.

EDA is more than just plotting data, it acts as an iterative tool that can help you adapt your modeling strategy.

For example, in our case the simplest model we can fit is a basic linear regression using sklearn (Python) or lm (R), and see how well it captures the variability in our data.

We could also consider a more complex model such as a linear mixed effects model. Again with some EDA we see that such a model captures group variability better and thus might be a better strategy. We can use the seaborn.lmplot or ggplot2's geom_smooth to quickly build some intuitive EDA plots. Here it seems that a varying-intercept, and a varying-intercept / varying-slope model might be good candidates to explore. (Our EDA also brings to light an example of Simpson’s Paradox, where a global regression line doesn’t match the individual county trends.)

We’ll proceed with three candidate models: a linear regression, a random intercept model, and a random intercept + slope model.
(The simple linear model is considered a baseline and will help exemplify the modeling process.)

Step 2: Selecting priors with help of some fake data

NOTE: At this point I’m assuming you’re familiar with Bayesian inference and won’t dive into its main components. If you’re not, stop, and go checkout the suggested readings above.

Bayesian inference requires us to select prior distributions (“priors”) for our parameters before we can fully define our model. Rather than be tempted to use uninformative priors — e.g. normal distributions with extremely large variance— to avoid prior assumptions damping down signals from the data or adding bias, you should aim to select weakly informative priors.

But how can you define what a “weak” prior is?

Good weakly informative priors take into account your data and the question being asked to generate plausible datasets.

A common strategy to define weakly informative priors is to create a “flip-book” of simulated data from your likelihood and prior combination. (This procedure is referred to as a prior predictive check.) Instead of thinking of priors as standalone (marginal) entities, simulating data with them allows us to understand how they interact with the likelihood and whether this pair generates plausible data points.

Probabilistic definition of our varying-intercepts model and priors

For example, let’s try choosing weak priors for our random intercepts model parameters (β0, β1 ,b_0 ,τ_0, and σ² ) and simulating some data.

I chose the above priors via an iterative strategy that combined information from the EDA step on population and group trends with group-specific empirical mean and variance estimates (see below). You’ll have to explore some options based on your own data and the questions being considered.

R Simulations

Python Simulations

A common strategy to define weakly informative priors is to create a “flip-book” of simulated data (also referred to as a prior predictive check).

Below is an example of what a flip-book might look like (left animation). We see that most of our simulated data points are “reasonable,” meaning we see most fake points fall within our current bounce rate range (≈160–250). Others are outside the range of our data but still within a believable range (e.g. up to 500–600 secs). If instead we had used generic weak priors [e.g. N(0, 100) and Inv-Gamma(1,100)] we’d see a lot of simulated bounce rates <100 sec. or even < 0, which is obviously less probable (right).

If you’re interested in learning more about how to select priors checkout this resource page from the Stan Development Team.

Step 3: Fitting a Bayesian Model

Having selected some reasonable weak priors on the previous step, we can now focus on building our candidate models from Step 1. As I mentioned at the beginning of this article, there are existing libraries in R and Python that can greatly simplify fitting Bayesian linear mixed models.

1) brms: an R-package that runs on Stan

If you’re familiar with lme4 and the lmer function’s formula builder you’re 90% of the way there. Yes, that simple! Other than some additional options to specify Bayesian settings, brms offers straightforward functionalities. It is an outstanding library if you dive into all its capabilities — fitting GAMs, modeling Gaussian Processes, or even imputing data via mice. I won’t cover these here but I encourage you to read more.

Below is example implementing a varying-intercept model with brms using our bounce time data. First, the package makes it easy to define priors with a built-in function prior which assigns a distribution to a selected model parameter class, group, or coefficient (prior(distribution(...), class=..., group=..., coef=...)). For example, we can set a normal prior for our intercept as prior(normal(...), class=Intercept). You can use the built-in function get_prior() to understand how to set these flags. Similarly, you can choose priors from a wide array of available distributions.

Next, you can fit the model with the brm function which offers a similar formula builder and options to those of lmer. You can further modify options specific to Bayesian methods such the number of MCMC chains to run and run length.

Extracting basic information from the model is simple. The library provides functions to print out model summaries related to the model fit summary() and posterior estimates posterior_summary().

NOTE: When using the bounce_time ~ 1 + std_age formulation, BRMS assumes the predictor(s) is mean-centered. If you’re predictors are not mean-centered you should use bounce_time ~ 0 + intercept + std_age instead. You can read more here (pages 38–39)

2) PyMC3: a Python library that runs on Theano

Although there are multiple libraries available to fit Bayesian models, PyMC3 without a doubt provides the most user-friendly syntax in Python. Although a new version is in the works (PyMC4 now running on Tensorflow), most of the functionalities in this library will continue to work in future versions.

Compared to the brms package in R, PyMC3 requires you to explicitly state each individual component of the model — from priors to likelihood.

Complete specification isn’t as bad as it sounds. Once you specify the model (shown left) it can be straightforward to build using a context manager.

Well, most of the time 😉.

The only caveat to this package is that your model’s parameterization can have a big impact on its sampling efficiency.

For example, the model code below uses a slightly different parameterization to the model shown above (i.e. a centered parameterization).

First, we initialize the model with a context manager and the pm.Model() method. This will basically encapsulate the variables and likelihood factors of a model. Next, we can start assigning distributions to each of our components (priors, hyperpriors, and likelihood) using a wide variety of built-in distributions.

Setting up model distributions is simple. First define the distribution class you want to use viapm.your_distribution_class(), then assign it a label and modify its distributional parameters, and finally assign it to a context variable
(e.g. beta1=pm.Normal('std_age', mu=mu0, sigma=tau0) defines beta1 as a normal with label std_age and parameters mu0 and tau0).

Once your model components are set, you just need to setup the pm.trace() function which will store your posterior samples. Similar to brms, you can also further modify options specific to Bayesian methods such the number of MCMC chains to run and run length.

Since PyMC3 relies on arviz for post-hoc visualizations, it is worthwhile to save the model output as a InferenceData type for later use in model diagnostics and posterior checks.

Output from above random intercepts model; alpha represents the county-specific intercepts

Step 3.a: MCMC Diagnostics

Phew! That was a lot information. However, no modeling endeavor is complete without some good ol’ diagnostics. (NOTE: Model checking, evaluation and comparison are also part of this process but I cover them in more detail in a separate article.)

Traditionally, two diagnostic visualizations — MCMC trace plots and autocorrelation function (ACF) diagrams — help us understand whether there were problems with sampling. In other words, how reliable are our samples and thus our inference? These tools appeal to some theoretical properties of Markov Chains and are good sanity checks to understand the chains’ performance. Yet, brms and pymc3 use advanced MCMC methods where additional diagnostics are available that provide more granular picture of their performance.

Aside from conventional diagnostics, additional methods are available for advanced techniques such as Hamiltonian Monte Carlo (HMC) and can provide a more detailed picture of the Markov chain performance.

Let’s dive into an example implementing conventional and HMC-related diagnostics.

Conventional MCMC Diagnostics

The two common visualizations highlighted above help us understand whether

  1. The Markov chains had good mixing (i.e. seem to have converged to the steady state distribution)
  2. Samples don’t violate the Markov property (e.g. autocorrelation dies out after lag 1)

In R you can easily generate these using the mcmc_trace and mcmc_acffunctions in bayesplot using the your fitted model.

Example for Linear Regression
Trace plots (left) and Autocorrelation plots (right) for linear regression model in R

In Python pymc3 has a built-in integration to visualize these via arviz , you just need the trace object specified in the model.

Trace plots (left) and Autocorrelation plots (right) for linear regression model in Python

Hamiltonian Monte Carlo Diagnostics

As mentioned, Hamiltonian Monte Carlo offers additional diagnostics we can use. Both pymc3 and brms run on a very efficient HMC sampler called No-U-Turn Sampling (NUTS).

Here’s a crude one sentence, high-level summary:

HMC uses first-order gradients to direct sampling in useful directions which helps it converge faster to the posterior than traditional MCMC methods and NUTS dynamically updates parameters that can affect HMC’s performance (i.e. step size ε and number of steps L).

Nuts, right?! (I hope you can forgive the bad pun). Just to clarify that last point, think about what’s happening in terms of gradient-like updates where, for example, if you choose a step size too small or the number of steps too high you might waste compute time. Although this is a useful analogy, the specifics on how this algorithm works are more complex and elegant.

Back to “So what, dude?” Well, in cases where the chains aren’t mixing or the ACF is showing higher dependence of samples it can be hard to diagnose what’s happening.

However, if you’re using HMC or NUTS, you can look at divergent transitions which can help you understand the geometry of your posterior (!!). These arise when the proposed next step(s) deviate from the most efficient path set by the sampler.

Divergence is powerful diagnostic that allows you to understand the posterior geometry.

Visualizing these can help in two ways:

  1. To detect “where” your sampler is having trouble traversing the posterior “landscape” (potentially due to local high variability or curved geometry). In such cases , you might consider re-parameterizing your model to “smooth” those high curvature regions.
  2. Troubleshooting “false positive” divergences (where flagged divergent transitions behave the same as non-divergent ones)

Below is an example of how divergence could help identify places where the chains get “stuck” or are incorrectly flagged as divergent (green). In both the left and right diagrams, regions where green points or lines are highly concentrated indicate convergence problems (i.e. you can see this funnel-like concentration for tau below). Here the sampler wasn’t able to adequately explore the posterior region defined by tau and thus our inferences on tau are unreliable. Conversely, we can also gain confidence other parameter inferences (e.g. theta1,...,theta8 ) where trajectories were incorrectly flagged as divergent but behaved like non-divergent ones.

Gabry, Jonah, et al. (2019)

To learn more about interpreting these powerful diagnostic plots, check out this write-up from Michael Betancourt. To run this diagnostic in R use the mcmc_scatter function from bayesplot.

In Python you can use the plot_parallel PyMC3 built-in function.

In this case, we can see there aren’t any divergent trajectories like we saw above and thus our chains have explored the posterior geometry without trouble (i.e. converged).

Divergent diagnostic in Python (left) and R (right)

Conclusion

Fitting Bayesian models has been greatly simplified by packages such as pymc3 and brms. However, being simpler to implement doesn’t mean we should forget about key steps in our modeling process that can help use build better models.

I hope you’ve learned some new approaches for selecting good prior distributions and how to troubleshoot Bayesian mixed models, and why EDA can be a powerful tool (not just in a Bayesian setting).

Finally, my hope is to keep people from shying away from Bayesian approaches just because “they sound hard,” and instead provide them with new methods they can add to their toolset.

References

  1. Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian data analysis. CRC press.
  2. Gabry, Jonah, et al. “Visualization in Bayesian workflow.” Journal of the Royal Statistical Society: Series A (Statistics in Society) 182.2 (2019): 389–402.

--

--