Announcing Equinox!

github.com/patrick-kidger…

A JAX neural network library with
- a PyTorch-like class API for model building
- whilst *also* being functional (no stored state)

It leverages two tricks: *filtered transformations* and *callable PyTrees*.

1/n🧵
First of all, I know what you're thinking. We already have e.g. Flax and Haiku (+ a few others as well).

What's new, and do we really need another?

To the best of my knowledge, Equinox overcomes some of the core technical difficulties faced in previous libraries.

2/n
We love having a PyTorch-like class API for model building.

We love having JAX-like functional programming.

But these seem like completely different paradigms, and making them work together is tricky.

3/n
This tension means that existing libraries:
- need their own notions of parameters groups
- need class-to-functional transformations
- need to wrap transforms like "jax.jit" into "library.jit"

Which is "fine" -- but isn't the functional programming paradise JAX promises.

4/n
The first secret ingredient behind what Equinox does differently is *filtered transformations*:

These are very thin wrappers around jax.grad and jax.jit.

(And are different to other "library.jit"-s, see below.)

5/n
Instead of specifying whole arguments to JIT/differentiate, e.g.

jax.grad(..., argnums=...)

with Equinox you are instead specifying a *filter function* that selects which *PyTree leaves* to JIT/differentiate. That is -- only *part* of each input.

6/n
(Recall that a "PyTree" is what JAX calls nested combinations of lists, dictionaries etc. The leaves of the PyTree can be arbitrary Python objects.

Pretty much everything in JAX is about functions between PyTrees: function inputs are PyTrees, function outputs are PyTrees.)

7/n
This is a game-changer. You can represent your whole model as a single PyTree -- not just its parameters, but everything else as well!

Then filter out any arbitrary Python objects (e.g. activation functions) -- and keep just the parameters you want to JIT or differentiate.

8/n
This brings us the second secret ingredient to Equinox.: we still need to build such PyTrees-that-are-models.

Each layer type (Linear, Dropout, etc.) is both a class *and* a PyTree.

9/n
Because it's a class, we can define methods.

Each method is now a parameterised function; parameterised by stored weights, sublayers etc.

*This is great for representing models*, because model forward passes are nothing other than (weight-)parameterised functions.

10/n
Because each layer is also a PyTree, then we can pass it around through JAX functions -- like JIT and autodifferentiation.

(And this includes plain jax.jit and grad if you don't want to filter anything! You don't have to use a special "library.jit" to make things work.)

11/n
This is so important that I'm going to emphasise it: *Equinox allows you to represent parameterised functions as data*.

The upshot is that you can input them into higher-order functions like JIT and autodifferentiation, and JAX will handle them just like it always does.

12/n
There's no special "library.Module" that has to used in certain ways.

There's no "library.jit" or "library.grad" needed specifically to work with "libary.Module"s.

13/n
There's nothing special about Equinox modules. They're just PyTrees.

There's nothing specal about filtered transformations. They just operate on PyTrees.

Equinox is all just regular JAX -- PyTrees and transformations!

14/n
Equinox allows us to build parameterised models without compromising on JAX's functional principles.

(Moreover, we can do so without introducing anything we need to handle differently to regular JAX. Equinox is not a framework.)

15/n
In summary, Equinox offers:
- a PyTorch-like API for building models
- joined with JAX-style functional programming
- without any magic behind the scenes: the Equinox library is only a few lines of code!

16/n
At this stage, Equinox is still half tech-demo and half neural network library.

So I'd love for some feedback!

Is Equinox useful to you? What do you think about the above discussion?

Check it out here:

github.com/patrick-kidger…

17/17
Appendix:

If any JAX/Haiku/Flax/Stax/Objax developers are reading this, I'd be interested to know what you think.

I'm aware of similarities to flax.linen, with its dataclasses-as-PyTrees, but as far as I can see, it stops short of going all the way to treating...

18/17
...parameterised functions as data like anything else (wrt higher-order functions).

I'm also guessing that the basic design principles of Equinox are probably present elsewhere, and are just new to JAX. (e.g. I also see similarities to functors in Flux.jl.)

19/17
Still, AFAIK this is new to the JAX ecosystem at least.

So does the way Equinox solves things seem reasonable?

What criticisms would you offer?

What motivated the design choices made before?

LMK!

20/17

• • •

Missing some Tweet in this thread? You can try to force a refresh
 

Keep Current with Patrick Kidger

Patrick Kidger Profile picture

Stay in touch and get notified when new unrolls are available from this author!

Read all threads

This Thread may be Removed Anytime!

PDF

Twitter may remove this content at anytime! Save it as PDF for later use!

Try unrolling a thread yourself!

how to unroll video
  1. Follow @ThreadReaderApp to mention us!

  2. From a Twitter thread mention us with a keyword "unroll"
@threadreaderapp unroll

Practice here first or read more on our help page!

Did Thread Reader help you today?

Support us! We are indie developers!


This site is made by just two indie developers on a laptop doing marketing, support and development! Read more about the story.

Become a Premium Member ($3/month or $30/year) and get exclusive features!

Become Premium

Too expensive? Make a small donation by buying us coffee ($5) or help with server cost ($10)

Donate via Paypal Become our Patreon

Thank you for your support!

Follow Us on Twitter!

:(