Variational Approximations in Graphical Models

Speaker: Prof. Yee Whye Teh (notes kindly prepared by the speaker and edited by TL).

1. Variational Inference and Learning

1. 1 Role of the gradient of the log partition function

The computations associated with inference and learning can be expressed in terms of computations associated with the gradient of the log partition function \(\nabla A\).

On the one hand, given some observations of some random variables in the graphical model, we are interested in computing the posterior distribution over the unobserved variables (in machine learning this is called inference, although the word means something else in statistics).
The posterior distribution can be expressed as an exponential family distribution with known natural parameter \(\theta^{\circ}\) associated with a factor graph over the unobserved variables.
Specifically we are often interested in posterior expectations, e.g. marginal probabilities of unobserved variables, which are simply entries of the mean parameter \(\mu^{\sharp} = {\nabla A}(\theta^{\circ})\).

On the other hand, given a dataset of observations, we are also interested in learning via maximum likelihood estimation in an exponential family. Supposing that the dataset is fully observed, with empirical average sufficient statistics \({\hat\mu} = \frac{1}{N} \sum_{k=1}^N \phi(x^{(k)})\). Then the average log likelihood is

\[ \begin{align} \frac{1}{N} \sum_{k=1}^N \left\langle \theta,\phi(x^{(k)})\right\rangle - A(\theta) \end{align} \]

Setting derivatives to 0 to get the MLE \(\theta^{\sharp}\),

\[ \begin{align} {\hat\mu} - \nabla A(\theta^{\sharp}) = 0 \end{align} \]

So any member of the inverse map \((\nabla A)^{-1}(\hat\mu)\) is an ML estimator.

Summarizing we have that

  • posterior computation amounts to computing \(\nabla A\),

  • ML estimation amounts to computing \((\nabla A)^{-1}\).

1. 2 Convex Duality

In fact, the process of maximizing the log likelihood is simply the operation of taking the convex (Fenchel-Legendre) conjugate of the log partition function \(A\):

\[ \begin{align} A^\star(\mu) \,\,:=\,\, \sup_{\theta\in\Theta} \left\langle \theta,\mu\right\rangle - A(\theta) \end{align} \]

For those familiar with convex analysis, this is simply because for a convex function \(A\), \(\nabla A^{\star} = (\nabla A)^{-1}\) when considering the generalized gradient.
Remark: If \(\mu\not\in \mathcal M\), then it can be shown that \(A^\star(\mu)=\infty\).

To clarify, let us write that the supremum is attained at values \(\theta^{\sharp}\in \nabla A^{-1}(\mu)\), for \(\mu\in \text{ri}\,\mathcal M\). At these values, we have

\[ \begin{align} A^\star(\mu) = \left\langle \theta^{\sharp},\mu\right\rangle - A(\theta^{\sharp}) \end{align} \]

taking the gradient in \(\mu\) thus yields \(\nabla A^{\star}(\mu)=\theta^{\sharp}=(\nabla A)^{-1}(\mu)\).

Now observe that for a distribution in the exponential family, we have \(\log p_{\theta}(X=x)=\left\langle \theta,\phi(x)\right\rangle-A(\theta)\) and we defined the mean parameter as \(\mu=\mathbb E_{\theta}[\phi(x)]\) so that

\[ \begin{eqnarray} A^{\star}(\mu) \,\,=\,\, \left\langle \theta^{\sharp},\mu\right\rangle - A(\theta^{\sharp}) \,\,=\,\, -\mathbb E_{\theta^{\sharp}}[-\log p_{\theta^{\sharp}}(x)] \,\,=\,\, - H(p_{\theta^{\sharp}}) \end{eqnarray} \]

where \(H\) denotes the (Boltzmann-Shannon) entropy of a distribution.

Since \(A\) is convex, lower semi-continuous, \((A^\star)^\star=A\), so that

\[ \begin{align} A(\theta) \,\,=\,\, \sup_{\mu\in \mathcal M} \left\langle \theta,\mu\right\rangle - A^\star(\mu). \label{eq:var} \end{align} \]

The supremum is achieved at \(\mu^{\sharp} = \mathbb E_\theta[\phi(x)]\). This is most easily seen for minimal exponential families. Indeed, in that case \(\nabla A\) is injective (by strict convexity) so that \((\nabla A)^{-1}=\nabla A^{\star}\) is single-valued. We can thus set the gradient equal to zero and have:

\[ \begin{eqnarray} \theta - \nabla A^{\star}(\mu^{\sharp}) &=& 0 \end{eqnarray} \]

