AI Models Collapse
Posted on
the paper consider what may happen to GPT-{n} once LLMs contribute much of the text found online.
the paper finds that indiscriminate use of model-generated content in training causes irreversible defects in the resulting models, where tails of the original content distribution disappear
the paper refer to this effect as model collapse and show that it can occur in LLMs as well as in variational autoencoders (VAEs) and Gaussian mixture models (GMMs)
two close concepts to model collapse from the existing literature: catastropic forgetting arising in the framework of task-free continual learning and data poisoning maliciously leading to unintended behaviour
What is model collapse?
model collapse refers to a degenerate learning process in which models start forgetting improbable events over time, as the model becomes poisoned with its own projection of reality.
two special cases:
- early model collapse
- model begins losing information about the tails of the distribution
- late model collapse
- the model converges to a distribution that carries little resemblance to the original one, often with substantially reduced variance
this process occurs owing to three specific sources of error compounding over generations and causing deviation from the original model
- statistical approximation error: primary type of error, a non-zero probability that information can get lost at every step
- functional expressivity error: secondary type of error, arising owing to limited function approximator expressiveness. A neural network can introduce nonzero likelihood outside the support of the original distribution or zero likelihood inside the support of the original distribution
- functional approximation error: secondary type of error, arising primarily from the limitations of learning procedures, e.g., structural bias of stochastic gradient descent or choices of objective
other types of error exist, e.g., computers have limited precision in practice
Theoretical intuition
examine two mathematical models
- a discrete distribution in the absence of functional expressivity and approximation errors
- a multidimensional Gaussian approximation
the overall stochastic process, which called learning with generation data, is as follows:
- dataset at generation $i$: $\cD_i$, comprising i.i.d. random variables $X_j^i$ with distribution $p_i, j \in \{1,\ldots, M_i\}$ denotes the size of the dataset
going from generation $i$ to generation $i+1$, aim to estimate the distribution of samples in $\cD_i$, with an approximation $p_{\theta_{i+1}}$
this is functional approximation
\[p_{\theta_{i+1}} = \cF_\theta(p_i)\]the dataset $\cD_{i+1}$ is then generated by sampling from
\[p_{i+1} = \alpha_i p_{\theta_{i+1}} + \beta_i p_i + \gamma_i p_0\]with non-negative parameters $\alpha_i,\beta_i, \gamma_i$ summing to 1
Discrete distribution with exact approximation
consider a discrete probability distribution in absence of functional approximation and expressivity errors, that is, $\cF(p) = p$
in this case, model collapse arise only because of statistical errors from the sampling step
- at first, the tails (low-probability events) begin to disappear as a result of the low probability of sampling them
- and over time, support of the distribution shrinks
denote the sample size as $M$, if we consider state $i$ with probability $q\le 1/M$, the expected number of samples with value $i$ coming from those events will be less than 1.
more generally, suppose some state $i$ with probability $q$, one can show that the probability of losing information (that is, sampling no data at some generation) is equal to $1-q$, implying that the distribution must converge to a delta function positioned at some state
consider the process $X^i \rightarrow F\rightarrow p_{i+1} \rightarrow X^{i+1}$ as a Markov chain
Multidimensional Gaussian
Assume the original data are sampled from distribution $\cD_0$ (not necessarily Gaussian), with non-zero sample variance. Assume $X^n$ are fit recursively using the unbiased sample mean and variance estimators from the previous generation,
\[X_j^n \mid \mu_n, \Sigma_n \sim N(\mu_n, \Sigma_n)\]with a fixed sample size. Then
\[\bbE[\bbW_2^2(N(\mu_n, \Sigma_n), \cD_0)] \rightarrow \infty; \Sigma_n\overset{a.s.}{\rightarrow} 0 \text{ as } n\rightarrow\infty\]in which $\bbW_2$ denotes the Wassertein-2 distance between the true distribution and its approximation at generation $n$
Model Collapse in language models
fine-tune the OPT-125m causal language model on the wikiext2 dataset
- use five-way beam search
- block training sequences to be 64 tokens long
- for each token in the training set, ask the model to predict the next 64 tokens
- go through all of the original training dataset and produce an artificial dataset of the same size
consider two different settings
- five epochs, no original training data
- ten epochs, 10\% of original training data preserved
Discussion
preserving the ability of LLMs to model low-probability event is essential to the fairness of their predictions: such events are often relevant to marginalized groups. Low-probability events are also vital to understand complex systems.
training on samples from another generative model can induce a distribution shift, which ——over time——causes model collapse.
Paper 2: Is Model Collapse Inevitable?
the studies on model collapse largely assumed that new data replace old data over time, but a more relastic assumption is that data accumulate over time.
the paper asks: what effect does accumulating data have on model collapse?
the paper
- confirms that replacing the original real data by each generation’s synthetic daat does indeed tend towards model collapse
- demonstrates that accumulating the successive generations of synthetic data alongside the original real data avoids model collapse