Concrete Distribution: Relaxation of Discrete Random Variables
Posted on
The reparameterization trick enables optimizing large scale stochastic computation graphs via gradient descent.
essence: refactor each stochastic node into a differentiable function of its parameters and a random variable with fixed distribution
discrete random variables lack useful reparameterizations due to the discontinuous nature of discrete states.
the paper introduce Concrete random variables——continuous relaxation of discrete random variables.
the concrete distribution is new family of distributions with closed form densities and a simple reparameterization
whenever a discrete stochastic node of a computation graph can be refactored into a one-hot bit representation that is treated continuously, Concrete stochastic nodes can be used with automatic differentiation to produce low-variance biased gradients of objectives on the corresponding discrete graph
Introduction
the work is inspired by the observation that many architectures treat discrete nodes continuously, and gradients rich with counterfactual information are available for each of their possible states
the concrete distribution is new parametric family of continuous distributions on the simplex with closed form densities
sampling from the concrete distribution is as simple as taking the softmax of logits perturbed by fixed additive noise.
every discrete random variable corresponds to the zero temperature limit of a Concrete one.
optimizing an objective over an architecture with discrete stochastic nodes can be accomplished by gradient descent on the samples of the corresponding Concrete relaxation.
When the objective depends, as in variational inference, on the log-probability of discrete nodes, the Concrete density is used during training in place of the discrete mass. At test time, the graph with discrete nodes is evaluated.
Background
Optimizing Stochastic Computation Graphs
Stochastic computation graphs (SCGs) provides a formalism for specifying input-output mappings, potentially stochastic, with learnable parameters using directed acyclic graphs.
many training objectives in supervised, unsupervised, and reinforcement learning can be expressed in terms of SCGs
for graphs with a single stochastic node $X$, interpret the forward pass in the graph as first sampling $X$ from the conditional distribution $p_\phi(x)$ of the stochastic node given its parents, then evaluating a deterministic function $f_\theta(x)$ at $X$
think of $f_\theta(X)$ as a noisy objective, and we are interested in optimizing its expected value $L(\theta, \phi) = \bbE_{X\in p_\phi(x)}[f_\theta(X)]$ w.r.t. parameters $\theta, \phi$.
in general, both the objective and its gradients are intractable
the gradient w.r.t. to the parameters $\theta$ has the form
\[\nabla_\theta L(\theta, \phi) = \nabla_\theta\bbE_{X\sim p_\phi(x)}[f_\theta(X)] = \bbE_{X\in p_\phi(x)}[\nabla_\theta f_\theta(X)]\]this can be easily estimated via Monte Carlo
more challenging task is to compute the gradient w.r.t. the parameters $\phi$ of $p_\phi(x)$. The expression obtained by differentiating the expected objective
\[\nabla_\phi L(\theta, \phi) = \nabla_\phi \int p_\phi(x) f_\theta(x)dx = \int f_\theta(x) \nabla_\phi p_\phi(x)dx\]does not have the form of an expectation w.r.t. $x$
Score Function Estimators
the score function estimator, also known as the REINFORCE or likelihood-ratio estimator, is based on the identity
\[\nabla_\phi p_\phi(x) = p_\phi(x) \nabla_\phi \log p_\phi(x)\]then
\[\nabla_\phi L(\theta, \phi) = \bbE_{X\in p_\phi(x)}[f_\theta(X)\nabla_\phi \log p_\phi(X)]\]Reparameterization Trick
in many cases we can sample from $p_\phi(x)$ by first sampling $Z$ from some fixed distribution $q(z)$ and then transforming the sample using some function $g_\phi(z)$
e.g., $N(\mu, \sigma^2)$
sampling from $Z\in N(0, 1)$, then transform it using $g_{\mu, \sigma}(Z) = \mu + \sigma Z$
reparametrization trick: transfer the dependence on $\phi$ into $f$ by writing $f_\theta(x) = f_\theta(g_\phi(z))$ for $x\in g_\phi(z)$, making it possible to reduce the problem of estimating the gradient w.r.t. parameters of a distribution to the simpler problem of estimating the gradient w.r.t. parameters of a deterministic function
having reparametrized $p_\phi(x)$, we can now express the objective as an expression w.r.t. $q(z)$
\[L(\theta, \phi) = \bbE_{X\in p_\phi(x)}[f_\theta(X)] = \bbE_{Z\in q(z)}[f_\theta(g_\phi(Z))]\]assuming differentiability of $f_\theta(x)$ w.r.t. $x$ and of $g_\phi(z)$ w.r.t. $\phi$ and using the chain rule gives
\[\nabla_\phi L(\theta, \phi) = \bbE_{Z\in q(z)}[\nabla_\phi f_\theta(g_\phi(Z))] = \bbE_{Z\in q(z)}[f'_\theta(g_\phi(Z))\nabla_\phi g_\phi(Z)]\]Application: Variational Training of Latent Variable Models
such models assume that each observation $x$ is obtained by first sampling a vector of latent variable $Z$ from the prior $p_\theta(z)$ before sampling the observation itself from $p_\theta(x\mid z)$, thus the probability is $p_\theta(x) = \sum_z p_\theta(z)p_\theta(x\mid z)$.
maximum likelihood training is infeasible, because the LL objective $L(\theta) = \log p_\theta(x) = \log \bbE_{Z\in p_\theta(z)}[p_\theta(x\mid Z)]$ is typically intractable and does not fit into the above framework due to the expectation being inside the log
the multi-sample variational objective
\[L_m(\theta, \phi) = \bbE_{Z^i \sim q_\phi(z\mid x)}\left[ \log\left( \frac 1m \sum_{i=1}^m \frac{p_\theta(Z^i, x)}{q_\phi(Z^i\mid x)} \right) \right]\]The Concrete Distribution
Discrete Random Variables and the Gumbel-max trick
a method for sampling from discrete distributions called the Gumbel-Max trick
restrict to a representation of discrete states as vectors $d\in {0, 1}^n$ of bits that are one-hot, or $\sum_{k=1}^n d_k=1$.
consider an unnormalized parameterization $(\alpha_1,\ldots, \alpha_n)$ where $\alpha\in (0, \infty)$ of a discrete distribution $D\sim \text{Discrete}(\alpha)$——assume that states with 0 probability are excluded.
the Gumbel-Max trick:
sample $U_k \sim \text{Uniform}(0, 1)$ i.i.d. for each $k$, find $k$ that maximizes $\log\alpha_k - \log(-\log U_k)$, set $D_k=1$ and the remaining $D_i = 0$ for $i\neq k$. Then
\[P(D_k = 1) = \frac{\alpha_k}{\sum_{i=1}^n\alpha_i}\]the sampling of a discrete random variable can be refactored into a deterministic function——componentwise addition followed by argmax——of the parameters $\log\alpha_k$ and fixed distribution $-\log(-\log U_k)$
the Gumbel distribution features in extreme value theory, where it plays a central role similar to the Normal distribution: the Gumbel distribution is stable under max operations, and for some distributions, the order statistics (suitably normalized) of i.i.d. draws approach the Gumbel in distribution
the Gumbel can also be recognized as a $-\log$-transformed exponential random variable
Concrete Random Variables
the derivative of the argmax is 0 everywhere except at the boundary of state changes, where it is undefined.
For this reason, the Gumbel-Max trick is not a suitable reparameterization for use in SCGs with AD
the argmax computation returns states on the vertices of the simplex.
the idea behind Concrete random variable is to relax the state of a discrete variable from the vertices into the interior where it is a random probability vector
To sample a Concrete random variable $X\in \Delta^{n-1}$ at temperature $\lambda \in (0, \infty)$ with parameters $\alpha_k \in (0, \infty)$, sample $G_k \sim \text{Gumbel}$ i.i.d. and set
\[X_k = \frac{\exp((\log\alpha_k + G_k)/\lambda)}{\sum_{i=1}^n \exp((\log\alpha_i + G_i)/\lambda)}\]the softmax computation smoothly approaches the discrete argmax computation as $\lambda \rightarrow 0$ while preserving the relative order of the Gumbels $\log \alpha_k + G_k$
the distribution of $X$ has a closed form density on the simplex
Let $\alpha \in (0, \infty)^n$ and $\lambda \in (0, \infty)$. $X\in \Delta^{n-1}$ has a Concrete distribution $X\in \text{Concrete}(\alpha, \lambda)$ with location $\alpha$ and temperature $\lambda$, if its density is
\[p_{\alpha, \lambda}(x) = (n-1)!\lambda^{n-1} \prod_{k=1}^n \left(\frac{\alpha_kx_k^{-\lambda-1}}{\sum_{i=1}^n \alpha_i x_i^{-\lambda}}\right)\]
Let $X\in \text{Concrete}(\alpha, \lambda)$ with location parameters $\alpha \in (0, \infty)^n$ and temperature $\lambda \in (0, \infty)$, then
- (a) (Reparameterization) If $G_k \sim \text{Gumbel}$ i.i.d., then $X_k\overset{d}{=}\frac{\exp(\log\alpha_k + G_k)/\lambda}{\sum_{i=1}^n\exp((\log\alpha_i+ G_i)/\lambda)}$
- (b) (Rounding) $P(X_k > X_i \text{for }i\neq k) = \frac{\alpha_k}{\sum_{i=1}^n\alpha_i}$
- (c) (Zero temperature) $P(\lim_{\lambda\rightarrow 0}X_k = 1) = \alpha_k/\sum_{i=1}^n\alpha_i$
- (d) (Convex eventually) If $\lambda \le (n-1)^{-1}$, then $p_{\alpha, \lambda}(x)$ is log-convex in $x$
Concrete Relaxations
concrete random variables may have some intrinsic value, but the paper investigates them simply as surrogates for optimizing a SCG with discrete nodes
consider the use case of optimizing a large graph with discrete stochastic nodes from samples
for a variational autoencoder with a single discrete latent variable
- $P_a(d)$: the mass function of some $n$-dimensional one-hot discrete random variable with unnormalized probabilities $a\in (0, \infty)^n$
- $p_\theta(x\mid d)$: some distribution over a data point $x$ given $d\in (0, 1)^n$ one-hot
- the generative model is $p_{\theta, a}(x, d) = p_\theta(x\mid d)P_a(d)$
- $Q_\alpha(d\mid x)$: approximating posterior over $d\in (0, 1)^n$ one-hot whose unnormalized probabilities $\alpha(x) \in (0, \infty)^n$ depend on $x$
the variational lowerbound is
\[L_1(\theta, a, \alpha) = \bbE_{D\sim Q_\alpha(d\mid x)} \left[\log\frac{p_\theta(x\mid D)P_a(D)}{Q_\alpha(D\mid x)}\right]\]the relaxed objective is
\[L_1(\theta, a, \alpha) = \bbE_{Z\sim q_{\alpha, \lambda_1}}(z\mid x)\left[ \log \frac{p_\theta(x\mid Z)p_{a,\lambda_2}(Z)}{q_{\alpha, \lambda_1}(Z\mid x)} \right]\](unofficial) implementation: https://github.com/kampta/pytorch-distributions