WeiYa's Work Yard

A traveler with endless curiosity, who fell into the ocean of statistics, tries to write down his ideas and notes to save himself.

Concrete Distribution: Relaxation of Discrete Random Variables

Posted on
Tags: Reparameterization, Discrete, Relaxation

This note is for Maddison, C. J., Mnih, A., & Teh, Y. W. (2017). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (No. arXiv:1611.00712). arXiv. http://arxiv.org/abs/1611.00712

The reparameterization trick enables optimizing large scale stochastic computation graphs via gradient descent.

essence: refactor each stochastic node into a differentiable function of its parameters and a random variable with fixed distribution

discrete random variables lack useful reparameterizations due to the discontinuous nature of discrete states.

the paper introduce Concrete random variables——continuous relaxation of discrete random variables.

the concrete distribution is new family of distributions with closed form densities and a simple reparameterization

whenever a discrete stochastic node of a computation graph can be refactored into a one-hot bit representation that is treated continuously, Concrete stochastic nodes can be used with automatic differentiation to produce low-variance biased gradients of objectives on the corresponding discrete graph

Introduction

the work is inspired by the observation that many architectures treat discrete nodes continuously, and gradients rich with counterfactual information are available for each of their possible states

the concrete distribution is new parametric family of continuous distributions on the simplex with closed form densities

sampling from the concrete distribution is as simple as taking the softmax of logits perturbed by fixed additive noise.

every discrete random variable corresponds to the zero temperature limit of a Concrete one.

optimizing an objective over an architecture with discrete stochastic nodes can be accomplished by gradient descent on the samples of the corresponding Concrete relaxation.

When the objective depends, as in variational inference, on the log-probability of discrete nodes, the Concrete density is used during training in place of the discrete mass. At test time, the graph with discrete nodes is evaluated.

Background

Optimizing Stochastic Computation Graphs

Stochastic computation graphs (SCGs) provides a formalism for specifying input-output mappings, potentially stochastic, with learnable parameters using directed acyclic graphs.

many training objectives in supervised, unsupervised, and reinforcement learning can be expressed in terms of SCGs

for graphs with a single stochastic node $X$, interpret the forward pass in the graph as first sampling $X$ from the conditional distribution $p_\phi(x)$ of the stochastic node given its parents, then evaluating a deterministic function $f_\theta(x)$ at $X$

think of $f_\theta(X)$ as a noisy objective, and we are interested in optimizing its expected value $L(\theta, \phi) = \bbE_{X\in p_\phi(x)}[f_\theta(X)]$ w.r.t. parameters $\theta, \phi$.

in general, both the objective and its gradients are intractable

the gradient w.r.t. to the parameters $\theta$ has the form

\[\nabla_\theta L(\theta, \phi) = \nabla_\theta\bbE_{X\sim p_\phi(x)}[f_\theta(X)] = \bbE_{X\in p_\phi(x)}[\nabla_\theta f_\theta(X)]\]

this can be easily estimated via Monte Carlo

more challenging task is to compute the gradient w.r.t. the parameters $\phi$ of $p_\phi(x)$. The expression obtained by differentiating the expected objective

\[\nabla_\phi L(\theta, \phi) = \nabla_\phi \int p_\phi(x) f_\theta(x)dx = \int f_\theta(x) \nabla_\phi p_\phi(x)dx\]

does not have the form of an expectation w.r.t. $x$

Score Function Estimators

the score function estimator, also known as the REINFORCE or likelihood-ratio estimator, is based on the identity

\[\nabla_\phi p_\phi(x) = p_\phi(x) \nabla_\phi \log p_\phi(x)\]

then

\[\nabla_\phi L(\theta, \phi) = \bbE_{X\in p_\phi(x)}[f_\theta(X)\nabla_\phi \log p_\phi(X)]\]

Reparameterization Trick

in many cases we can sample from $p_\phi(x)$ by first sampling $Z$ from some fixed distribution $q(z)$ and then transforming the sample using some function $g_\phi(z)$