so that \(\mu^{\sharp}=\nabla A(\theta)=\mathbb E_{\theta}[\phi(x)]\) (the last equality coming from the properties of \(A\) introduced in the first part).

Gathering results, we have that

  • for a known \(\theta\), we can obtain the corresponding means parameter \(\mu(\theta)=\nabla A(\theta)=\mathbb E_{\theta}[\phi(x)]\) (computation of posterior expectations or posterior marginal probabilities),

  • for an estimated \(\mu\), we can obtain the corresponding parameter \(\theta(\mu)=\nabla A^{\star}(\mu)\) (maximum likelihood estimation).

Both these problems can, as we showed, be put considered into an optimization problem so that we are effectively linking inference and optimization.

2. Mean field and Bethe approximations

The equation \(\eqref{eq:var}\) is the core of variational inference, and a number of variational approaches can be related to it.

2. 1 Mean Field Approximation

The framework described in the previous section is a new perspective based on convex analysis by Martin Wainwright, which culminated in an in-depth survey paper (Wainwright and Jordan, 2008). The classical variational principle has a different starting point.

Suppose we have a distribution \(p_{\theta}\) in the exponential family (say). For example this can be the posterior distribution in a graphical model given observations. Consider now the optimization problem of finding the distribution \(q\) over the set of all possible distributions that minimizes the Kullback-Leibler (KL) divergence \(\text{KL}(q\|p_{\theta})\). Recall that it is defined as

\[ \begin{eqnarray} \text{KL}(q\|p) &=& \mathbb E_{q}[\log q]-\mathbb E_{q}[\log p]. \end{eqnarray} \]

Since we haven't restricted the set of distribution for \(q\), the solution is obviously \(q=p_\theta\) itself. Writing this, we have:

\[ \begin{eqnarray} 0\,\,=\,\, \inf_q \,\,\text{KL}(q\|p_{\theta}) &=& \inf_q\, \mathbb E_q[-\log p_\theta(x)] - \mathbb E_q[-\log q(x)] \nonumber\\ &=& \inf_q\, \mathbb E_q[ - \left\langle \theta,\phi(x)\right\rangle + A(\theta)] - H(q), \label{eq:classical} \end{eqnarray} \]

where, at the second line, we have used the fact that \(p_{\theta}\) is in the exponential family and the definition of the entropy.

The \(A(\theta)\) term can be ignored as it does not vary with \(q\). The first term \(\mathbb E_q[-\left\langle \theta,\phi(x)\right\rangle]\) is an expected energy and \(H(q)\) is the entropy. The interpretation in statistical physics is that we have a system at a fixed temperature. The system prefers low energy states. However because of thermal fluctuations it cannot settle into the minimum energy state which typically has low entropy. As a result at equilibrium it settles into a distribution over states that compromises between the two.

The result of the optimization is not very interesting as we haven't learnt anything. However it shows that the distribution we are interested in can be posed as the result of an optimization problem.

The starting point of the classical mean field variational principle is that if the optimization problem \(\eqref{eq:classical}\) cannot be solved, then we can solve a simpler problem by using a simpler form for \(q\), for example that it is a factorized distribution, e.g.

\[ \begin{align} q(x) = \prod_i q_i(x_i; \gamma_i) \end{align} \]

where \(\gamma_i\) are variational parameters to be optimized. Structured variational approximations are possible, e.g. with factorization across groups of variables. Graphically, this can be thought of as simplifying the model by removing edges from a complete graph.
This mean field approach has also been extended to Bayesian learning, by treating parameters \(\theta\) as random and treated as any other node on the graphical model (Ghahramani and Beal (2001), Beal (2003)).

In the past, optimization of the variational parameters has been done in coordinate-wise closed-form analytic updates. This is often achievable in exponential family models, but have limited the approach to exponential family models, often with conjugate priors. In recent years there has been a spade of approaches to optimizing variational parameters using stochastic approximation approaches, coupled with control variates and other techniques for reducing variance in the stochastic gradients (Salimans and Knowles (2012), Hoffman, Blei et al.\ (2013), Mnih and Gregor (2014), Kingma and Welling (2014), Rezende, Mohamed and Wierstra (2014)).

The above is related to the convex framework by letting \(\mu=\mathbb E_q[\phi(x)]\) be the expected sufficient statistics under \(q\), in which \(\eqref{eq:classical}\) gives,

\[ \begin{align} \inf_{\mu\in\mathcal M} -\left\langle \theta,\mu\right\rangle + A^\star(\mu) + A(\theta) \end{align} \]

