I have been training neural networks for 10 years now.
Here are 16 ways I actively use to optimize model training:
Before we dive in, the following visual covers what we are discussing today.
Let's understand them in detail below!
These are some basic techniques:
1) Use efficient optimizers—AdamW, Adam, etc.
2) Utilize hardware accelerators (GPUs/TPUs).
3) Max out the batch size.
4) Use multi-GPU training through Model/Data/Pipeline/Tensor parallelism. Check the visual👇
5) Bayesian optimization for hyperparameter optimization:
This technique takes informed steps based on the results of the previous hyperparameter configs.
This way, the model converges to an optimal set of hyperparameters much faster.
Check these results 👇
6) Initialize parameters with He or Xavier initialization.
7) For large models, use DeepSpeed, FSDP, YaFSDP, etc.
8) Mixed precision training: Use lower precision float16 along with float32. This leads to faster computation. Check the visual 👇
9) Use DistributedDataParallel, not DataParallel.
10) Use torch.rand(2, 2, device ...) to create a tensor on GPU. A .cuda() call creates a tensor on CPU and then transfers it to GPU, which is slow.
11) Use activation checkpointing in memory constraints👇
12) Use gradient accumulation.
13) Normalize data after transferring to GPU (for integer data, like pixels):
- Normalizing before will transfer 32-bit floats to the GPU.
- Normalizing after will transfer 8-bit floats to the GPU.
- The latter is better.
Check this👇
14) Use momentum
In gradient descent, every parameter update solely depends on the current gradient. This leads to unwanted oscillations during optimization.
Momentum reduces this by adding a weighted average of previous gradient updates to the update rule.
Check this 👇
15-16) Set max_workers and pin_memory in DataLoader.
PyTorch dataloader has two terrible default settings. Update them according to your config.
Speedup is shown in the image below 👇
Those were 16 techniques that I actively use to optimize neural network training.
If I missed something, please drop that in the replies.
Here's the visual again for your reference 👇
That's a wrap!
If you found it insightful, reshare with your network.
Find me → @akshay_pachaar ✔️
For more insights and tutorials on LLMs, AI Agents, and Machine Learning!
Share this Scrolly Tale with your friends.
A Scrolly Tale is a new way to read Twitter threads with a more visually immersive experience.
Discover more beautiful Scrolly Tales like this.