e.g., $N(\mu, \sigma^2)$

sampling from $Z\in N(0, 1)$, then transform it using $g_{\mu, \sigma}(Z) = \mu + \sigma Z$

reparametrization trick: transfer the dependence on $\phi$ into $f$ by writing $f_\theta(x) = f_\theta(g_\phi(z))$ for $x\in g_\phi(z)$, making it possible to reduce the problem of estimating the gradient w.r.t. parameters of a distribution to the simpler problem of estimating the gradient w.r.t. parameters of a deterministic function

having reparametrized $p_\phi(x)$, we can now express the objective as an expression w.r.t. $q(z)$

\[L(\theta, \phi) = \bbE_{X\in p_\phi(x)}[f_\theta(X)] = \bbE_{Z\in q(z)}[f_\theta(g_\phi(Z))]\]

assuming differentiability of $f_\theta(x)$ w.r.t. $x$ and of $g_\phi(z)$ w.r.t. $\phi$ and using the chain rule gives

\[\nabla_\phi L(\theta, \phi) = \bbE_{Z\in q(z)}[\nabla_\phi f_\theta(g_\phi(Z))] = \bbE_{Z\in q(z)}[f'_\theta(g_\phi(Z))\nabla_\phi g_\phi(Z)]\]

Application: Variational Training of Latent Variable Models

such models assume that each observation $x$ is obtained by first sampling a vector of latent variable $Z$ from the prior $p_\theta(z)$ before sampling the observation itself from $p_\theta(x\mid z)$, thus the probability is $p_\theta(x) = \sum_z p_\theta(z)p_\theta(x\mid z)$.

maximum likelihood training is infeasible, because the LL objective $L(\theta) = \log p_\theta(x) = \log \bbE_{Z\in p_\theta(z)}[p_\theta(x\mid Z)]$ is typically intractable and does not fit into the above framework due to the expectation being inside the log

the multi-sample variational objective

\[L_m(\theta, \phi) = \bbE_{Z^i \sim q_\phi(z\mid x)}\left[ \log\left( \frac 1m \sum_{i=1}^m \frac{p_\theta(Z^i, x)}{q_\phi(Z^i\mid x)} \right) \right]\]

The Concrete Distribution

Discrete Random Variables and the Gumbel-max trick

a method for sampling from discrete distributions called the Gumbel-Max trick

restrict to a representation of discrete states as vectors $d\in {0, 1}^n$ of bits that are one-hot, or $\sum_{k=1}^n d_k=1$.

consider an unnormalized parameterization $(\alpha_1,\ldots, \alpha_n)$ where $\alpha\in (0, \infty)$ of a discrete distribution $D\sim \text{Discrete}(\alpha)$——assume that states with 0 probability are excluded.

the Gumbel-Max trick:

sample $U_k \sim \text{Uniform}(0, 1)$ i.i.d. for each $k$, find $k$ that maximizes $\log\alpha_k - \log(-\log U_k)$, set $D_k=1$ and the remaining $D_i = 0$ for $i\neq k$. Then

\[P(D_k = 1) = \frac{\alpha_k}{\sum_{i=1}^n\alpha_i}\]

the sampling of a discrete random variable can be refactored into a deterministic function——componentwise addition followed by argmax——of the parameters $\log\alpha_k$ and fixed distribution $-\log(-\log U_k)$

Image

the Gumbel distribution features in extreme value theory, where it plays a central role similar to the Normal distribution: the Gumbel distribution is stable under max operations, and for some distributions, the order statistics (suitably normalized) of i.i.d. draws approach the Gumbel in distribution

the Gumbel can also be recognized as a $-\log$-transformed exponential random variable

Concrete Random Variables

the derivative of the argmax is 0 everywhere except at the boundary of state changes, where it is undefined.

For this reason, the Gumbel-Max trick is not a suitable reparameterization for use in SCGs with AD

the argmax computation returns states on the vertices of the simplex.

the idea behind Concrete random variable is to relax the state of a discrete variable from the vertices into the interior where it is a random probability vector

To sample a Concrete random variable $X\in \Delta^{n-1}$ at temperature $\lambda \in (0, \infty)$ with parameters $\alpha_k \in (0, \infty)$, sample $G_k \sim \text{Gumbel}$ i.i.d. and set

\[X_k = \frac{\exp((\log\alpha_k + G_k)/\lambda)}{\sum_{i=1}^n \exp((\log\alpha_i + G_i)/\lambda)}\]

the softmax computation smoothly approaches the discrete argmax computation as $\lambda \rightarrow 0$ while preserving the relative order of the Gumbels $\log \alpha_k + G_k$

the distribution of $X$ has a closed form density on the simplex

Let $\alpha \in (0, \infty)^n$ and $\lambda \in (0, \infty)$. $X\in \Delta^{n-1}$ has a Concrete distribution $X\in \text{Concrete}(\alpha, \lambda)$ with location $\alpha$ and temperature $\lambda$, if its density is

\[p_{\alpha, \lambda}(x) = (n-1)!\lambda^{n-1} \prod_{k=1}^n \left(\frac{\alpha_kx_k^{-\lambda-1}}{\sum_{i=1}^n \alpha_i x_i^{-\lambda}}\right)\]

Let $X\in \text{Concrete}(\alpha, \lambda)$ with location parameters $\alpha \in (0, \infty)^n$ and temperature $\lambda \in (0, \infty)$, then

  • (a) (Reparameterization) If $G_k \sim \text{Gumbel}$ i.i.d., then $X_k\overset{d}{=}\frac{\exp(\log\alpha_k + G_k)/\lambda}{\sum_{i=1}^n\exp((\log\alpha_i+ G_i)/\lambda)}$
  • (b) (Rounding) $P(X_k > X_i \text{for }i\neq k) = \frac{\alpha_k}{\sum_{i=1}^n\alpha_i}$
  • (c) (Zero temperature) $P(\lim_{\lambda\rightarrow 0}X_k = 1) = \alpha_k/\sum_{i=1}^n\alpha_i$
  • (d) (Convex eventually) If $\lambda \le (n-1)^{-1}$, then $p_{\alpha, \lambda}(x)$ is log-convex in $x$

Image

Concrete Relaxations

concrete random variables may have some intrinsic value, but the paper investigates them simply as surrogates for optimizing a SCG with discrete nodes

consider the use case of optimizing a large graph with discrete stochastic nodes from samples

for a variational autoencoder with a single discrete latent variable

  • $P_a(d)$: the mass function of some $n$-dimensional one-hot discrete random variable with unnormalized probabilities $a\in (0, \infty)^n$
  • $p_\theta(x\mid d)$: some distribution over a data point $x$ given $d\in (0, 1)^n$ one-hot
  • the generative model is $p_{\theta, a}(x, d) = p_\theta(x\mid d)P_a(d)$
  • $Q_\alpha(d\mid x)$: approximating posterior over $d\in (0, 1)^n$ one-hot whose unnormalized probabilities $\alpha(x) \in (0, \infty)^n$ depend on $x$

the variational lowerbound is

\[L_1(\theta, a, \alpha) = \bbE_{D\sim Q_\alpha(d\mid x)} \left[\log\frac{p_\theta(x\mid D)P_a(D)}{Q_\alpha(D\mid x)}\right]\]

the relaxed objective is

\[L_1(\theta, a, \alpha) = \bbE_{Z\sim q_{\alpha, \lambda_1}}(z\mid x)\left[ \log \frac{p_\theta(x\mid Z)p_{a,\lambda_2}(Z)}{q_{\alpha, \lambda_1}(Z\mid x)} \right]\]

(unofficial) implementation: https://github.com/kampta/pytorch-distributions


Published in categories