Statistical Review on Variational Inference
Posted on
Variational inference is widely used to approximate posterior densities for Bayesian models, and alternative strategy to MCMC sampling.
Compared to MCMC, variational inference tends to be faster and easier to scale to large data
consider a joint density of latent variable $z$ and observation $x$
\[p(z, x) = p(z)p(x\mid z)\]inference in a Bayesian model amounts to conditioning on data and computing the posterior $p(z\mid x)$
rather than use sampling, main idea behind variational inference is to use optimization
first, posit a family of approximate densities
then, try to find the member of that family that minimizes the KL divergence to the exact posterior
\[q^\star(z) = \argmin_{q(z)\in Q} KL (q(z)\Vert p(z\mid x))\]Variational Inference
The Problem of Approximate Inference
Bayesian mixture of Gaussians
Consider a Bayesian mixture of unit-variance univariate Gaussians.
$K$ mixture components
the full hierarchical model is
\[\mu_k \sim N(0, \sigma^2)\\ c_i\sim \text{categorical}(1/K, \ldots, 1/K)\\ x_i \mid c_i, \mu \sim N(c_i^T\mu, 1)\]for a sample of size $n$, the joint density of latent and observed variables is
\[p(\mu, c, x) = p(\mu) \prod_{i=1}^n p(c_i) p(x_i\mid c_i, \mu)\]The latent variables are $z = {\mu, c}$, the $K$ class means and $n$ class assignments
the evidence is
\[p(x) = \int p(\mu) \prod_{i=1}^n \sum_{c_i} p(c_i) p(x_i\mid c_i, \mu)d\mu\]the integrand does not contain a separate factor for each $\mu_k$
The Evidence Lower Bound
\[q^\star(z) = \argmin_{q(z)\in Q} KL(q(z)\Vert p(z\mid x))\]recall that KL divergence is
\[KL(q(z)\Vert p(z\mid x)) = \bbE[\log q(z)] - \bbE[\log p(z\mid x)]\]where all expectations are taken w.r.t. $q(z)$, and we have
\[KL(q(z)\mid p(z\mid x)) = \bbE[\log q(z)] - \bbE[\log p(z, x)] + \log p(x)\]optimize an alternative objective that is equivalent to the KL up to an added constant
\[ELBO(q) = \bbE[\log p(z, x)] - \bbE[\log q(z)]\]rewrite the ELBO as a sum of the expected log-likelihood of the data and the KL divergence between the prior $p(z)$ and $q(z)$
\[ELBO(q) = \bbE[\log p(z)] + \bbE[\log p(x\mid z)] - \bbE[\log q(z)] = \bbE[\log p(x\mid z)] - KL(q(z)\Vert p(z))\]the variational objective mirrors the usual balance between likelihood and prior
another property of the ELBO is that is lower-bounds the log evidence
\[\log p(x) = KL(q(z)\Vert p(z\mid x)) + ELBO(q) \ge ELBO(q)\]Unlike variational inference, EM assumes the expectation under $p(z\mid x)$ is computable
Unlike EM, variational inference does not estimate fixed model parameters——it is often used in a Bayesian setting where classical parameters are treated as latent variables
Variational EM is the EM algorithm with a variational E-step, that is, a computation of an approximate conditional
Mean-Field Variational Family
the latent variables are mutually independent and each governed by a distinct factor in the variational density
\[q(z) = \prod_{j=1}^m q_j(z_j)\]for the Bayesian mixture of Gaussian, the joint density of latent and observed variables is
\[p(\mu, c, x) = p(\mu) \prod_{i=1}^n p(c_i) p(x_i\mid c_i,\mu)\]Bayesian mixture of Gaussians (continued)
the mean-field variational family contains approximate posterior densities of the form
\[q(\mu, c) = \prod_{k=1}^K q(\mu_k; m_k, s_k^2) \prod_{i=1}^n q(c_i; \psi_i)\]each latent variable is governed by its own variational factor
- the factor $q(\mu_k; m_k, s_k^2)$ is a Gaussian distribution on the $k$-th mixture component’s mean parameter; its mean is $m_k$ and its variance is $s_k^2$
- the factor $q(c_i; \varphi_i)$ is a distribution on the $i$-th observation’s mixture assignment; its assignment probabilities are a K-vector $\varphi_i$
these are the optimal forms of the mean-field variational density for the mixture of Gaussians
Coordinate Ascent Variational Inference (CAVI)
Consider the $j$-th latent variable $z_j$. The complete conditional of $z_j$ is its conditional density given all of the other latent variables in the model and the observations, $p(z_j\mid z_{-j}, x)$.
fix the other variational factors, $q_{\ell}(z_\ell), \ell\neq j$. The optimal $q_j(z_j)$ is then proportional to the exponentiated expected log of the complete conditional
\[q_j^\star(z_j)\propto \exp[\bbE_{-j}[\log p(z_j\mid z_{-j}, x)]]\]equivalently, it is proportional to
\[q_j^\star(z_j)\propto \exp[\bbE_{-j}[\log p(z_j, z_{-j}, x)]]\]Practicalities
- Initialization
- Assessing convergence
- Numerical stability
log-sum-exp trick
\[\log [\sum_i \exp(x_i)] = \alpha + \log[\sum_i \exp(x_i-\alpha)]\]A Complete Example: Bayesian Mixture of Gaussians
consider $K$ mixture components and $n$ real-valued data points $x_{1:n}$
the latent variables are $K$ real-valued mean parameters $\mu = \mu_{1:K}$ and $n$ latent-class assignments $c = c_{1:n}$, where $c_i$ is an indicator $K$-vector
a fixed hyperparameter $\sigma^2$, the variance of the normal prior on the $\mu_k$’s
assume the observation variance is one and take a uniform prior over the mixture components