François Chollet Profile picture
Co-founder @ndea. Co-founder @arcprize. Creator of Keras and ARC-AGI. Author of 'Deep Learning with Python'.

Feb 17, 2024, 6 tweets

The "aha" moment when I realized that curve-fitting was the wrong paradigm for achieving generalizable modeling of problems spaces that involve symbolic reasoning was in early 2016.

I was trying every possible way to get a LSTM/GRU based model to classify first-order logic statements, and each new attempt was showing a bit more clearly than the last that my models were completely unable to learn to perform actual first-order logic -- despite the fact that this ability was definitely part of the representable function space. Instead, the models would inevitably latch onto statistical keyword associations to make their predictions.

It has been fascinating to see this observation echo again and again over the past 8 years.

From 2013 to 2016 I was actually quite convinced that RNNs could be trained to learn any program. After all, they're Turing-complete (or at least some of them are) and they learn a highly compressed model of the input:output mapping they're trained on (rather than mere pointwise associations). Surely they could perform symbolic program synthesis in some continuous latent program space?

Nope. They do in fact learn mere pointwise associations and completely useless for program synthesis. The problem isn't with what the function space can represent -- the problem is the learning process. It's SGD.

Ironically, Transformers are even worse in that regard -- mostly due to their strongly interpolative architecture prior. Multi-head-attention literally hardcodes sample interpolation in latent space. Also, the fact that recurrence is a really helpful prior for symbolic programs.

Not saying that Transformers are worse than RNNs, mind you -- Transformers are *the best* at *what deep learning does* (generalizing via interpolation), specifically *because* of their strongly interpolative architecture prior (MHA). They are, however, worse at learning symbolic programs (which RNNs also largely fail at anyway).

As pointed out by @sirbayes, this paper has a formal investigation into the observation from the first tweet -- that "my models were completely unable to learn to perform actual first-order logic -- despite the fact that this ability was definitely part of the representable function space. Instead, the models would inevitably latch onto statistical keyword associations to make their predictions"

The paper concludes: "while models can attain near-perfect test accuracy on training distributions, they fail catastrophically on other distributions; we demonstrate that they have learned to exploit statistical features rather than to emulate the correct reasoning function"

Basically, SGD will always latch onto statistical correlations as a shortcut for fitting the training distribution, which prevents it from finding the generalizable form of the target program (the one that would operate outside of the training distribution), despite that program being part of the search space.starai.cs.ucla.edu/papers/ZhangIJ…

There are essentially two main options to remedy this:

1. Find ways to perform active inference, so that the model adapts its learned program in contact with a new data distribution at test time. Would likely lead to some meaningful progress, but it isn't the ultimate solution, more of an incremental improvement.

2. Change the training mechanism to something more robust than SGD, such as the MDL principle. This would pretty much require moving away from deep learning (curve fitting) altogether and embracing discrete program search instead (which I have advocated for many years as a way to tackle reasoning problems...)

Share this Scrolly Tale with your friends.

A Scrolly Tale is a new way to read Twitter threads with a more visually immersive experience.
Discover more beautiful Scrolly Tales like this.

Keep scrolling