What is the Expectation-Maximisation algorithm?


Many systems can be described in a deterministic fashion. That is to say, given some collection of information about a system, you may reliably and consistently predict some other information about the system.

For other systems, we either lack the information we need, or we lack knowledge of an accurate deterministic model that can describe the system.

Examples of these two types of systems can be seen across almost all disciplines. In Physics, a good example of the first would be newtonian (classical) mechanics. A good example of the second type would be statistical mechanics. In statistical mechanics, we surrender ourselves to the fact that we are unable to obtain the information necessary to fully predict the evolution of certain system - e.g a system like a hot chamber filled with multitudes of particles, each with their own position and velocity.

To predict how such a system evolves, we would have to go around collecting up information about each particle’s position and velocity - which is obviously untenable, unless of course you’re maxwell’s demon. In 1867, Physicist James Clark Maxwell suggested a thought experiment in which a tiny enough demon in charge of a tiny enough door between two systems would be capable of decreasing entropy and thus violating the second law of thermodynamics

In these situations, instead of describing the state of the system from a deterministic perspective, we can describe it from a probabilistic perspective.

How exactly you describe a system probabilistically however can vary quite dramatically.

Let’s say your system consists of a single quantity which you can make observations of, in this toy example we’ll consider this quantity to be the weather on any particular day. Now even though the true distribution of this quantity (the weather) may be quite complex, one of the ways we can try getting close to the probability distribution of this quantity is to first assume that it corresponds to a well known distribution that we’re comfortable with. We can then fine-tune this known probability distribution until it does a good job of describing the quantities we’ve observed.

Take for example the normal distribution. This distribution is popular because it has all sorts of nice properties, and it also has two convenient knobs/parameters (mean and variance) that we can fine-tune until it starts predicting data well.

Fine-tuning your models this way is often called Maximum Likelihood. We are fine-tuning our model (probability distribution) until we’ve maximised the likelihood of seeing the data.

$$p(\text{data}|\text{model parameters})$$

Now let’s say as a domain expert you want to incorporate some of your knowledge of the weather into your model. You know that the seasons effect the weather - perhaps during summer and spring, the weather is more likely to be one way, and during autumn and winter, it’s more likely to be another way. We’ll call this season variable S.

We’ve now updated our model of the world to be the following: [S –> W], parameterised by the probability distributions p(W|S) and p(S).
Our very simple model of the world.

Our likelihood is then simply $$ p(w,s|\theta) = p(w|s,\theta_{ws})p(s|\theta_s) $$ Now we can take logs $$ \begin{align} \log(p(w,s|\theta)) &= \log(p(w|s,\theta_{ws})p(s|\theta_s)) \\ &= \log(p(w|s,\theta_{ws})) + \log(p(s|\theta_s)) \end{align} $$ Which conveniently let’s us optimise the likelihood with respect to theta, by optimising with respect to \(\theta_{ws}\) and \(\theta{s}\) separately. We say that the parameters have decoupled.

The issue arises when instead we realise there may be hidden quantities that we can not observe, which themselves may be modulating the probability distributions of the variables we can observe.

Let’s say - being the smart domain experts we are - we know that the ocean currents (O) affect the weather. And that the the ocean currents are in turn affected by the season. (I’m not a weather expert, so this may or may not be true). Unfortunately, we’re not smart enough to measure these ocean currents, so for the most part they remain this elusive hidden variable in our model, which we can only indirectly observe the effects of via the weather.
Our still very simple but slightly less simple model of the world.

Now if we wanted to go back and ignore this hidden variable, we would do a poor job at finding the correct parameters that describe the data. Because the true underlying distribution may be varying very dependently on the value of this hidden variable (the ocean currents), and this variation is not being captured.

So naturally we think about including this hidden variable in our model, and therefore during our maximum likelihood calculation. One natural way of accounting for an unknown variable in a distribution would be marginalising over it, and maximising our parameters with respect to that quantity instead.

i.e. We maximise the “marginal likelihood”: $$ \begin{align} \int_{o}p(w,s,o|\theta)d(o) &= \int_{o}p(w|o,s,\theta_w)p(o|s,\theta_o)p(s|\theta_s)d(o) \\ &= p(s|\theta_s)\int_{o}p(w|o,s,\theta_w)p(o|s,\theta_o) \end{align} $$

This quantity is called the marginal likelihood because we have marginalised out our hidden variables (in this case, the ocean currents).

But now we’ve ended up somewhere a little undesirable. Where we could once optimize our individual \(\theta\)s separately, we must now optimise our \(\theta\)s together. I.e. summing across the joint distribution for all the values of the hidden variable means we end up coupling together any parameters that specify a distribution which contains those hidden variables.

