Let's consider the widely used 7B param llama architecture.
It has 32 layers, 32 heads, and d_k, d_v sizes of 128
The key issue with multi-head attention is the cost of repeatedly accessing the previously computed attention keys and values during inference.
why? (2/n)
How does multi-head attention work when generating the Nth token?
To generate each token, we calculate a query for each head. Then, we look at the keys and values for prev tokens and apply the op:
Softmax(q^TK/sqrt(d)) *V
where head q = (d_k,1) k=(d_k,N) and v=(d_v, N)
(3/12)
When generating the N+1th token, each of the 32 layers and 32 heads needs to access the (128xN) dimensional K and V matrices
For our 7B param model, we need to access 32*32*128*2*N cached values, or 2(32*32*128*2)N bytes = 520KB*N
Here we hit the wall on memory bandwidth (4/12)
This can get incredibly out of hand for massive values of N. For 64K context models, this means 33GB of data from the KV cache needs to be read from RAM!
To achieve reasonable throughput, higher batch sizes are a must. With a batch of 16, that makes the KV cache 528GB! (5/12)
To fit in memory, we’d need to split the cache across several GPUs!
With a naive implementation, we are limited by the 1.5 TB/s DRAM or 1.6TB/s inter-GPU bandwidth.
So our time per token is... (528GB+14GB)/1.6TB/s = 340 ms/token! [1]
(6/12)
[1] - 14 GB for the model weights
Multi-query attention gives a massive speedup here.
We use the same attention formula as before, but this time K,V are shared across all 32 attention heads!
This means a 32x reduction in our KV cache size. When memory bound, this gives up to a 32x speedup!
(7/12)
Lets work out some concrete numbers, the new KV cache is 16.5GB and the model weights are 14GB
This leaves us with just 16.5GB for the KV cache, and 14GB for the model weights, meaning 30GB of total data being moved at 1.5 TB/s - giving 20ms/token, a 17x speedup...
(8/12)
And the larger we make our batch size, the more substantial this speedup will be.
In the limit, the speedup should approach 32x.
In PaLM, they tested this up to 40K context windows and massive batch sizes
(9/12)
As a caveat, these are all incredibly crude estimates/calculations, and some of my math may be way off.
For example, in practice, memory bandwidth utilization will not be as high as the peak of 1.5 TB/s.
(10/12)
Furthermore, several naive transformer implementations will be moving more total data than the model weights + the KV cache.
This is often due to unnecessary repeated work (for example not fusing ops like the softmax).
(11/12)
But from the original multi-query attention paper (by the goat Noam Shazeer) they see decoding speedups of over 10x for much smaller sequence lengths
While PaLM (for much longer sequence lengths up to 40K) sees far more substantial speedups than even 30x
(12/12)
• • •
Missing some Tweet in this thread? You can try to
force a refresh
The size of all code/history on Github public repos is 92TB
The size of Google's monorepo in 2015 was 86TB (of much higher quality code)
If Google were willing to deploy code models trained on their own data, they'd have a noticable advantage over everyone else. twitter.com/i/web/status/1…
To be fair, they would have to use the diff/history data much more than the raw code.
There is almost certainly more code at Github HEAD, but Google may benefit from a richer commit history
And, I'd suspect the size of their mono repo would have substantially increased since
Furthermore, these high-quality code tokens aren't just good for code, but incredibly useful for general language modeling performance:
There are times and places for training your own models... With the release OpenAI's chatGPT API - coding is looking less like one of them.
The human-eval pass@1 rate of ChatGPT is as good as the best Open Source model's pass@100 rate.
And this is still just GPT 3.5...
Not only that, but 10x better pricing than text-davinci, and far lower latency.
After seeing this news today, I really would not want to be one of OpenAI's competitors
For those unfamiliar with pass@k, this means if I took the best open-source code model (CodeGen 16B) and sampled 100 generations, the probability that 1 of those 100 generations was correct is the same as the probability ChatGPT gets it right on the first try.