How much does a language model forget when finetuned on new tasks? We show both model size and optimization matter and forgetting can be nearly eliminated with self-generated replay!
We view forgetting as drift in the model's predictions on old data. So the fix is simple: use a KL penalty on past (pretraining) data to keep old outputs fixed while the model fits the new data. 2/8
Unfortunately, pretraining data is often unavailable! But since LLMs are generative models, we can use them to directly sample data. In this continual learning experiment with a 2M parameter language model, self-generated replay entirely eliminates forgetting. 3/8
We can even generate replay data from an instruction-tuned LLM. For example, when finetuning Llama-3.2-1B, we can prompt the model with a BOS token (without a chat template) and generate pretraining-like data. With a KL penalty, this data significantly reduces forgetting. 4/8
When does forgetting still happen? When the model has no spare capacity. Small models trained to saturation cannot absorb new information without overwriting old information. 5/8
Learning rate matters too. Forgetting can be reduced by using a high pretraining learning rate, making it possible to release pretrained models that are less prone to downstream forgetting. A small finetuning learning rate also mitigates forgetting. 6/8
However, a small finetuning learning rate is expensive, increasing the optimizer steps required to reach a target loss. Using replay data in finetuning breaks this tradeoff, enabling the use of a high learning rate while minimizing forgetting! 7/8
Much more in the paper! As models are increasingly being adapted to new settings, it’s especially crucial to understand forgetting. This was an incredible effort with an amazing team led by @mrtnm. Code is available at: . 8/8github.com/martin-marek/f…
• • •
Missing some Tweet in this thread? You can try to
force a refresh
We introduce epiplexity, a new measure of information that provides a foundation for how to select, generate, or transform data for learning systems. We have been working on this for almost 2 years, and I cannot contain my excitement! 1/7
Information theory comes up empty-handed on key questions: can we learn more from data than existed in the generating process? Can new information be created from deterministic transformations? Can the learnable content in data be evaluated for broad generalization? 2/7
In particular, we present three paradoxes in information theory, statements which can be justified mathematically but are in tension with intuitions and empirics. This tension arises in part from assuming unbounded computation, and failing to target useful information. 3/7
My new paper "Deep Learning is Not So Mysterious or Different": . Generalization behaviours in deep learning can be intuitively understood through a notion of soft inductive biases, and formally characterized with countable hypothesis bounds! 1/12 arxiv.org/abs/2503.02113
What makes deep learning different? Not overparametrization, benign overfitting, or double descent, which can be reproduced with other models and explained with old generalization frameworks. Understanding DL doesn't require rethinking generalization -- and it never did! 2/12
Rather than restricting the solutions a model can represent, specify a preference for certain solutions over others, through _soft_ inductive biases. This approach guides us towards structure where it exists, without significant penalty where it doesn't. 3/12
Naively using LLMs like GPT-3 for time series extrapolation can fail out of the box because of suboptimal tokenization and preprocessing. We show that if we tokenize numbers to individual digits, LLMs really shine! 2/7
We also show that language models are surprisingly natural probabilistic models of continuous data, acting like hierarchical softmax distributions over numbers when tokenized into individual digits. This allows them to fit challenging distributions common in time series. 3/7
Last year at ICML, we presented marginal likelihood pathologies in model selection and hyper learning. We now have a 60 page JMLR extension featuring: 1) should we be comforted by connections with PAC-Bayes? 2) approximations; 3) architecture search.
To recap, the marginal likelihood answers the question "how likely is my prior to generate the training data?" which is fundamentally different than "will my trained model provide good generalization?", leading to many discrepancies. See
2/16
In short, the log marginal likelihood (LML) can underfit, overfit, and heavily penalize diffuse priors that provide good generalization. The decomposition of the LML into a sum of log p(D_i|D<i) suggests a partial remedy, the conditional LML (CLML), removing the first terms. 3/16
The marginal likelihood (evidence) provides an elegant approach to hypothesis testing and hyperparameter learning, but it has fascinating limits as a generalization proxy, with resolutions.
The search for scientific truth is elusive. How do we select between theories which are entirely consistent with any data we observe? The marginal likelihood p(D|M) -- the probability we would generate our observations from our prior model -- provides a compelling approach. 2/23
MacKay's book, Ch. 28, makes a nice case: a simple model can't generate many datasets, but since p(D|M) is a normalized probability density, it gives high probability to the data it can generate. For a given dataset, the most constrained model wins, encoding "Occam's razor". 3/23
Suppose for instance there are dead pixels in an image. The weights attached to these pixels don’t affect the predictions, and so MAP (regularized optimization) drives them to zero. A BMA instead samples these weights from the prior... 2/5
...For in-distribution test data this behaviour doesn’t hurt generalization. But now suppose for example some corruption is added to the image. Now the BMA is activating connections that should be dead, hurting OOD generalization! 3/5