Training Hidden Markov Models

The Baum-Welch and Forward-Backward Algorithms

Field Cady
Towards Data Science

--

In my previous article I introduced Hidden Markov Models (HMMs) — one of the most powerful (but underappreciated) tools for modeling noisy sequential data. If you have an HMM that describes your process, the Viterbi algorithm can turn a noisy stream of observations into a high-confidence guess of what’s going on at each timestep.

This assumes, however, that you already have an HMM with all of its parameters tuned. Sometimes that’s the case — I remember from my time at Google that we had a very fancy one that modeled word order in the English language — but more often you just have raw data and you must fit an HMM to it yourself. This post will discuss that situation.

The TLDR is this: if you truly have no labeled data and no knowledge of anything you can use the Baum-Welch algorithm to fit an HMM. But for technical reasons the Baum-Welch algorithm doesn’t always give the right answer. On the other hand if you do have some knowledge and a little bit of time, then there are myriad ways to hack the training process.

Two Parts to Train: the Markov Chain and the Observations

Recall from the previous article that an HMM has two parts to it:

  • An underlying markov chain that describes how likely you are to transition between different states (or stay in the same state). Typically this underlying state is the thing that you’re really interested in. If there are k states in the HMM then the markov chain consists of 1) a k*k matrix saying how likely you are to transition from a state S1 to a state S2, and 2) a k-length vector saying how likely you are to start off in each of the states.
  • A probability model that lets you compute Pr[O|S] — the probability of seeing observation O if we assume that the underlying state is S. Unlike the markov chain, which has a fixed format, the model for Pr[O|S] can be arbitrarily complex. In many HMMs though Pr[O|S] is pretty simple: each state S is a different loaded die, and the Pr[O|S] are its probabilities of landing on each side.

To a large degree these two moving parts can be considered independently. You might even have external knowledge that tells you what one of them is but not the other. For example, say you are trying to make a transcript of an audio recording; the states are English words, and the observations are blurbs of sound. You might use an existing corpus of English texts to train your markov model, and then only tune the parameters for Pr[O|S] when it comes time to train your HMM.

If you have a large amount of labeled data — where you have the sequence of observations and a knowledge of what the underlying state is — training the HMM truly breaks down to two independent problems. First you use the labels to train the markov chain. Then you divy up the observations based on what state they were in and train P[O|S] for each state S. If we have reliable state labels for our data, then training the HMM is straightforward.

In practice though we usually just have the sequence of observations, with no rock-solid knowledge of what state the system was in. So here is the intuition: we guess at what the state labels are and train an HMM using those guesses. It will be a pretty pitiful HMM, but it will probably have some bit of real pattern in it even if just by dumb luck. Then we use the trained HMM to make better guesses at the states, and re-train the HMM on those better guesses. This process continues until the trained HMM stabilizes. This back-and-forth — between using an HMM to guess state labels and using those labels to fit a new HMM — is the essence of the Baum-Welch algorithm.

Baum-Welch Algorithm: the Fine Print

The intuition I gave you for the Baum-Welch (BW) algorithm needs two points of clarification.

Firstly, we generally start by guessing the parameters of the HMM rather than the underlying states in our data. This lets us make sure there are patterns present in the HMM, and subsequent iterations can determine which of those patterns are real. You could start by guessing the states if you wanted, but it will converge faster this way. Plus, oftentimes you will have good guesses about the HMM parameters that you can use to make a judicious initial guesses.

The second point is a bit more subtle. What do we mean by “guessing the states” once we have an HMM in-hand? In the previous post we talked about the Viterbi algorithm, which takes a sequence of observations and gives us a single best-fit sequence of states for them. The problem though is that some of those states might be guessed reliably while other could be quite ambiguous — the Viterbi algorithm doesn’t distinguish between high- and low-confidence guesses.

To remedy this second problem we ditch the Viterbi algorithm and its faux-confidence. Instead we use its lesser-known cousin the forward-backward algorithm, which gives us probabilistic guesses that can be used to gauge confidence.

The Forward-Backward Algorithm

Let’s get technical for a minute. The Viterbi algorithm doesn’t just “decode the observations”. It solves a very specific math problem: given a sequence o¹o²… of observations it finds the single sequence s¹s²… of states that maximizes the combined probability of those states and observations:

Pr[s¹s²…]*Pr[o¹o²…|s¹s²…]

