Announcing Equinox v0.1.0! Lots of new goodies for your neural networks in JAX.
-The big one: models using native jax.jit and jax.grad!
-filter, partition, combine, to manipulate PyTrees
-new filter functions
-much-improved documentation
-PyPI availability!
A thread: 1/n 🧵
First: simple models can be used directly with jax.jit and jax.grad. This is because Equinox models are just PyTrees like any other. And JAX understands PyTrees.
2/n
More complex models might have arbitrary Python types in their PyTrees -- we don't limit you to just JAX arrays.
In this case, filter/partition/combine offer a succient way to split one PyTree into two, and then recombine them.
3/n
As this often happens around JIT/grad, then we have some convenient wrappers. (You can still use the previous explicit version if you prefer.)
How about that for easy-to-use syntax!
4/n
(For those who've used Equinox before: "filter_jit" and "filter_grad" are the successors to the old "jitf" and "gradf" functions.
We've tidied up the interface a bit, and as with tweet 3, optionally separated the filtering and the transformation.)
5/n
New filter functions, covering all the common use cases:
(is_array_like has also gotten a huge speed improvement)
6/n
Much improved documentation! In fact all the code snippets above are from the documentation. :)
This also includes new examples, such as how to use Equinox in the "classical" init/apply way e.g. in conjunction with other libraries.
7/n
And finally -- sound the trumpets! -- Equinox is now available via "pip"!
(Huge thanks to the guy who previously had "equinox" registered on PyPI, and agreed to let us use it.)
8/n
Equinox demonstrates how to use a PyTorch-like class-based API without compromising on JAX-like functional programming.
It is half tech-demo, half neural network library, and comes with no behind-the-scenes magic, guaranteed.