Built by @|paganpasta, it's now at feature-parity with torchvision for classification models. (With segmentation, object detection etc. on the way!)
1/2
It's been super cool watching this project grow, and a big thank-you to @|paganpasta for all their upstream contributions to Equinox. (Read: fixing my bugs!🐛🪲)
2/2
I guess I'm taking a leaf out of @DynamicWebPaige's book and tweeting about #JAX ecosystem stuff... :D
• • •
Missing some Tweet in this thread? You can try to
force a refresh
If you follow me then there's a decent chance that you already know what an NDE is. (If you don't, go read the introductory Chapter 1 to my thesis haha -- it's only 6 pages long.) Put a neural network inside a differential equation, and suddenly cool stuff starts happening.
2/n
Neural differential equations are a beautiful way of building models, offering:
- high-capacity function approximation;
- strong priors on model space;
- the ability to handle irregular data;
- memory efficiency;
- a foundation of well-understand theory.
3/n
Announcing Equinox v0.1.0! Lots of new goodies for your neural networks in JAX.
-The big one: models using native jax.jit and jax.grad!
-filter, partition, combine, to manipulate PyTrees
-new filter functions
-much-improved documentation
-PyPI availability!
A thread: 1/n 🧵
First: simple models can be used directly with jax.jit and jax.grad. This is because Equinox models are just PyTrees like any other. And JAX understands PyTrees.
2/n
More complex models might have arbitrary Python types in their PyTrees -- we don't limit you to just JAX arrays.
In this case, filter/partition/combine offer a succient way to split one PyTree into two, and then recombine them.