Invariant Risk Minimization
Posted on 0 Comments
This note is for Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2020). Invariant Risk Minimization. ArXiv:1907.02893 [Cs, Stat].
Understanding which patterns are useful has been previously studied as a correlation-versus-causation dilemma, since spurious correlations stemming from data biases are unrelated to the causal explanation of interest.
The paper propose Invariant Risk Minimization (IRM), a novel learning paradigm that estimates nonlinear, invariant, causal predictors from multiple training environments, to enable out-of-distribution (OOD) generalization.
The IRM principle is
to learn invariance across environments, find a data representation such that the optimal classifier on top of that representation matches for all environments
Many faces of generalization
For datasets $D_e:={(x_i^e, y_i^e)}_{i=1}^{n_e}$ collected under multiple training environments $e\in \cE_{tr}$. It contains iid examples according to some probability distribution $P(X^e,Y^e)$.
The goal is to use these multiple datasets to learn a predictor $Y\approx f(X)$, which performs well across a large set of unseen but related environments $\cE_{all} \supset \cE_{tr}$. Namely, minimize
\[R^{OOD}(f) = \max_{e\in \cE_{all}} R^e(f)\]where $R^e(f) = E_{X^e, Y^e}[\ell(f(X^e), Y^e)]$ is the risk under environment $e$.
Four techniques commonly discussed in prior work
- empirical risk minimization (ERM)
- robust learning objective
- domain adaptation strategy
- invariant casual prediction techniques
We say that a data representation $\Phi:\cX\rightarrow \cH$ elicits an invariant predictor $w\circ \Phi$ across environments $\cE$ if there is a classifier $w:\cH\rightarrow \cY$ simultaneously optimal for all environments, that is
\[w\in \arg\min_{\bar w:\cH\rightarrow\cY} R^e(\bar w\circ \Phi)\]for all $e\in\cE$.
In these cases, a data representation function $\Phi$ elicits an invariant predictor across environments $\cE$ iff for all $h$ in the intersection of the supports of $\Phi(X^e)$, we have
\[\E[Y^e\mid \Phi(X^e)=h] = \E[Y^{e'}\mid \Phi(X^{e'})=h]\]for all $e,e’\in\cE$.
Two goals for the data representation $\Phi$:
- be useful to predict well
- elicit an invariant predictor across $\cE_{tr}$
This is a challenging, bi-leveled optimization problem, since each constraint calls an inner optimization routine.
Instantiate into the practical version
\[\min_{\Phi:\cX\rightarrow\cY} \sum_{e\in \cE_{tr}} R^e(\Phi) + \lambda \Vert \nabla_{w\mid w=1.0}R^e(w\cdot\Phi)\Vert^2\]where
- $\Phi$ becomes the entire invariant predictor
- $w=1.0$ is a scalar and fixed “dummy” classifier
- the gradient norm penalty is used to measure the optimality of the dummy classifier at each environment $e$
- $\lambda\in [0,\infty)$ is a regularizer balancing between predictive power (an ERM term), and the invariance of the predictor $1\cdot\Phi(x)$.
From IRM to IRMv1
Phrasing the constraints as a penalty
\[L_{IRM}(\Phi, w) = \sum_{e\in \cE_{tr}} R^e(w\circ \Phi) +\lambda D(w,\Phi, e)\]Choosing a penalty $D$ for linear classifier $w$
\[D_{lin}(w,\Phi, e) = \Vert E_{X^e}[\Phi(X^e)\Phi(X^e)^T]w-E_{X^e,Y^e}[\Phi(X^e)Y^e]\Vert^2\]Fixing the linear classifier $w$
relax the recipe for invariance into
finding a data representation such that the optimal classifier, on top of that data representation, is $\tilde w$ for all environments.
Scalar fixed classifiers $\tilde w$ are sufficient to monitor invariance
restrict the search to matrices $\Phi\in \IR^{1\times d}$ and let $\tilde w\in\IR^1$ be the fixed scalar 1.0.