Noting that the minimum when \(q=p_\theta\) is 0, the above is equivalent to \(\eqref{eq:var}\). The difference is that the optimization, rather than over the set of all distributions (very high or infinite dimensional space), is now over a lower, finite dimensional space. This is good. The bad is that the space of feasible \(\mu\) i.e., the polytope \(\mathcal M\) is a complex space to work with as it can have exponentially many constraints specifying it.

Another difficulty is that the function \(A^\star(\mu)\) is typically not easy to compute: we do know that it is the entropy \(H(p_{\nabla A^\star(\mu)})\), but \(\nabla A^\star(\mu)\) is just as difficult to solve for as \(\eqref{eq:var}\), as is the computation of the entropy function, which defeats the purpose. Various variational approximations involve approximating one or both of the polytope or \(A^\star\).

For example, the mean field approximation is an inner approximation to the set \(\mathcal M\), since instead of optimizing over all \(\mu\) corresponding to all distributions \(q\), we are now optimizing only over all \(\mu\) corresponding to tractable factorized distributions. It does not involve approximations to the entropy, since the factorized distributions typically have easily computed entropies. However this set of tractable mean parameters is not a convex set. For example, the average of two factorized distributions is not factorized. So that the optimization problem can have local optima and multiple restarts are sometimes necessary, and we are never guaranteed to find the global optimum.

2. 2 Bethe Approximation

Another approximation is called the Bethe approximation, and it involves approximations to both the polytope and the entropy. We will restrict our discussion to undirected graphical models with pairwise potentials here for simplicity. That is, the (non-minimal) exponential family is:

\[ \begin{align} p(X=x) &=\exp\left(\sum_{i;a} \theta_{i;a} \phi_{i;a}(x) + \sum_{ij;ab} \theta_{ij;ab}\phi_{ij;ab}(x) - A(\theta) \right) \end{align} \]

where the edges \((ij)\) can form cycles. The marginal polytope is in generally not easily characterized.

Recall our tree example. The marginal polytope in that case is easily characterized by local marginalization and non-negativity constraints:

\[ \begin{align} \forall &i \in V, \,\forall a\in\mathcal X_{s} & \mu_{i;a} &\ge 0 \nonumber \\ \forall& (ij)\in E,\,\forall (a,b)\in\mathcal X_{i}\times \mathcal X_{j} & \mu_{ij;ab} &\ge 0 \nonumber \\ \forall& (ij)\in E,\,\forall a\in\mathcal X_{i}, b\in\mathcal X_{j} & \sum_b \mu_{ij;ab} &= \mu_{i;a} \end{align} \]

The set of vectors \(\mathcal{L}\) satisfying the local constraints includes all vectors in the marginal polytope, but also includes other vectors that do not fall into the marginal polytope. In other words, they are not the marginals of any distribution. They are sometimes referred to as pseudomarginals. As an example, the pseudomarginals for a graphical model with a 3-node cycle given below are not globally consistent (with some distribution):

 

For a tree still, the entropy function \(A^\star(\mu)\) can also be easily calculated. Using

\[ \begin{align} p(X=x) \,\,=\,\, \prod_i p(X_i=x_i) \prod_{(ij)\in E} \frac{p(X_i=x_i, X_j=x_j)}{p(X_i=x_i)p(X_j=x_j)} \end{align} \]

we get:

\[ \begin{align} A^\star(\mu) \,\,=\,\, \sum_{i,a} \mu_{i;a}\log \mu_{i;a} + \sum_{ij,ab} \mu_{ij;ab}\log\frac{\mu_{ij;ab}}{\mu_{i;a}\mu_{j;b}} \label{eq:betheentropy} \end{align} \]

The Bethe approximation involves approximating the marginal polytope with the local constraints (an outer approximation) and approximating the entropy with \(\eqref{eq:betheentropy}\) on cyclic graphs. The variational problem \(\eqref{eq:var}\) is then:

\[ \begin{align} \sup_{\mu\in\mathcal{L}}\,\, \sum_{i,a} \mu_{i;a} \theta_{i;a} + \sum_{ij,ab} \mu_{ij;ab}\theta_{ij;ab} -\sum_{i,a} \mu_{i;a} \log\mu_{i;a}-\sum_{ij,ab} \mu_{ij;ab}\log\frac{\mu_{ij;ab}}{\mu_{i;a}\mu_{j;b}} \label{eq:bethe} \end{align} \]

