WeiYa's Work Yard

A dog, who fell into the ocean of statistics, tries to write down his ideas and notes to save himself.

Invariant Risk Minimization

Posted on 0 Comments
Tags: Causal Inference, Invariance

This note is for Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2020). Invariant Risk Minimization. ArXiv:1907.02893 [Cs, Stat].

Understanding which patterns are useful has been previously studied as a correlation-versus-causation dilemma, since spurious correlations stemming from data biases are unrelated to the causal explanation of interest.

The paper propose Invariant Risk Minimization (IRM), a novel learning paradigm that estimates nonlinear, invariant, causal predictors from multiple training environments, to enable out-of-distribution (OOD) generalization.

The IRM principle is

to learn invariance across environments, find a data representation such that the optimal classifier on top of that representation matches for all environments

Many faces of generalization

For datasets $D_e:={(x_i^e, y_i^e)}_{i=1}^{n_e}$ collected under multiple training environments $e\in \cE_{tr}$. It contains iid examples according to some probability distribution $P(X^e,Y^e)$.

The goal is to use these multiple datasets to learn a predictor $Y\approx f(X)$, which performs well across a large set of unseen but related environments $\cE_{all} \supset \cE_{tr}$. Namely, minimize

\[R^{OOD}(f) = \max_{e\in \cE_{all}} R^e(f)\]

where $R^e(f) = E_{X^e, Y^e}[\ell(f(X^e), Y^e)]$ is the risk under environment $e$.

Four techniques commonly discussed in prior work

  • empirical risk minimization (ERM)
  • robust learning objective
  • domain adaptation strategy
  • invariant casual prediction techniques

We say that a data representation $\Phi:\cX\rightarrow \cH$ elicits an invariant predictor $w\circ \Phi$ across environments $\cE$ if there is a classifier $w:\cH\rightarrow \cY$ simultaneously optimal for all environments, that is

\[w\in \arg\min_{\bar w:\cH\rightarrow\cY} R^e(\bar w\circ \Phi)\]

for all $e\in\cE$.

In these cases, a data representation function $\Phi$ elicits an invariant predictor across environments $\cE$ iff for all $h$ in the intersection of the supports of $\Phi(X^e)$, we have

\[\E[Y^e\mid \Phi(X^e)=h] = \E[Y^{e'}\mid \Phi(X^{e'})=h]\]

for all $e,e’\in\cE$.

Two goals for the data representation $\Phi$:

  • be useful to predict well
  • elicit an invariant predictor across $\cE_{tr}$
\[\begin{align*} \min_{\Phi:\cX\rightarrow\cH; w:\cH\rightarrow\cY} & \sum_{e\in \cE_{tr}} R^e(w\circ \Phi)\\ \text{subject to} & w\in \arg\min_{\bar w:\cH\rightarrow\cY} R^e(\bar w\circ \Phi), \text{for all $e\in\cE_{tr}$} \end{align*}\]

This is a challenging, bi-leveled optimization problem, since each constraint calls an inner optimization routine.

Instantiate into the practical version

\[\min_{\Phi:\cX\rightarrow\cY} \sum_{e\in \cE_{tr}} R^e(\Phi) + \lambda \Vert \nabla_{w\mid w=1.0}R^e(w\cdot\Phi)\Vert^2\]


  • $\Phi$ becomes the entire invariant predictor
  • $w=1.0$ is a scalar and fixed “dummy” classifier
  • the gradient norm penalty is used to measure the optimality of the dummy classifier at each environment $e$
  • $\lambda\in [0,\infty)$ is a regularizer balancing between predictive power (an ERM term), and the invariance of the predictor $1\cdot\Phi(x)$.

From IRM to IRMv1

Phrasing the constraints as a penalty

\[L_{IRM}(\Phi, w) = \sum_{e\in \cE_{tr}} R^e(w\circ \Phi) +\lambda D(w,\Phi, e)\]

Choosing a penalty $D$ for linear classifier $w$

\[D_{lin}(w,\Phi, e) = \Vert E_{X^e}[\Phi(X^e)\Phi(X^e)^T]w-E_{X^e,Y^e}[\Phi(X^e)Y^e]\Vert^2\]

Fixing the linear classifier $w$

relax the recipe for invariance into

finding a data representation such that the optimal classifier, on top of that data representation, is $\tilde w$ for all environments.

Scalar fixed classifiers $\tilde w$ are sufficient to monitor invariance

restrict the search to matrices $\Phi\in \IR^{1\times d}$ and let $\tilde w\in\IR^1$ be the fixed scalar 1.0.

Extending to general losses and multivariate outputs

Published in categories Note