I'm reading about GPU shortages, and it feels like the right time to highlight tips and tricks for efficient training on one GPU (memory/speed optimizations) to squeeze the most out of it. 🧵
(thread is a tldr based on the @huggingface Transformer docs huggingface.co/docs/transform…):
Pick the right batch size. For optimal resource utilization, it should be a 2^N, where N depends on dtype and hardware. For float16 on regular GPU NVIDIA recommends multiples of 8, moving to A100? Multiples of 64. Find the optimal size via hyperparameter tuning.
Weights aren't the only thing stored in memory at training, there's whole bunch of other things - optimizer states, gradients, forward activations saved for gradient computation, etc. Does it all fit in your GPU 's memory for your perfect batch size? If not, here are some tricks:
#1 Add gradient accumulation
TLDR: calculate gradients in small increments instead of on full batch, accumulate them, do the optimization step.
This increases the effective batch size beyond limitations but can slow down training because there are additional fwd and bwd passes.
#2 Add gradient checkpointing
TLDR: Instead of saving all activations from the forward pass, save only some strategically selected ones => less memory used.
Downside: slows down training by approximately 20%.
#3 Mixed precision training
TLDR: Do all variables need to be in fp32? Nah. Store some in fp16 (or less!) for faster computations.
On Ampere and newer hardware you also get bf16 (worse precision than fp16, but larger dynamic range) and tf32 (same range as fp32, lower precision).
#4 Pick a different optimizer
AdamW is great but stores the rolling average of the previous gradients => needs more memory.
Try adamw_apex_fused if you have NVIDIA/apex, or Adafactor, or 8-bit Adam (think of it as mixed-precision Adam)
#5 Preload data
Does your main process read data fast enough? Preload data into the pinned memory and spawn several workers to preload data faster. (DataLoader and 🤗 Trainer have corresponding arguments)
#6 Leverage DeepSpeed ZeRO integration
DeepSpeed is an open-source DL optimization library that both 🤗 Transformers and 🤗 Accelerate integrate with.
It has a whole bunch of optimizations to improve the efficiency of training.
#7 Use torch.compile
PyTorch 2.0 introduced torch.compile feature that significantly speeds up training on on server-class GPUs such as A100 (not as much on desktop-class GPU such as a NVIDIA 3090).
All of the above tricks can be used on its own or combined either via Trainer in 🤗 Transformers, or if you prefer writing pure PyTorch loop with 🤗 Accelerate.
If all of the above is not enough, switch to a beefier GPU if you can. If still not enough, you may need to move to a multi-GPU setup where same approaches still apply, and you also get parallelism tricks.
• • •
Missing some Tweet in this thread? You can try to
force a refresh
To my JVM friends looking to explore Machine Learning techniques - you don’t necessarily have to learn Python to do that. There are libraries you can use from the comfort of your JVM environment. 🧵👇
github.com/eclipse/deeple… : Deep Learning framework in Java that supports the whole cycle: from data loading and preprocessing to building and tuning a variety deep learning networks.
github.com/linkedin/dagli Framework for defining machine learning models, including feature generation and transformations, as directed acyclic graphs (DAGs).