This coupling means computing the likelihood is more difficult than it would be if we could observe all our variables. To be explicit, the difficulty here would be that summing over the product of probability distributions would not allow us to optimise the maximum likelihood with respect to each parameter separately.

What if there was a way to decouple these parameters again and therefore allow us to calculate our log likelihoods more easily. Well, being the smart folks we are, we’re aware of something called Jensen’s inequality. Jensen’s inequality roughly states that the expected value of a concave function is less than or equal to the concave function applied to the expected value.

Sounds pretty abstract, so why does this help us?

Recount for a second the sum over a product of probabilities that we couldn’t conveniently separate. Well if there was a way to get a log within that integral, we’d conveniently decouple our parameters once again. And hey, an expectation is kind of an integral, and a log is a concave function, so maybe by rewriting our marginal likelihood as an expectation, we can use jensen’s inequality and push that log inside our pesky integral.

So that’s exactly what we do:

$$ \begin{align} \overbrace{\log p(w,s|\theta)}^{\text{marginal likelihood}} &= \log\left[ \int_{o}p(w,s,o|\theta)\enspace do \right] \\ &= \log\left[ \int_{o}p(w|o,s,\theta_w)p(o|s,\theta_o)p(s|\theta_s) \enspace do \right] \\ &= \log\left[ \int_{o}\frac{q(o)}{q(o)}p(w|o,s,\theta_w)p(o|s,\theta_o)p(s|\theta_s) \enspace do \right] \\ &= \log\left[ \int_{o}q(o)\frac{p(w|o,s,\theta_w)p(o|s,\theta_o)p(s|\theta_s)}{q(o)} \enspace do \right] \\ &= \log\left[ \mathbb{E}_q \left[ \frac{p(w|o,s,\theta_w)p(o|s,\theta_o)p(s|\theta_s)}{q(o)} \right ] \right] \\ &\ge \mathbb{E}_q \left[\log \frac{p(w|o,s,\theta_w)p(o|s,\theta_o)p(s|\theta_s)}{q(o)} \right] \\ &= \mathbb{E}_q \left[\log \frac{p(w,o,s|\theta)}{q(o)}\right] \end{align} $$

And hey presto, we’ve got our logs inside our integral.

Of course, it’s not all sunshine and rainbows. In interpreting our integral as an expectation, we’ve sneaked in a ‘variational’ distribution \(q(o)\). As for this lower bound, how close is it really, and does optimizing it in turn optimize for the all elusive marginal likelihood.

We can answer these questions by rewriting the above lower bound in the following way:

$$ \begin{align} \text{Lower Bound} &= \mathbb{E}_{q} \left [ \log \left [ p(w,s|\theta) \frac{p(o|w,s,\theta)}{q(o)} \right ] \right] \\ &= \log p(w,s|\theta) + \mathbb{E}_{q} \left [ \log \frac{p(o|w,s,\theta)}{q(o)} \right] \\ &= \underbrace{\log p(w,s|\theta)}_{\text{marginal likelihood}} + \underbrace{D_{KL}(p(o|w,s,\theta)|q(o))}_{\text{KL Divergence}} \\ \end{align} $$

Where the KL term is the Kullback-Leibler divergence between our q distribution and the posterior distribution of hidden states. The Kullback-Leibler divergence can be thought of as a measure of how much one probability distribution differs with respect to another.

Note that the optimal value of q, i.e the value that maximises our bound, for fixed p, and simultaneously ensures that the bound is ‘tight’, occurs if we set \( q(o) = p(o|w,s,\theta)\). In this situation, the KL divergence between our variational distribution and the posterior become 0, because they are one and the same.

In the classical EM-scheme, this is generally how we proceed. We set \(q(o)\) equal to \(p(o|w,s,\theta)\) - this is the E step. Certain types of probabilistic models make this distribution easy to calculate, see Baum-Welch for HMMs

Once we’ve set q to it’s optimal value, we keep it fixed, and proceed to maximise our lower bound with respect to \(\theta\) - the eponymous M step. We then fix that, and reset optimal q, and round and round we go until we converge onto a local minima.

Because the second term (our KL divergence) is 0 after our E-step, any maximisation we do with respect to \(\theta\) is guaranteed to maximise the log likelihood, as our bound is tight.

Finally, note that there are times where you can not simply set \(q(o)\) equal to the posterior \(p(o|,w,s,\theta)\), because either it’s difficult to compute, or the energy term under this distribution is difficult to compute. In these situations, we sometimes use a family of distributions which do not contain our true distribution, but are more tractable. Also note that given this scenario our likelihood is not guaranteed to increase, because the KL-divergence is not 0 and therefore our bound is not tight.