Want to make your inference code in PyTorch run faster? Here’s a quick thread on doing exactly that.
1. Replace torch.no_grad() with the ✨torch.inference_mode()✨ context manager.
2. ⏩ inference_mode() is torch.no_grad() on steroids
While NoGrad excludes operations from being tracked by Autograd, InferenceMode takes that two steps ahead, potentially speeding up your code (YMMV depending on model complexity and hardware)
3. ⏩ InferenceMode reduces overheads by disabling two Autograd mechanisms - version counting and metadata tracking - on all tensors created here ("inference tensors").
Disabled mechanisms mean inference tensors have some restrictions in how they can be used 👇
4. ⚠️ Inference tensors can't be used outside InferenceMode for Autograd operations.
⚠️ Inference tensors can't be modified in-place outside InferenceMode.
✅ Simply clone the inference tensor and you're good to go.
5.❗Note that the highest speedups are for lightweight operations that are bottlenecked by the tracking overhead.
❗If the ops are fairly complex, disabling tracking with InferenceMode doesn't provide big speedups; e.g. using InferenceMode on ResNet101 forward