It’s optimizing for the entire sequence of states, not for any one particular state. But let’s say instead that we are only interested in a particular state s⁵. It could be that the single best sequence of states — the one found by the Viterbi algorithm — has s⁵=0, but there are many slightly-worse sequences for which s⁵=1. In this case the slightly-worse sequences add up, and the best guess for s⁵ in isolation is 1.

The Forward-Backward algorithm finds the best-guess states for each timestep. In general these will be different from the Viterbi states, although in practice they’re typically very close (example: in the data I’m working with right now they disagree 5% of the time, and I was surprised it was that high).

For our purposes the value of the forward-backward algorithm isn’t the places where it differs from the Viterbi algorithm. It’s the fact that it rigorously calculates each timestep’s probability of being in each state. This is what we use for the Baum-Welch algorithm! We don’t train our HMM on a sequence of states that we know for certain: we train our HMM so that it is on average the best fit over all these probabilistic guesses.

I won’t give pseudo-code for the Baum-Welch algorithm here. Instead I will note that it is a specific case of the EM algorithm, which is used to fit models when there are unknown “hidden variables” (in this case which state is at which timestamp) that can be guessed probabilistically. The Achilles heal of the Baum-Welch algorithm is the same thing that plagues the EM algorithm in general: solutions that are only locally optimal.

The Great Boogeyman: Local Optima

I mentioned that the first step in the Baum-Welch algorithm is to guess at the parameters of the HMM. That means randomly initializing the parameters (a dirichlet distribution is a popular choice for HMMs, but that’s outside our scope here). The problem though is that different random starting points can converge on different final HMMs — there is no single best answer that you get to regardless of the starting point.

The various HMMs that you might converge to are called “local optima”. Take a local optimum X — you can think of it as a point in a high-dimensional vector space, where the dimensions are the different parameters of the HMMs. X is better than any other HMM that is within a certain distance of it, and if your initial guess is close enough to X then the Baum-Welch algorithm will converge on X. But there can easily be another X’ that is better than X but far away from it.

I ran into an interesting example of this recently in my own work. I was modeling which of 7 desktop applications a person was using at each timestep, and tried to fit a 2-state HMM to it. My hope was to find a “natural” behavior pattern where they switched rapidly between several apps. Instead I fitted an HMM twice, and each time did a pie chart of how much time the user spent in each app for each state. Here are the results from the two runs:

Different initial conditions in the Baum-Welch algorithm, even training on the same data, can yield wildly different final outputs due to the presence of many local optima.

You can see that one time the purple app got combined with the brown and pink apps in a single state, while the other time it got merged in with the orange/blue/red/green apps. Additional runs yielded other outcomes. I consistently found that a given app was used almost exclusively within one state. But there were 7 apps to divy up between two states, and each way of divying them was its own local optimum.

There is no general way to avoid this issue — fitting an HMM is what’s called a “non-convex” problem and those will usually suffer from local optima. There are however several ways to mitigate it:

  • You can simplify your HMM. The number of local optima can grow exponentially with the number of parameters in your HMM. If you reduce the parameter size you can reduce the number of local optima there are to fall into. This could mean using a 2-state rather than a 3-state HMM. If your observations are multinomials it could mean reducing the number of possible observations.
  • Use business intuition to have an initial HMM that is similar to what you expect the ultimate outcome to be. Different local optima often have large qualitative differences between each other, as in the example that I encountered, so on a gross level the HMM you ultimately converge on is likely to resemble the one you started with. Do make sure to have your guess be partly random though — there could be several local optima that are consistent with your business intuitions, and if so you want to know about it.

Note that local optima is a different problem from over-fitting (which of course also applies!). In the example I ran into each distribution of apps among states would be its own local optimum regardless of how much training data I had.

Parting Words

I’m not gonna lie: using a trained HMM is more straightforward than fitting it in the first place. On the other hand — at least if you’re like me — that’s part of the fun! We run into this a lot at Zeitworks when applying HMMs to client data. We need a model that is simple enough to mitigate local optima but complicated enough to describe the process we are studying. Out-of-the-box tools often don’t handle edge cases like partially-known ground truth, and we often find ourselves tinkering with core algorithms.

On the other hand this complexity pales next to other techniques you might consider, like recurrent neural networks. Plus HMMs are easier to make business sense of, more debuggable, and they often perform better if there is limited training data. They don’t solve all problems (nothing does), but they are an irreplaceable tool in the arsenal.

--

--

Field Cady has written two books on data science, and consulted for companies of all sizes. His background is in mathematical physics, especially stochastics.