NEW PAPER đź§µ: Deep neural networks are complicated, but looking inside them ought to be simple. In this paper we introduce TorchLens, a package for extracting all activations and metadata from any #PyTorch model, and visualizing its structure, in as little as one line of code. 1/
Paper:
GitHub repo:
Colab Tutorial:
"Model Menagerie" of example visualizations:
2/ nature.com/articles/s4159…
github.com/johnmarktaylor…
colab.research.google.com/drive/1ORJLGZP…
drive.google.com/drive/u/0/fold…
The core function is log_forward_pass. Just pass in any PyTorch model (as-is, no changes needed) and input, and this one line of code gives you a data structure with activations and metadata of any layer, along with an automatic visualization of the model’s computational graph 3/
What’s new here? Till now, there have been arbitrary roadblocks to fetching hidden activations from PyTorch models: for example, some methods can get activations from “modules” but not from all tensor operations, or fail for models with if-then branching ("dynamic models") 4/
TorchLens works by (temporarily!) mutating all elementary PyTorch functions that return a tensor such that every function call is logged. Since all PyTorch models are built from these "building block" functions, the approach it uses should work for any PyTorch model whatsoever 5/
So far, it's been tested on hundreds of models, spanning different inputs (images/videos/audio/ language/multimodal) & architectures (feedforward, recurrent, transformers, graph neural networks). You can see example visuals in this “model menagerie”:
6/ drive.google.com/drive/u/0/fold…
The visuals are meant to make the salient aspects of the network “pop out”: inputs are green, outputs are red, layers with trainable parameters are in gray, and layers with internally generated tensors (e.g., that generate random noise) are shown with dashed lines. 7/
TorchLens will automatically find and mark the loops in recurrent models for you, and you can visualize them in either rolled or unrolled format. Check out the CORNet models in the model menagerie for some real-life examples:
8/ drive.google.com/drive/u/0/fold…
For models with a hierarchical “Russian nesting dolls” structure that have modules within modules, you can specify how deep you want to visualize. For example, here's an "inception" module from GoogLeNet shown with different levels of nesting. 9/
If the model's operations change based on the input (a stumbling block for some approaches to PyTorch feature extraction), it simply dynamically traces what actually happens in the forward pass for that input, and marks the operations involved in evaluating the "if" statement 10/
TorchLens even works for massive models. For instance, here's the visual for the swin_v2b transformer model, which has nearly 2000 tensor operations, all of which are logged and visualized. 11/
…and here’s the visual for the Keypoint R-CNN detection model with nearly 8000 tensor operations (this one took awhile to render!)
12/ drive.google.com/file/d/1bBzDYf…
In addition to the activations and visuals, TorchLens gives you exhaustive metadata both about the overall model, and about each layer. You can see a mostly-complete list of available metadata here:
13/ static-content.springer.com/esm/art%3A10.1…
Metadata includes info about the tensors & trainable parameters for each layer, model graph info, info about the function executed in each layer (e.g., the runtime—useful for profiling), & much more. I included everything I could think of, but let me know if I missed anything 14/
Other features include being able to inspect the actual code used to execute each layer, and being able to save the gradients from a backward pass—see the CoLab for how to do this. 15/
One principle of TorchLens is "no silent errors". Since there are infinitely many possible DNNs (and potential edge cases we didn't predict), TorchLens has a built-in validation procedure to algorithmically verify the accuracy of saved activations—useful for new architectures 16/
Since its initial release, TorchLens has greatly benefitted from user feedback. For example, it now has a function to quickly save activations to new model inputs, and to apply a user-provided function (e.g., averaging conv layers over space) to activations before saving them 17/
Many thanks to members of the @KriegeskorteLab, @CogCompNeuro and @VSSMtg attendees, @alfcnz, @grez72, @_jacobprince_, and Colin Conwell (not on Twitter) for their invaluable suggestions. Since TorchLens remains in active development, we welcome your further feedback.
18/
TorchLens should be ready to go—just type the following into your terminal to install it (graphviz is required for making the visuals), and you should be set:
sudo apt install graphviz
pip install torchlens
19/
For an interactive tutorial of TorchLens, check out this Google CoLab walkthrough:
20/colab.research.google.com/drive/1ORJLGZP…
We hope that TorchLens will complement the existing outstanding packages (e.g. ThingsVision, Net2Brain, DeepDive) for DNN feature extraction: these have excellent functionality for loading models and stimuli and analyzing the resulting activations, which TorchLens lacks. 21/
You should also check out our lab’s recently-updated RSA Toolbox, which provides a complete solution for performing RSA on DNN activations and neural/behavioral data.
22/github.com/rsagroup/rsato…
We envision TorchLens being useful in various ways:
1. For neuroscience, to streamline comparing DNNs and brains,
2. For engineering, to aid in debugging, profiling, and visualizing models, and
3. For teaching, to help translate between concepts and code
23/
Understanding the brain and advancing #AI is hard enough as it is. We hope TorchLens removes all arbitrary barriers to seeing what's going on in your network, lighting up the black box so you can focus on the more interesting problems. Tell us how we can make it better!
24/
Share this Scrolly Tale with your friends.
A Scrolly Tale is a new way to read Twitter threads with a more visually immersive experience.
Discover more beautiful Scrolly Tales like this.