One can optimise the above in different ways. The standard way is to derive a fixed point equation whose fixed points are stationary points of \(\eqref{eq:bethe}\). This gives the loopy belief propagation algorithm (Frey and MacKay (1997), Weiss (1997), Murphy, Weiss and Jordan (1999), Kschischang, Frey and Loeliger (2001), Yedidia, Freeman and Weiss (2001a,2001b,2001c)).
This is not guaranteed to converge, but often converges quickly. Another is to derive algorithms that are guaranteed to locally maximize \(\eqref{eq:bethe}\) (Welling, Teh (2001), Yuille (2002)), but these tend to be ‘‘double loop’’ algorithms that take longer to converge.

There has been a lot of work in this area, including alternative approximations, e.g. convexified Bethe approximations (Wainwright, Jaakkola and Willsky (2002)), approximations to find the MAP configuration etc. See also (Wainwright and Jordan (2008)) for an in-depth review.

3. Historical Notes and Bibliography

A great in-depth survey to variational inference in machine learning is the report by Wainwright and Jordan (2008). These notes are based on an earlier note (Wainwright and Jordan (2003)).

The first variational inference methods in machine learning were mean-field methods, and were introduced from the statistical physics literature by Hinvan (1993) and later popularized by Saul, Jaakkola and Jordan (1996) and Jordan, Ghahramani and Jaakkola (1999). The use of variational methods for Bayesian inference started with variational Bayes (Ghahramani and Beal (2001), Beal (2003)).

Loopy belief propagation was studied first in the coding community in the context of error-correction codes and then in graphical model inference by Frey and MacKay (1997), Murphy, Weiss and Jordan (1999), Kschischang, Frey and Loeliger (2001). Its link to the Bethe free energy was observed by Yedidia, Frey and Weiss (2001a,b,c) who also extended it to generalized belief propagation. Expectation propagation was proposed by Minka (2001) who over the years has produced a significant literature on the topic.

3. 1 Books and Papers

  • Beal, Variational Algorithms for Approximate Bayesian Inference, PhD thesis, Neuroscience Unit, UCL, 2003. Link.

  • Frey and MacKay, A revolution: Belief propagation in graphs with cycles, NIPS, 1997. Link.

  • Hinton and van Camp, Keeping neural networks simple by minimizing the description length of the weights, in Proc. ACM Conv. Comp. Learn. Th., 1993. Link.

  • Hoffman, Blei, Paisley and Wang, Stochastic variational inference, JMLR, 2013. Link.

  • Jordan, Ghahramani, Jaakkola and Saul, An introduction to variational methods for graphical models, in Learning in Graphical Models, Kluwer, 1999. Link.

  • Kingma and Welling, Auto-encoding variational Bayes, ICLR, 2014. Link.

  • Kschischang, Frey and Loeliger, Factor graphs and the sum-product algorithm, IEEE Transactions on Information Theory, 2001. Link.

  • Minka, A Family of Algorithms for Approximate Bayesian Inference, PhD thesis, MIT, 2001. Link.

  • Mnih and Gregor Neural variational inference and learning in belief networks, ICML, 2014. Link.

  • Murphy, Weiss and Jordan, Loopy belief propagation for approximate inference: an empirical study, UAI, 1999. Link.

  • Rezende, Mohamed and Wierstra, Stochastic backpropagation and approximate inference in deep generative models, ICML, 2014. Link.

  • Salimans and Knowles, Fixed-form variational posterior approximation through stochastic linear regression, arXiv, 2012. Link.

  • Saul, Jaakkola and Jordan, Mean field theory for sigmoid belief networks, JAIR, 1996. Link.

  • Wainwright, Jaakkola and Willsky, A new class of upper bounds on the log partition function, UAI, 2002. Link.

  • Wainwright and Jordan, Graphical models, exponential families and variational inference, Technical Report, UCBerkeley, 2003.

  • Wainwright and Jordan, Graphical models, exponential families and variational inference, Foundations and Trends in Machine Learning, 2008. Link.

  • Weiss, Belief propagation and revision in networks with loops, Technical Report, MIT, 1997.

  • Welling and Teh, Belief optimization for binary networks: a stable alternative to loopy belief propagation, UAI, 2001. Link.

  • Yedidia, Freeman and Weiss, Bethe free energy, Kikuchi approximations, and belief propagation algorithms, Technical Report, MERL, 2001a. Link.

  • Yedidia, Freeman and Weiss, Generalized belief propagation, in Advances in Neural Information Processing Systems, 2001b. Link.

  • Yedidia, Freeman and Weiss, Understanding belief propagation and its generalizations, Technical Report, MERL, 2001c. Link.

  • Yuille, CCCP algorithms to minimize the Bethe and Kikuchi free energies: Convergent alternatives to belief propagation, in Neural Computation, 2002. Link.