We're releasing an optimized implementation of GPT2/GPT3 with FlashAttention🚀!
This trains 3-5x faster than the Huggingface version, reaching up to 189 TFLOPs/sec per A100, 60.6% (model) FLOPs util of the theoretical maximum. 1/6 github.com/HazyResearch/f…
The main ingredient is FlashAttention, which computes attention fast (2-4x) and with less memory (10x), without any approximation. This means that we don't need to do any activation checkpointing 2/6
We also provide optimized implementations of other layers:
- Fused matmul + bias + gelu for the MLP (based on Apex and cuBLASLt)
- Optimized cross entropy loss (based on Apex)
- Fused rotary embedding
- Fused dropout + residual + LayerNorm (building on Apex's FastLayerNorm)
3/6
We include training scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples.
On 1-node, our implementation is 3-5x faster than Huggingface and 1.3-2x faster than Megatron-LM! 4/6
Of course Megatron-LM is great for large-scale distributed training, and Huggingface has such a friendly API to play with all the latest models. We're planning to integrate our work into these libraries
5/6
I'll be at #NeurIPS2022 presenting FlashAttention (poster Session 4 Hall J #917, Wednesday 4-6 PM). Drop by and say hi! 6/6
Announcing FlashAttention, a fast and memory-efficient attention algorithm with no approximation! đź“Ł w/ @realDanFu
By reducing GPU memory reads/writes, FlashAttention runs 2-4x faster & requires 5-20x less memory than PyTorch standard attention, & scales to seq. length 64K. 1/
Transformers have grown larger and deeper, but longer context remains difficult, since self-attention has time and memory quadratic in seq. length. Approx attn attempts to address this by trading off quality for compute complexity, but often doesn’t achieve wall-clock speedup. 2/
We argue that a missing principle is making attention algorithms *IO-aware:* accounting for reads and writes between levels of GPU memory like large but (relatively) slow high-bandwidth memory (HBM), vs small but fast SRAM. 3/