This is not a fancy novel method. It's plain old distillation.
But we investigate it thoroughly, for model compression, via the lens of *function matching*.
We highligh two crucial principles that are often missed: Consistency and Patience. Only both jointly give good results!
0. Intuition: Want the student to replicate _the whole function_ represented by the teacher, everywhere that we expect data in input space.
This is a much stronger view than the commonly used "teacher generates better/more informative labels for the data". See pic above.
1. Consistency: to achieve this, teacher and student need to see the same view (crop) of the image. For example, this means no pre-computed teacher logits! We can generate many more views via mixup.
Other approaches may look good early, but eventually fall behind consistency.
2. Patience: The function matching task is HARD! We need to train *a lot* longer than typical, and actually we were not able to reach saturation yet. Overfitting does not happen, as when function-matching, an "overfit" student is great! (Note: w/ pre-computed teacher, we overfit)
2b. Excessively long training may mean optim struggle. We try advanced optimization via Shampoo, and get 4x faster convergence.
We believe this setting is a great test-bed for optimizer research: No concern of overfitting, and reducing training error means generalizing better!
3. By distilling a couple large BiT R152x2 models into a ResNet-50, we get a ResNet-50 on ImageNet that gets 82.8% at 224px resolution, and 80.5% at 160px! 😎
No "tricks" just plain distillation, patiently matching functions.
4. Importantly, this simple strategy works on many datasets of various sizes, down to only 1020 training images, where anything else we tried overfit horribly.
Be patient, be consistent, that's it. Eventually, you'll reach or outperform your teacher!
2c. We can't stress patience enough. Multiple strategies, for example initializing the student with a pre-trained model shown here, look promising at first, but eventually plateau and are outperformed by patient, consistent function matching.
5. We have a lot more content. MobileNet students, distilling on on "random other" data (shown below), very thorough baselines, a teacher ensemble, and.... BiT download statistics!
PS: we are working on releasing a bunch of the models, including the best ones, ... but we're also on vacation. Watch github.com/google-researc… and stay tuned, we're aiming for next week!
• • •
Missing some Tweet in this thread? You can try to
force a refresh
o3-mini-high figured out the issue with @SakanaAILabs CUDA kernels in 11s.
It being 150x faster is a bug, the reality is 3x slower.
I literally copy-pasted their CUDA code into o3-mini-high and asked "what's wrong with this cuda code". That's it!
Proof: chatgpt.com/share/67b6f47c…
Fig1: o3-mini's answer.
Fig2: Their orig code is wrong in subtle way. The fact they run benchmarking TWICE with wildly different results should make them stop and think.
Fig3: o3-mini's fix. Code is now correct. Benchmarking results are consistent. 3x slower.
There are three real lessons to be learned here: 1) Super-straightforward CUDA code like that has NO CHANCE of ever being faster than optimized cublas kernels. If it is, something is wrong. 2) If your benchmarking results are mysterious and inconsistent, something is wrong. 3) o3-mini-high is REALLY GOOD. It literally took 11sec to find the issue. It took me around 10min to make this write-up afterwards.
my fork of the author's colab, with the fix:
PS: I wouldn't have found the bug myself, because it's been a literal decade since I wrote cuda kernel launch code myself...colab.research.google.com/drive/1CS1g0Of…
I have to say that this MNIST weights figure looks suspicious as hell.
I've trained linear + softmax mnist and looked at weights often, and it never looks as bad as presented here. However, their score of ~92.5% is the expected one, so that's good.
I trained a plain MNIST linear model, not tuning much (no wd!) and it looks like one of the two pics below.
Add small wd and it looks like another of the pics below.
I trained another one with their Harmonic stuff, closely following their code and hparams in the code, and it looks like the other pic below.
Don't trust the headline. This 56M…
…DOESNT include compute, it's everything else
…scattered across many unis, profs, post-docs, phds
…who collaborate in theory, but work on their individual papers in reality, cuz that's what's needed to graduate
…is unrelated to DeepSeek
What I expect to come out of it:
- a whole bunch of papers, probably some interesting ones. A lot on multilinguality and language transfer.
- a series of benchmarks for various EU languages
- (I hope) nice small-languages datasets
I think that's good overall.
What I do not expect to come out of it:
- an open-source base model at the frontier
Just had a quick look at DeepSeek's new Janus Pro paper.
I don't think it's a big deal (yet...!), but quick TL;DR below before hype gets out of hands.
It's as straightforward an Omni model as it gets:
- a core autoregressive decoder LLM
- a SigLIP encoder for understanding (L@384 why not So400m)
- a VQ-VAE for generation (from LlamaGen)
Three training stages: 1. new-params only on ImageNet 2. fine-tune all on mm-mix 3. SFT mix
The main changes from Janus to -Pro are basically all data-related: 1. "Purify" stage2 by moving ImageNet to Stage1, make that longer. 2. Update data for understanding by taking from DeepSeek-VL2 3. Update data for generation by throwing in MidJourney data
I've held a skeptical opinion of DiffTransformer after skimming it, but after some recent interaction thought that's unfair and I'd give it a proper read. Did so on a recent flight.
Paper TL;DR: pair two attention heads, and do:
(sm(Q1K1) - λ sm(Q2K2)) V
The motivation in Fig1 is very solid a priori: as context gets long, the sum of (small) attention on irrelevant tokens might be more than the attention to few individual relevant tokens, thus drowning them.
However, it is just an illustration, and I'm still not sure how much this is really a problem for well trained models. Instabilities usually occur because attention logits grow, which makes the attention look much more like the green one already. This is not often talked about, but evidence is spread across some papers like ViT22B or "small scale proxies". If max attn is ~0.5, then we're already in the green.
I would have liked some plots about attn distribution/entropy in the DiffTransformer paper to actually justify the illustration. (We'll get back to this)
Next, the core idea is very simple and nice, but I notice a few details making it immediately less "neat" somehow?
The DiffAttn actually does _not_ re-normalize the diff, unlike what happened in the Fig1 motivating illustration. This confuses me a lot, so now how does this lead to amplifying the "relevant" scores? The second head clearly got a different job now, assuming lambda is positive, it must learn what to suppress from the first one?
I would have loved some plots and experiments looking at how it really looks like in trained DiffTransformers, instead of just the fig1 mismatched illustration.
Learns a (R,W,b) per layer and _per position_ in the (prompt) sequence.
However, they hparam search which subset of layers and positions to apply it to.
Even more, they suggest to (sometimes) tie the (R,W,b) parameters from a layer across positions.
2/5
It seems best to apply to all layers, but only few positions. For decoder models, tying across positions, but not for decoders. Rank can be lower for smaller models.