Tri Dao Profile picture
Nov 29 • 6 tweets • 2 min read
We're releasing an optimized implementation of GPT2/GPT3 with FlashAttention🚀!
This trains 3-5x faster than the Huggingface version, reaching up to 189 TFLOPs/sec per A100, 60.6% (model) FLOPs util of the theoretical maximum. 1/6
github.com/HazyResearch/f…
The main ingredient is FlashAttention, which computes attention fast (2-4x) and with less memory (10x), without any approximation. This means that we don't need to do any activation checkpointing
2/6
We also provide optimized implementations of other layers:
- Fused matmul + bias + gelu for the MLP (based on Apex and cuBLASLt)
- Optimized cross entropy loss (based on Apex)
- Fused rotary embedding
- Fused dropout + residual + LayerNorm (building on Apex's FastLayerNorm)
3/6
We include training scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples.
On 1-node, our implementation is 3-5x faster than Huggingface and 1.3-2x faster than Megatron-LM!
4/6 ImageImage
Of course Megatron-LM is great for large-scale distributed training, and Huggingface has such a friendly API to play with all the latest models. We're planning to integrate our work into these libraries
5/6
I'll be at #NeurIPS2022 presenting FlashAttention (poster Session 4 Hall J #917, Wednesday 4-6 PM). Drop by and say hi!
6/6

• • •

Missing some Tweet in this thread? You can try to force a refresh
 

Keep Current with Tri Dao

Tri Dao Profile picture

Stay in touch and get notified when new unrolls are available from this author!

Read all threads

This Thread may be Removed Anytime!

PDF

Twitter may remove this content at anytime! Save it as PDF for later use!

Try unrolling a thread yourself!

how to unroll video
  1. Follow @ThreadReaderApp to mention us!

  2. From a Twitter thread mention us with a keyword "unroll"
@threadreaderapp unroll

Practice here first or read more on our help page!

More from @tri_dao

May 31
Announcing FlashAttention, a fast and memory-efficient attention algorithm with no approximation! đź“Ł w/ @realDanFu

By reducing GPU memory reads/writes, FlashAttention runs 2-4x faster & requires 5-20x less memory than PyTorch standard attention, & scales to seq. length 64K. 1/ Image
Transformers have grown larger and deeper, but longer context remains difficult, since self-attention has time and memory quadratic in seq. length. Approx attn attempts to address this by trading off quality for compute complexity, but often doesn’t achieve wall-clock speedup. 2/
We argue that a missing principle is making attention algorithms *IO-aware:* accounting for reads and writes between levels of GPU memory like large but (relatively) slow high-bandwidth memory (HBM), vs small but fast SRAM. 3/
Read 11 tweets

Did Thread Reader help you today?

Support us! We are indie developers!


This site is made by just two indie developers on a laptop doing marketing, support and development! Read more about the story.

Become a Premium Member ($3/month or $30/year) and get exclusive features!

Become Premium

Don't want to be a Premium member but still want to support us?

Make a small donation by buying us coffee ($5) or help with server cost ($10)

Donate via Paypal

Or Donate anonymously using crypto!

Ethereum

0xfe58350B80634f60Fa6Dc149a72b4DFbc17D341E copy

Bitcoin

3ATGMxNzCUFzxpMCHL5sWSt4DVtS8UqXpi copy

Thank you for your support!

Follow Us on Twitter!

:(