Samip Profile picture
solving generalization at https://t.co/zsptJBlblS

Jun 4, 8 tweets

1/ Now that we're running out of data, how do you optimally scale multi-epoch pretraining to hundreds of epochs?

Our first paper from Q! q0 trains a population of models, instead of single model that saturates fast, reaching a dramatically lower loss at *every* epoch budget.

w/ @bishmdl76 @akshayvegesna @ShmuelBerman

2/ Paper:

q0 is built on one intuition, motivated by Solomonoff induction: instead of training one perfect model, train a population of diverse models and aggregate predictions. Everything in the algorithm follows from this one goal of efficiently training a population. It comes down to three core primitives:arxiv.org/abs/2606.03938

3/ Primitive 1: fast exploration of weight space. Training many models from scratch to build a population is too expensive. Inspired by FGE, we collect many models along a few parallel cyclic trajectories. The mechanism is anti-correlating weight decay with the LR, so each cycle explores early (high LR, low WD) then settles into a low-norm basin right before we snapshot.

Primitive 2: model capability compounding via chain distillation. Independently trained models all come out about equally good, so adding more doesn't lift quality. We train each model against its predecessor as a frozen teacher (KL on soft targets), so every model improves on the last and the population compounds.

Primitive 3: a learned generalization prior. Uniform averaging wastes the good members. We fit one softmax weighting over models on a held-out set by minimizing ensemble loss, then reuse it to pick and weight the best K models for any inference budget.

4/ Now the results at 256 epochs. A single model saturates after ~16 epochs. Our strong ensembling baseline pushes past that but converges slowly.

q0 matches the ensemble baseline at only ~56 epochs of training, 4.6x fewer, and keeps improving through 256 to a val loss of 3.003 vs the baseline's 3.048. The gains translate to downstream benchmarks too.

5/ Importantly, these gains hold at every epoch budget, from one to hundreds. But the optimal allocation shifts with scale. A budget splits across three knobs: parallel base models, cycles per model, and cycle length.

Small budgets want a single base model, with frequent cycles packed toward the end of training. As the budget grows, adding parallel base models starts to pay. Roughly one more base each time you double the epochs (one base up to ~128 epochs, two to ~256, three to ~512).

6/ I'm confident this beats standard pretraining at any budget, even a single epoch, but the biggest limitation is inference cost. An ensemble of K models means K forward passes. It's effectively a way of growing the combined model's parameter count, like scaling depth but without the saturation depth scaling faces.

As with any large model, the fix is distillation into a single model, which tends to work magically well, but we leave that to future work.

7/ Looking beyond this paper: scaling compute against a fixed, limited pool of data will need new primitives. Searching over a population of models is a different problem than standard gradient descent training and we've barely scratched the surface. We hope q0 pushes people toward crazy ideas in multi-epoch training and scaling compute in general!!

8/ Huge thanks to Andrew Gordon Wilson (@andrewgwils) for feedback on the paper!

Code at Slowrun: github.com/qlabs-eng/slow…

Share this Scrolly Tale with your friends.

A Scrolly Tale is a new way to read Twitter threads with a more visually immersive experience.
Discover more beautiful Scrolly Tales like this.

Keep scrolling