Transformers are arguably the most impactful deep learning architecture from the last 5 yrs.
In the next few threads, we’ll cover multi-head attention, GPT and BERT, Vision Transformer, and write these out in code. This thread → understanding multi-head attention.
1/n
What is attention? Say you want to classify the sentiment of “attention is not too shabby.“ “shabby” suggests 😞 but “not” actually means it's 😀. To correctly classify you need to look at all the words in the sentence. How can we achieve this?
2/n
The simplest thing we can do is input all words into the network. Is that enough? No. The net needs to not only see each word but understand its relation to other words. E.g. it’s crucial that “not” refers to “shabby”. This is where queries, keys, values (Q,K,V) come in.
3/n
Let’s pass the words through a linear layer and call its outputs “values”. How do we encode relationships between values? We can mix them by summation. Now we “see” both words and relationships, but that’s still not quite right. What’s wrong with this code?
4/n
The issue is that naive summing of values assumes the relationships between all words are equal. E.g. relationship between “is” and “too” is equal to that between “not” and “shabby”. But clearly “not” <> “shabby” is more important for sentiment analysis than “is”<>”too”.
5/n
We want the orange matrix to weigh relationships based on how useful word_i is as context for word_j. So let’s create two more linear nets called “queries” and “keys”. The weight w_ij should be proportional to the inner product between the i-th Q and the j-th K.
6/n
A small but important detail is that we need to re-scale the weights by 1 / sqrt(D). Why this specific scaling? Why not 1 / D or 1 / T or some other constant? The reason is that 1 / sqrt(D) ensures that the standard deviation of the outputs is roughly equal to 1.
7/n
Finally, we need to normalize the weights along the axis that will be summed, so we use a softmax. Intuitively, Q is a question “how useful am I for word K?” High / low inner product means very / not very useful. With that we are done - this is attention!
8/n
Technically what we’ve shown is called single-head self-attention. Before going to multi-head attention, let’s code up what we’ve done so far.
9/n
What is multi-head and why do we need it? Our single-head net may overfit to the training data. In ML, ensembles are a common strategy to combat overfitting. By initializing multiple nets we get more robust results. The concat of N single heads is multi-head attention.
10/n
So multi-head is just a small tweak to single-head attention. In practice, we also add dropout layers to further prevent overfitting and a final linear projection layer. This is what a complete vectorized multi-head self-attention block looks like in PyTorch.
11/n
And there you have it - we derived attention intuitively and wrote it out in code. The main idea is quite simple.
In next posts I will cover Transformers, GPT & BERT, Vision Transformers, and other useful tricks / details. That was fun to write, hope also fun to read!
12/n END
• • •
Missing some Tweet in this thread? You can try to
force a refresh
In our new work - Algorithm Distillation - we show that transformers can improve themselves autonomously through trial and error without ever updating their weights.
No prompting, no finetuning. A single transformer collects its own data and maximizes rewards on new tasks.
1/N
We've seen a lot of successful models showing how transformers can learn in-context.
But transformers have not been shown to *reinforcement* learn in-context. To adapt to new tasks, you either need to manually specify a prompt or finetune the model (e.g. preferences).
2/N
Would be great if transformers could adapt (do RL) out-of-the-box.
Don't Decision Transformers (DTs) / Gato do RL? No!
DTs and Gato learn policies from offline data, but these policies cannot improve themselves autonomously through trial and error.
How much memory do you need to train deep neural networks? You may find the answer to be counter intuitive.
For example, suppose we're training a 4 megabyte MLP with batch_size = hidden_dim, how much memory do we need? 4MB? No - we need 8MB!
Here's why...
1/N
Consider a 1M param MLP, each param is stored as a float32. How much memory is required to train the MLP?
You might guess that it's the amount of bytes needed to store the model:
1M params * 4 bytes per float32 = 4MB.
This is wrong...
2/N
...or rather, not entirely correct.
Since we train deep nets with backpropagation, we need to store not just the model but also all of the activations from the fwd pass in order to compute gradients.
The memory needed to store activations is often >> than size(model).
3/N
Building on parts 1 & 2 which explained multi-head attention and GPT, in part 3 of the Transformer Series we'll cover masked language models like BERT.
This thread → masked language models, diff between causal and bi-directional masked attention, finetuning, and code.
1/N
Since we'll be referencing multi-head attention and GPT, make sure to read parts 1 & 2 if you're unfamiliar with these concepts.
We saw with GPT that we can pre-train language models with a causal predict-the-future objective. Instead, BERT uses a fill-in-the-blank objective. It is called bi-directional because unlike GPT (which is causal) it sees both past and future tokens at once.
GPT has been a core part of the unsupervised learning revolution that’s been happening in NLP.
In part 2 of the transformer series, we’ll build GPT from the ground up. This thread → masked causal self-attention, the transformer block, tokenization & position encoding.
1/N
In part 1 we covered multi-head attention (MHA). tl;dr attention allows a neural network to “see” all words in the input as well as their relationships. As a result the net attends to the most important words for optimizing its objective.
So far, we haven’t defined an objective for MHA to optimize. GPT uses a very simple unsupervised objective - predict the next word in a sentence given previous words. This objective is called unsupervised because it doesn’t require any labels.
Patch extraction is a fundamental operation in deep learning, especially for computer vision.
By the end of this thread, you’ll know how to implement an efficient vectorized patch extractor (no for loops) in a few lines of code and learn about memory allocation in numpy.
1/n
In deep learning we often need to preprocess inputs into patches. This can mean splitting an image into overlapping or non-overlapping 2D patches or splitting a long audio or text input into smaller equally sized chunks.
2/n
Implementing patches efficiently is harder than it seems. For example, we can load an image into a numpy array, then write a for loop to index into the array and get patches. This works but requires extra memory and the for loop is slow. Can we do better?
Humans reuse skills effortlessly to learn new tasks - can robots do the same? In our new paper, we show how to pre-train robotic skills and adapt them to new tasks in a kitchen.
tl;dr you’ll have a robot chef soon. 🧑🍳🤖
links / details below
thread 🧵 1/10
Title: Hierarchical Few-Shot Imitation with Skill Transition Models
Paper: arxiv.org/abs/2107.08981
Site: sites.google.com/view/few-shot-…
Main idea: fit generative “skill” model on large offline dataset, adapt it to new tasks
Result: show robot a new task, it will imitate it
2/10
We introduce Few-shot Imitation with Skill Transition Models (FIST). FIST first extracts skills from a diverse offline dataset of demonstrations, and then adapts them to the new downstream task. FIST has 3 steps (1) Extraction (2) Adaptation (3) Evaluation.
3/10