If you follow me then there's a decent chance that you already know what an NDE is. (If you don't, go read the introductory Chapter 1 to my thesis haha -- it's only 6 pages long.) Put a neural network inside a differential equation, and suddenly cool stuff starts happening.
2/n
Neural differential equations are a beautiful way of building models, offering:
- high-capacity function approximation;
- strong priors on model space;
- the ability to handle irregular data;
- memory efficiency;
- a foundation of well-understand theory.
3/n
You can model the evolution of unknown trajectories (cough finance cough) via neural SDEs:
4/n
...or model unknown distributions via continuous normalising flows (aka the Fokker--Planck equation):
In this case the target distribution is a 2D picture. (Obtained by me and Microsoft Paint at 2am.)
5/n
Or you can study unknown physical dynamics -- here an unknown Hamiltonian system parameterised by neural kinetic and potential terms:
6/n
You can build "continuous time RNNs", by using the theory of controlled differential equations:
7/n
As a final example, you can understand neural nets via diffeqs. It's relatively famous that ResNets are the explicit Euler method applied to a neural ODE...
8/n
...but did you know that the feature that distinguishes GRUs and LSTMs from generic RNNs, is a very precise differential-equation-like structure? For example a GRU has an exponential decay term.
(No wonder they struggle to learn long-term dependencies.)
9/n
I've hinted that my thesis includes some previously unpublished material.
For example, did you know that a neural ODE can be a universal approximator even if its vector fields are not universal approximators? (That's Section 2.4.2.)
10/n
Or that all these "adjoint methods" floating about
- for ODEs, for SDEs, whatever -
are all special cases of the same thing, applied to the general notion of an "RDE" aka Rough Differential Equation? (That's Appendix C.3.)
11/n
If you've ever heard of SINDy -- symbolic regression for dynamical systems -- then there's *also* some unpublished material on improving on this via genetic algorithms. (That's Section 6.1.)
And of course, all accompanying code is provided -- available as the examples in the brand-new Diffrax software library! Your one-stop-shop for numerical differential equation solvers in #JAX.
I'm planning on doing some individual posts about some highlights from the thesis over the next few days.
Diffrax is available a little early as a sneak peak. I'll be doing a proper announcement on it next week!
14/n
Credit where it's due. A doctorate doesn't happen in a vacuum.
My friends and friends have been an amazing support. Chloe, thank you for all the food! Juliette, thank you for the south of France. Mum, Dad: thank you for everything.
On a more academic note:
15/n
CC allllll the people who might find this announcement interesting
(Feel free to @ everyone else who'd like to know about this!)
I have been fortunate to work, collaborate, or communicate with most of the above list. (+Many others, both on and off Twitter.) So in a very practical way, you made this work possible.
17/n
Okay, let's wrap this up. If you're studying NDEs and want a reference text, then maybe this is it?
231 pages of everything you ever wanted to know about N ordinary DEs, N controlled DEs, N stochastic DEs, and N rough DEs.
Appendix: as a fun historical note, whilst "Neural Ordinary Differential Equations" won best paper at NeurIPS 2018 -- which is why people have heard of the field of NDEs -- I'm actually aware of work on NDEs dating back to 1991!
Also, do check out the "Comments" section at the end of each chapter. There you can find various marginalia on extensions, references, open problems, and musings about the field of NDEs in general. :)
20/18
Fin. [Both of this tweet thread... and of my doctorate!]
Announcing Equinox v0.1.0! Lots of new goodies for your neural networks in JAX.
-The big one: models using native jax.jit and jax.grad!
-filter, partition, combine, to manipulate PyTrees
-new filter functions
-much-improved documentation
-PyPI availability!
A thread: 1/n 🧵
First: simple models can be used directly with jax.jit and jax.grad. This is because Equinox models are just PyTrees like any other. And JAX understands PyTrees.
2/n
More complex models might have arbitrary Python types in their PyTrees -- we don't limit you to just JAX arrays.
In this case, filter/partition/combine offer a succient way to split one PyTree into two, and then recombine them.