WeiYa's Work Yard

A traveler with endless curiosity, who fell into the ocean of statistics, tries to write down his ideas and notes to save himself.

Statistical Review on Variational Inference

Posted on
Tags: KL-divergence, Variational Inference

This note is for Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877. https://doi.org/10.1080/01621459.2017.1285773

Variational inference is widely used to approximate posterior densities for Bayesian models, and alternative strategy to MCMC sampling.

Compared to MCMC, variational inference tends to be faster and easier to scale to large data

consider a joint density of latent variable $z$ and observation $x$

\[p(z, x) = p(z)p(x\mid z)\]

inference in a Bayesian model amounts to conditioning on data and computing the posterior $p(z\mid x)$

rather than use sampling, main idea behind variational inference is to use optimization

first, posit a family of approximate densities

then, try to find the member of that family that minimizes the KL divergence to the exact posterior

\[q^\star(z) = \argmin_{q(z)\in Q} KL (q(z)\Vert p(z\mid x))\]

Variational Inference

The Problem of Approximate Inference

Bayesian mixture of Gaussians

Consider a Bayesian mixture of unit-variance univariate Gaussians.

$K$ mixture components

the full hierarchical model is

\[\mu_k \sim N(0, \sigma^2)\\ c_i\sim \text{categorical}(1/K, \ldots, 1/K)\\ x_i \mid c_i, \mu \sim N(c_i^T\mu, 1)\]

for a sample of size $n$, the joint density of latent and observed variables is

\[p(\mu, c, x) = p(\mu) \prod_{i=1}^n p(c_i) p(x_i\mid c_i, \mu)\]

The latent variables are $z = {\mu, c}$, the $K$ class means and $n$ class assignments

the evidence is

\[p(x) = \int p(\mu) \prod_{i=1}^n \sum_{c_i} p(c_i) p(x_i\mid c_i, \mu)d\mu\]

the integrand does not contain a separate factor for each $\mu_k$

The Evidence Lower Bound

\[q^\star(z) = \argmin_{q(z)\in Q} KL(q(z)\Vert p(z\mid x))\]

recall that KL divergence is

\[KL(q(z)\Vert p(z\mid x)) = \bbE[\log q(z)] - \bbE[\log p(z\mid x)]\]

where all expectations are taken w.r.t. $q(z)$, and we have

\[KL(q(z)\mid p(z\mid x)) = \bbE[\log q(z)] - \bbE[\log p(z, x)] + \log p(x)\]

optimize an alternative objective that is equivalent to the KL up to an added constant

\[ELBO(q) = \bbE[\log p(z, x)] - \bbE[\log q(z)]\]

rewrite the ELBO as a sum of the expected log-likelihood of the data and the KL divergence between the prior $p(z)$ and $q(z)$

\[ELBO(q) = \bbE[\log p(z)] + \bbE[\log p(x\mid z)] - \bbE[\log q(z)] = \bbE[\log p(x\mid z)] - KL(q(z)\Vert p(z))\]

the variational objective mirrors the usual balance between likelihood and prior

another property of the ELBO is that is lower-bounds the log evidence

\[\log p(x) = KL(q(z)\Vert p(z\mid x)) + ELBO(q) \ge ELBO(q)\]

Unlike variational inference, EM assumes the expectation under $p(z\mid x)$ is computable

Unlike EM, variational inference does not estimate fixed model parameters——it is often used in a Bayesian setting where classical parameters are treated as latent variables

Variational EM is the EM algorithm with a variational E-step, that is, a computation of an approximate conditional

Mean-Field Variational Family

the latent variables are mutually independent and each governed by a distinct factor in the variational density

\[q(z) = \prod_{j=1}^m q_j(z_j)\]

for the Bayesian mixture of Gaussian, the joint density of latent and observed variables is

\[p(\mu, c, x) = p(\mu) \prod_{i=1}^n p(c_i) p(x_i\mid c_i,\mu)\]

Bayesian mixture of Gaussians (continued)

the mean-field variational family contains approximate posterior densities of the form

\[q(\mu, c) = \prod_{k=1}^K q(\mu_k; m_k, s_k^2) \prod_{i=1}^n q(c_i; \psi_i)\]

each latent variable is governed by its own variational factor

  • the factor $q(\mu_k; m_k, s_k^2)$ is a Gaussian distribution on the $k$-th mixture component’s mean parameter; its mean is $m_k$ and its variance is $s_k^2$
  • the factor $q(c_i; \varphi_i)$ is a distribution on the $i$-th observation’s mixture assignment; its assignment probabilities are a K-vector $\varphi_i$

these are the optimal forms of the mean-field variational density for the mixture of Gaussians

Coordinate Ascent Variational Inference (CAVI)

Image

Consider the $j$-th latent variable $z_j$. The complete conditional of $z_j$ is its conditional density given all of the other latent variables in the model and the observations, $p(z_j\mid z_{-j}, x)$.

fix the other variational factors, $q_{\ell}(z_\ell), \ell\neq j$. The optimal $q_j(z_j)$ is then proportional to the exponentiated expected log of the complete conditional

\[q_j^\star(z_j)\propto \exp[\bbE_{-j}[\log p(z_j\mid z_{-j}, x)]]\]

equivalently, it is proportional to

\[q_j^\star(z_j)\propto \exp[\bbE_{-j}[\log p(z_j, z_{-j}, x)]]\]

Practicalities

  • Initialization
  • Assessing convergence
  • Numerical stability

log-sum-exp trick

\[\log [\sum_i \exp(x_i)] = \alpha + \log[\sum_i \exp(x_i-\alpha)]\]

A Complete Example: Bayesian Mixture of Gaussians

consider $K$ mixture components and $n$ real-valued data points $x_{1:n}$

the latent variables are $K$ real-valued mean parameters $\mu = \mu_{1:K}$ and $n$ latent-class assignments $c = c_{1:n}$, where $c_i$ is an indicator $K$-vector

a fixed hyperparameter $\sigma^2$, the variance of the normal prior on the $\mu_k$’s

assume the observation variance is one and take a uniform prior over the mixture components

CAVI for the Mixture of Gaussians

Image


Published in categories