Now, how fast should we be decoding in theory? Well, using tp=2 on two A100s, we’re reading 81GB on each token or 40GB/A100.
Assuming we can achieve 70% of the 2TB/s bandwidth, that’s 35 tok/s.
(5/10)
Let’s swap out codellama for our multi-query model. Notice that we can directly trade context length for heads in our KV cache memory.
We multiply our context window by 32, and divide the num_heads by 32 giving us the exact same size.
(6/10)
This means for our 256k token prompt, we're decoding at 35 tok/s on a reasonable batch size
(7/10)
Now, prefill is a whole different story, and actually getting the model to pay attention to 256K tokens from training is another can of worms...
(8/10)
But MQA models give such massive inference wins for long context that they may be worth it despite slightly worse scaling perf than MHA or GQA.
Plus, we’re already working with an MQA model that beats codellama 7b on code evals.
(9/10)
Finally, 256K is not the limit of how far you can take things with vanilla transformers for reasonable inference perf.
There are a few more long context tricks that, in theory, should be able to preserve that same perf up to 1M+ ;)
(10/10)
[1] Simple model flops calc
We can also account for attention flops, but notice they'll also be insignificant compared to KV cache size:
For each of the attention heads and sequences in the batch, we’re doing a 1 x 128 dim (query) times a 128 x 256k (keys) matmul.
Then the attention weights are a 1 x 256K vector, multiplying a 256K x 128 vector (the values), giving 2 * 256K * 128 FLOPs, again * 16bs * 32 kv heads * 32 layers. This is 2 TFLOPs.
On our 2 GPUs, we should be able to do 300TPS (probs closer to 200TPS) if not bottlenecked by compute
But we're still bottnecked by mem bw, which keeps us at 35 TPS
* Actually we need to spend 2TFLOPs/token when accounting for attention, but on our 2 A100s, that would be 200-300TPS, which the memory bw bottleneck prevents us from hitting
[2] - We need to read model weights to generate each token, and we're assuming bf16 or fp16, hence the term there
• • •
Missing some Tweet in this thread? You can try to
force a refresh
Standard prompting libraries use variants of “f-strings” with subbed-in inputs.
For us, a prompt is defined as a function that maps some set of inputs X and a token budget n to some string, s:
p(X, n) = s
We call this operation "rendering"
(2/12)
For example, my inputs, X, could include conversation history, contents of the current file, chunks of documentation, and codebase context we deem relevant.
This sums to 100K tokens. But the budget we are working with may just be 4000 tokens.
We've seen two key advantages of Turbopuffer with no perf degradation:
1. Normal vector database pricing makes no sense for our workloads (lots of moderate-sized indices). 2. The normal “pods” or cluster-based indices (of Pinecone for example) add unnecessary complexity
(2/10)
Most vector databases store the indices in memory.
For older use-cases, this made sense A given customer will have several large vector indices with consistently high usage on each index.
And the index should be in memory for high-throughput/low-latency querying.
People claim LLM knowledge distillation is trivial with logprobs, but that's not quite right...
It's very tricky to distill between different tokenizers. [1]
Internally, we've solved this with a clever algorithm we called tokenization transfer
(1/7)
To start, we needed to build a sophisticated primitive called the "Logmass Trie"
It's an extended Trie where each edge not only contains a character but a weight that represents the "log probability" of that character conditional on the string thus far
(2/7)
This edge weight is just an estimate.
But it must satisfy the constraint that for a contained string X, summing the log probabilities of the edges on the path to X gives the log probability of X
(3/7)