Memory Efficient Coding in #PyTorch ⚑
20 tricks to optimize your PyTorch code

Let's break some of the bad habits while writing PyTorch code πŸ‘‡

A thread 🧡
1. PyTorch dataloader supports asynchronous data loading in a separate worker subprocess. Set pin_memory=True to instruct the DataLoader to use pinned memory which enables faster and asynchronous memory copy from host to GPU Image
2. Disable gradient calculation for validation or inference. Gradients aren't needed for inference or validation, so perform them within torch.no_grad() context manager. Image
3. You can even use torch.no_grad() as a function decorator if you are doing forward pass in a function multiple times Image
4. Instead of zeroing out the gradients, set it to None πŸ‘‡ Image
5. Disable bias for convolutions directly followed by batch norm

nn.Conv2d() has bias default to True, but if any of the conv layers is followed by a BatchNorm layer, then bias is not needed as batchnorm effectively cancels out the effect of bias.
6. Fuse pointwise operations

Pointwise operations (elementwise addition, multiplication, math functions) can be fused into a single kernel to reduce memory access and kernel launch time Image
7. Avoid frequent CPU to GPU transfers Image
8. Use .detach() to clear up attached computational graph
9. Construct tensors directly on the device (GPU)
In the first case, a tensor is created on the CPU first and then it is transferred to GPU, which is slow. Directly construct tensors on GPU. Image
10. Use DistributedDataParaller instead of DataParallel while training on multiple GPUs. In DataParallel (DP), the model is copied to each GPU, whereas in DDP a siloed copy of the model is created on each GPU (in its own process).
11. Use Automatic Mixed Precision. AMP is natively available in PyTorch since the 1.6 release. It speeds up training (3x faster) on certain GPUs (V100, 2080Ti), those having volta and ampere architecture. Also, it can reduce memory usage and you can use a larger batch size.
12. Delete unnecessary variables, objects, classes, lists, dictionaries using python keyword del Image
13. Use gc.collect() to free up your RAMs memory when you are done running the training/inference script. Note, using gc.collect() in between the program can delete important variables and can corrupt your data or crash your program. Use it cautiously
14. Use torch.cuda.empty_cache() to clear up GPU memory at the end of training or inference routine
14. Use larger batch sizes as much as you can fit in your memory, your GPU utilization should be 100% most of the time, to reduce bottlenecks. However, a much larger batch size could hurt the performance of your model (in terms of metrics), so see what works best for you
15. Optimize the loading and preprocessing of data as much as possible, your GPU should be kept working at all times, it shouldn't be idle waiting for data
16. Use OpenCV to read images, it is faster than PyTorch's default PIL backend
17. Turn on cudNN benchmarking, this helps if your models are making heavy use of convolutional layers Image
18. Profile your code using PyTorch profiler which tells you where most of the time is spent in your code, then you can optimize the code accordingly
19. If you are processing huge CSV files, pandas is slow as it runs on CPU. Use @RAPIDSai (cuDF) instead to preprocess data on GPU, which is a lot faster.
20. If you want all of the above-mentioned optimizations without having any overhead, use @PyTorchLightnin, they do it all for you :)
Hope this thread helped you and now you can reduce training time and avoid bottlenecks, by using the above-mentioned optimizations which in turn increases the rate of experimentations. Happy PyTorching !!!!

β€’ β€’ β€’

Missing some Tweet in this thread? You can try to force a refresh
γ€€

Keep Current with Atharva Ingle

Atharva Ingle Profile picture

Stay in touch and get notified when new unrolls are available from this author!

Read all threads

This Thread may be Removed Anytime!

PDF

Twitter may remove this content at anytime! Save it as PDF for later use!

Try unrolling a thread yourself!

how to unroll video
  1. Follow @ThreadReaderApp to mention us!

  2. From a Twitter thread mention us with a keyword "unroll"
@threadreaderapp unroll

Practice here first or read more on our help page!

More from @AtharvaIngle7

18 Sep
So, recently there was this question from @svpino that in theory, you can model any function using a neural network with a single hidden layer. However, deep networks are much more efficient than shallow ones. Why?
Here's a thread answering that question🧡
πŸ“Œ Both shallow (network with one hidden layer) and deep networks(network with multiple hidden layers) are capable of approximating any function (theoretically). But, still, you get better results with deep neural networks. Here's why πŸ‘‡
1. To match the performance of a multi-layered neural network you would require a large number of neurons in a single layer in the case of shallow networks.
Read 11 tweets
10 Sep
I am using @weights_biases from past 4-5 months and I am in love with the product. It makes working with Deep Learning projects super easy, trackable and fun. Here are some of my favourite features of wandb πŸ‘‡
1. It is super easy to integrate Weights and Biases with any framework of your choice
2. You can literally log anything to wandb from any metrics you want to track, to model's weights, data, images, configurations, etc. Basically anything you can think of
Read 13 tweets
15 Aug
Some of the best resources I came across for intuitively visualizing #NeuralNetworks (how they transform data and classify stuff).
With these resources, Neural Networks will be no longer black boxes for you'll.
A thread 🧡
A playlist by none other than @3blue1brown explaining how forward and backward propagation works with great visualizations as always. You can't miss this ...
youtube.com/playlist?list=…
A great article from @ch402 explaining how a neural net transforms the data. He has some other great blogposts too, do check out the complete website
colah.github.io/posts/2014-03-…
Read 14 tweets
14 Aug
What is sampling in #MachineLearning and what are different sampling techniques?
Detailed analysis of 10 widely used sampling techniques. (Notes at the end πŸ‘‡)
A thread 🧡
PS: There is a Notion document at the end of the thread with detailed notes on this topic 😎
Population vs Sample ✨
πŸ“Œ Population - Population is the collection of the elements which has some or the other characteristic in common.
πŸ“Œ Sample - Sample is the subset of the population. The process of selecting a sample is known as sampling
Why do we even need sampling πŸ€”?
πŸ“Œ Dealing with a complete population is very hard (almost impossible). Sampling is a method that allows us to get information about the population without investigating every individual.
Read 9 tweets
9 Aug
How to learn a Machine Learning algorithm?
Everything you need to consider while approaching to learn a #MachineLearning algorithm πŸ‘‡

A thread 🧡
1. Get the intuition behind the algorithm (i.e its core ideas and why the algorithm is there in the first place).
---
2. Get the mathematical intuition behind the algorithm (understand the math working under the hood).
---
3. For what the algorithm is used (regression/classification/both) and how it is modified to fit different scenarios.
---
4. How the algorithm works with numerical and categorical data?
---
Read 9 tweets

Did Thread Reader help you today?

Support us! We are indie developers!


This site is made by just two indie developers on a laptop doing marketing, support and development! Read more about the story.

Become a Premium Member ($3/month or $30/year) and get exclusive features!

Become Premium

Too expensive? Make a small donation by buying us coffee ($5) or help with server cost ($10)

Donate via Paypal Become our Patreon

Thank you for your support!

Follow Us on Twitter!

:(