WeiYa's Work Yard

A traveler with endless curiosity, who fell into the ocean of statistics, tries to write down his ideas and notes to save himself.

AI Models Collapse

Posted on
Tags: Large Language Model

This note is for Shumailov, I., Shumaylov, Z., Zhao, Y., Papernot, N., Anderson, R., & Gal, Y. (2024). AI models collapse when trained on recursively generated data. Nature, 631(8022), 755–759. https://doi.org/10.1038/s41586-024-07566-y

the paper consider what may happen to GPT-{n} once LLMs contribute much of the text found online.

the paper finds that indiscriminate use of model-generated content in training causes irreversible defects in the resulting models, where tails of the original content distribution disappear

the paper refer to this effect as model collapse and show that it can occur in LLMs as well as in variational autoencoders (VAEs) and Gaussian mixture models (GMMs)

two close concepts to model collapse from the existing literature: catastropic forgetting arising in the framework of task-free continual learning and data poisoning maliciously leading to unintended behaviour

What is model collapse?

model collapse refers to a degenerate learning process in which models start forgetting improbable events over time, as the model becomes poisoned with its own projection of reality.

Image

two special cases:

  • early model collapse
    • model begins losing information about the tails of the distribution
  • late model collapse
    • the model converges to a distribution that carries little resemblance to the original one, often with substantially reduced variance

this process occurs owing to three specific sources of error compounding over generations and causing deviation from the original model

  1. statistical approximation error: primary type of error, a non-zero probability that information can get lost at every step
  2. functional expressivity error: secondary type of error, arising owing to limited function approximator expressiveness. A neural network can introduce nonzero likelihood outside the support of the original distribution or zero likelihood inside the support of the original distribution
  3. functional approximation error: secondary type of error, arising primarily from the limitations of learning procedures, e.g., structural bias of stochastic gradient descent or choices of objective

other types of error exist, e.g., computers have limited precision in practice

Theoretical intuition

examine two mathematical models

  • a discrete distribution in the absence of functional expressivity and approximation errors
  • a multidimensional Gaussian approximation

the overall stochastic process, which called learning with generation data, is as follows:

  • dataset at generation $i$: $\cD_i$, comprising i.i.d. random variables $X_j^i$ with distribution $p_i, j \in \{1,\ldots, M_i\}$ denotes the size of the dataset

going from generation $i$ to generation $i+1$, aim to estimate the distribution of samples in $\cD_i$, with an approximation $p_{\theta_{i+1}}$

this is functional approximation

\[p_{\theta_{i+1}} = \cF_\theta(p_i)\]

the dataset $\cD_{i+1}$ is then generated by sampling from

\[p_{i+1} = \alpha_i p_{\theta_{i+1}} + \beta_i p_i + \gamma_i p_0\]

with non-negative parameters $\alpha_i,\beta_i, \gamma_i$ summing to 1

Discrete distribution with exact approximation

consider a discrete probability distribution in absence of functional approximation and expressivity errors, that is, $\cF(p) = p$

in this case, model collapse arise only because of statistical errors from the sampling step

  1. at first, the tails (low-probability events) begin to disappear as a result of the low probability of sampling them
  2. and over time, support of the distribution shrinks

denote the sample size as $M$, if we consider state $i$ with probability $q\le 1/M$, the expected number of samples with value $i$ coming from those events will be less than 1.

more generally, suppose some state $i$ with probability $q$, one can show that the probability of losing information (that is, sampling no data at some generation) is equal to $1-q$, implying that the distribution must converge to a delta function positioned at some state

consider the process $X^i \rightarrow F\rightarrow p_{i+1} \rightarrow X^{i+1}$ as a Markov chain

Multidimensional Gaussian

Assume the original data are sampled from distribution $\cD_0$ (not necessarily Gaussian), with non-zero sample variance. Assume $X^n$ are fit recursively using the unbiased sample mean and variance estimators from the previous generation,

\[X_j^n \mid \mu_n, \Sigma_n \sim N(\mu_n, \Sigma_n)\]

with a fixed sample size. Then

\[\bbE[\bbW_2^2(N(\mu_n, \Sigma_n), \cD_0)] \rightarrow \infty; \Sigma_n\overset{a.s.}{\rightarrow} 0 \text{ as } n\rightarrow\infty\]

in which $\bbW_2$ denotes the Wassertein-2 distance between the true distribution and its approximation at generation $n$

Model Collapse in language models

fine-tune the OPT-125m causal language model on the wikiext2 dataset

  • use five-way beam search
  • block training sequences to be 64 tokens long
  • for each token in the training set, ask the model to predict the next 64 tokens
  • go through all of the original training dataset and produce an artificial dataset of the same size

consider two different settings

  • five epochs, no original training data

Image

  • ten epochs, 10\% of original training data preserved

Image

Discussion

preserving the ability of LLMs to model low-probability event is essential to the fairness of their predictions: such events are often relevant to marginalized groups. Low-probability events are also vital to understand complex systems.

training on samples from another generative model can induce a distribution shift, which ——over time——causes model collapse.

Paper 2: Is Model Collapse Inevitable?

the studies on model collapse largely assumed that new data replace old data over time, but a more relastic assumption is that data accumulate over time.

the paper asks: what effect does accumulating data have on model collapse?

the paper

  • confirms that replacing the original real data by each generation’s synthetic daat does indeed tend towards model collapse
  • demonstrates that accumulating the successive generations of synthetic data alongside the original real data avoids model collapse

Image

Image


Published in categories