Deriving variational bounds

Variational approximations such the ones used in EM and mean field algorithms all boil down to the application of one simple bound: we start with a probability density formula that contains an intractable sum (or integral) of the form

L(T) = sumi exp(T(i))

for some (large) list of terms T(i).  Then, we bound this sum using the inequality

log L(T) >= EQ(i)(T(i)) - H(Q)

where Q(i) is an arbitrary probability distribution over the indices of the terms in the sum, and H(Q) is the negentropy

H(Q) = sumi Q(i) log Q(i)

This bound holds for arbitrary distributions Q, so in practice we will pick some simple form for Q and then find the Q of this form that makes our bound as tight as possible.  Optimizing Q this way has two effects: first, it gives us a good estimate of log L(T) and therefore a good estimate of our original probability density.  And second, as we shall see below it finds a Q which is a good approximation of the distribution

Q*(i) = exp(T(i)) / Z

(where Z is defined so that Q* sums to 1).  The distribution Q* often turns out to be of independent interest: for example, in clustering it is the posterior distribution over data assignment hypotheses, and in image segmentation it is the posterior distribution over pixel classifications.

To prove our bound, we can explicitly find the Q which makes it as tight as possible, and show that it holds for this Q.  Since Q is restricted to be a probability distribution, we can use a Lagrange multiplier k in our optimization to enforce the constraint that it sums to 1.  So, we want to solve the optimization

minQ [ log(L(T)) + H(Q) - EQ(i)(T(i)) ]
= log(L(T)) + minQmaxk [ H(Q) - EQ(i)(T(i)) + k (1 - sumiQ(i)) ]

We will start by setting the derivative wrt Q(i) to 0. Since the derivative of x log(x) is log(x)+1, we have:

0 = log(Q(i)) + 1 - T(i) - k
Q(i) = exp(T(i) + k - 1)

Now we can solve for k: since Q(i) sums to 1, exp(1-k) must act as a normalizing constant.

1 - k = log sumi exp(T(i)) = log L(T)

Finally, we can substitute back into our original expression to find that

minQ [ log(L(T)) + H(Q) - EQ(i)(T(i)) ]
= log(L(T)) + sumi Q(i) [ log(Q(i)) - T(i) ]
= log(L(T)) + sumi Q(i) [ k - 1 ]
= log(L(T)) + [ k - 1 ]
= 0

so the best possible Q makes our bound perfectly tight but doesn't violate it.

Looking back at the derivation, we can pick out some interesting facts.  First, the Q for which the bound is tightest is just the Q* distribution mentioned above.  This Q* is gotten by the multidimensional analog of a sigmoid transformation of the T(i) terms: if there are only 2 terms T(0) and T(1), then

Q*(0) = exp(T(0)) / [ exp(T(0)) + exp(T(1)) ]
= 1 / ( 1 + exp(T(1)-T(0)) )

which is the standard one-dimensional sigmoid.  Another useful representation of the above equality is in terms of the hyperbolic tangent function

tanh(z) = (exp(z) - exp(-z)) / (exp(z) + exp(-z))

Using this definition, we have

Q*(1) - Q*(0) = [ exp(T(1)) - exp(T(0)) ] / [ exp(T(1)) + exp(T(0)) ]
= tanh((T(1) - T(0)) / 2)

Second, if we are picking Q from a restricted set, we won't get Q* but instead the closest legal Q (in KL divergence):

argminQ sumi Q(i) [ log(Q(i)) - T(i) ]
= argminQ sumi Q(i) [ log(Q(i)) - T(i) + log(Z) ]
= argminQ sumi Q(i) [ log(Q(i)) - log(Q*(i)) ]
= argminQ KL(Q | Q*)

The first equality holds because log(Z) is a constant independent of Q.  The second holds by definition of Q*, and the last by definition of KL divergence.

Finally, suppose our restricted set of Qs is one which factors into a product such as

Q(i, j) = Q1(i) Q2(j)

where now our sum is over pairs i, j.  Also suppose that our terms T are of the form

T(i, j) = T1(i) + T2(j)

This form is very common in practice; for example, it holds in k-means clustering (where the factorization is over the data association variables) and in mean-field image segmentation (where the factorization is over the classifications of individual pixels).

In this case we can optimize for Q1 and Q2 separately: because

H(Q) = H(Q1) + H(Q2)
EQ(T1(i) + T2(j)) = EQ1(T1(i)) + EQ2(T2(j))

we can choose Q1 and Q2 to minimize

H(Q1) + EQ1T1(i)
H(Q2) + EQ2T2(j)

respectively.

This page is maintained by Geoff Gordon and was last modified in November 2002.