There's been some back-and-forth about this paper on getting gradients without doing backpropagation, so I took a minute to write up an analysis on what breaks and how it might be fixed.
tl;dr: the estimated gradients are _really_ noisy! like wow
The main result I claim is an extension of Thm 1 in the paper. They prove that the _expected value_ of the gradient estimate is the true gradient, and I worked out the _variance_ of the estimate.
It's big! Each entry has variance equal to the entire true gradient's norm😬
(Sketch of the proof: nothing is correlated, everything has 0 mean and is symmetric around the origin, the only relevant terms are chi-squared r.v.s with known variances that get scaled by the gradient norms. gaussians are fun!)
Informally, we say that "noisy gradients" are bad and slow down learning.
So I looked at the "signal to noise ratio" between the true gradient value and the variance of the estimate.
It's bad! If you're scaling your gradients properly, it gets worse as you add parameters.
(FYI, I sanity-checked my result by pulling gradients from a PyTorch MNIST example and checking the true gradient's norm against the average variance of each entry, which should be equal. And they were super close!)
I give some intuitions for the variance, and for the general distribution of the forward gradients (g), based on product distributions and large random vectors.
In that paragraph I mention some simulations (related to the sanity check above). I didn't include the plots, but here they are! The alignment between the forward grad and the true gradient is all over the place -- and way worse than randomness from minibatch effects.
More could've been said about the weaknesses of FG in the paper, but I don't think it's a useless idea.
So I wrote some suggestions. For example, if you already have a good prior about the gradient direction, maybe you could sample from it instead of a unit normal?
@theshawwn i saw you expressing interest in the forward gradient stuff and reasonable skepticism about the value of MNIST experiments
this is a fairly rigorous argument that the gradient noise is too high for fwd grads, as is, to work in large models
Last week @brad19brown, @jordanjuravsky, & co-authors released a paper on inference-time scaling laws that enable small LMs to beat the big boys.
So this weekend, @HowardHalim & I dropped everything to run their analysis on a new model + new data.
Success 😎
Why this matters:
Details of our work and repro code on the Modal blog.
All you need are @modal_labs and @huggingface credentials! And it's free: it fits in the $30/month in Modal's free tier.modal.com/blog/llama-hum…
First: we are bad at using language models.
They are statistical models of Unicode sequences. We know that sequential sampling is hard, but (driven by the economics of inference service providers) we ignore that when sampling from LMs and sample a single sequence greedily.
I had a delightful session talking through the paper "In-Context Learning and Induction Heads" with author @NeelNanda5.
It's part of a long research thread, one of my favorites over the last five years, on "reverse engineering" DNNs.
The core claim of the paper is that a large fraction of the in-context learning behavior that makes contemporary transformer LLMs so effective comes from a surprisingly simple type of circuit they call an _induction head_.
In the video, Neel and I talk through the context of this claim and some of the phenomenological evidence for it.
In the process, I was delighted to discover that we share a deep love for and perspective informed by the natural sciences.
I recently presented a series of four reports over 40 years on system failure, ranging from a 1985 typewritten white paper on mainframe database crashes to a 2021 Zoom talk on outages in one of Google's ML-based ranking systems.
Here's a summary, with connections to reliable ML.
Each report was a post-hoc meta-analysis of post-mortem analyses: which "root causes" come up most often? Which take the most time to resolve?
Each captures 100 or more outages from a system using best practices of its era & modality at the largest scale.
"Why Do Computers Stop" was the first in the series, by Jim Gray (standing, center), who pioneered transactional databases and the ACID principle in 80s.
It's clear that these ideas were informed by his close engagement with actual failure data.
from @DivitaVohra, an overview of @Spotify's ML platform. super cool to hear how a product manager thinks about the problem of supporting ML systems
from @jeffboudier, an overview of the awesome work being done at @huggingface, with a focus on democratization of best practices, e.g. fast inference with Infinity