New paper: Neural Rough Differential Equations !

Greatly increase performance on long time series, by using the mathematics of rough path theory.

arxiv.org/abs/2009.08295
github.com/jambo6/neuralR…

Accepted at #ICML2021!

🧵: 1/n
(including a lot about what makes RNNs work) Image
(So first of all, yes, it's another Neural XYZ Differential Equations paper.

At some point we're going to run out of XYZ differential equations to put the word "neural" in front of.)

2/n
As for what's going on here!

We already know that RNNs are basically differential equations.

Neural CDEs are the example closest to my heart. These are the true continuous-time limit of generic RNNs:
arxiv.org/abs/2005.08926
github.com/patrick-kidger…

3/n Image
But there's lots of other examples too, e.g. antisymmetric RNN, who use differential equations as inspiration for an RNN cell:
arxiv.org/abs/1902.09689

There's also been lots of other work on this going back to the 90s.
sciencedirect.com/science/articl…

4/n
Best of all, if you look at the update mechanisms for a GRU or an LSTM -- but not simple "h_{n+1}=tanh(Wh_n + Wx_n + b)" Elman-like RNNs -- then they look suspiciously like discretised differential equations...

5/n Image
First moral of the story: you want to understand RNNs, or build better RNNs, then differential equations are the way to do it. <3

6/n
Alright, how do rough differential equations come into it?

These are differential equations, which (intuitively speaking...) have an input signal driving them -- but the input signal has wiggly fine-scale structure.

[Brownian motion, and SDEs, are a special case.]

7/n Image
Over in machine learning, what else has wiggly fine-scale structure?

Long time series!

Realistically, we probably don't care about _every_ data point in that long time series.

But each of those data points _is_ still giving us some small amount of information.

8/n Image
So, what to do? Throw away data? Selectively skip some data? Bin the data?

The answer (at least here) is that last one. Bin the data, very carefully, so that you get the most information out of the wiggles.

(...the least technical way possible of describing my PhD...)

9/n Image
In particular, we use the "logsignature". This is a map that extracts the information that's most important for describing how an input path drives a differential equation.

...and as we've already discussed, RNNs are differential equations! <3

10/n
In fact, if you work through the mathematics...

(...see the appendix if you're really curious...)

...then doing all of the above corresponds to a particular numerical solver for differential equations, called the "log-ODE method".

11/n Image
Now remember, GRUs were just Euler-discretised differential equations. So were ResNets. So seeing another numerical solver shouldn't be too surprising!

12/n
In fact nearly anything worth using seems to look like a discretised differential equation...

A popular example from Twitter yesterday: MomentumNets. (@PierreAblin arxiv.org/abs/2102.07870)



13/n Image
Returning to logsignatures: we get a way to feed really long data into anything RNN-like.

In this case we go full-continuous and use a Neural CDE, but any old RNN would work as well.

And because we've made things much shorter -- "extracted information from the wiggles" --

14/n
-- then all the standard problems with learning on long time series are alleviated.

No vanishing gradients.
No exploding gradients.
No taking a really really long time just to run the model...

(...looking at you, everyone I have to share GPU resources with...)

15/n
In our experiments we successfully handle datasets up to 17k samples in length.

(We could probably go longer too, but that was just what we had lying around...)

16/n
Anyway, let's wrap this up.

If you're interested in RNNs / time series...
If you're interested in long time series...
If you're interested in (neural) differential equations...
...then this might be interesting for you.

17/n Image
The paper is here: arxiv.org/abs/2009.08295
The code is here: github.com/jambo6/neuralR…

If you want to use this yourself then the necessary tools have been implemented for you in torchcde:

github.com/patrick-kidger…

18/n
Thank you for coming to my TED talk about RNNs and differential equations. :)

19/20
A huge thanks to James Morrill for taking the lead on this paper.

(Who unfortunately doesn't have a Twitter account I can link, so he's asked me to do this instead :) .)

20/20
Appendix for the physicist:

If you have a physics background, then you may know of "Magnus expansions". These are actually a special case of the log-ODE method.

en.wikipedia.org/wiki/Magnus_ex…

21/20
Appendix for the stochastic analyst:

The rough path theory used here is super cool.

It gives you a pathwise notion of solution to SDEs:

en.wikipedia.org/wiki/Rough_pat…

And gives an easy universal approximation theorem for RNNs/CDEs!

arxiv.org/abs/2005.08926 (Appendix B)

22/20

• • •

Missing some Tweet in this thread? You can try to force a refresh
 

Keep Current with Patrick Kidger

Patrick Kidger Profile picture

Stay in touch and get notified when new unrolls are available from this author!

Read all threads

This Thread may be Removed Anytime!

PDF

Twitter may remove this content at anytime! Save it as PDF for later use!

Try unrolling a thread yourself!

how to unroll video
  1. Follow @ThreadReaderApp to mention us!

  2. From a Twitter thread mention us with a keyword "unroll"
@threadreaderapp unroll

Practice here first or read more on our help page!

More from @PatrickKidger

13 Sep
Announcing Equinox v0.1.0! Lots of new goodies for your neural networks in JAX.

-The big one: models using native jax.jit and jax.grad!
-filter, partition, combine, to manipulate PyTrees
-new filter functions
-much-improved documentation
-PyPI availability!

A thread:
1/n 🧵 Image
First: simple models can be used directly with jax.jit and jax.grad. This is because Equinox models are just PyTrees like any other. And JAX understands PyTrees.

2/n Image
More complex models might have arbitrary Python types in their PyTrees -- we don't limit you to just JAX arrays.

In this case, filter/partition/combine offer a succient way to split one PyTree into two, and then recombine them.

3/n Image
Read 9 tweets
3 Aug
Announcing Equinox!

github.com/patrick-kidger…

A JAX neural network library with
- a PyTorch-like class API for model building
- whilst *also* being functional (no stored state)

It leverages two tricks: *filtered transformations* and *callable PyTrees*.

1/n🧵
First of all, I know what you're thinking. We already have e.g. Flax and Haiku (+ a few others as well).

What's new, and do we really need another?

To the best of my knowledge, Equinox overcomes some of the core technical difficulties faced in previous libraries.

2/n
We love having a PyTorch-like class API for model building.

We love having JAX-like functional programming.

But these seem like completely different paradigms, and making them work together is tricky.

3/n
Read 20 tweets

Did Thread Reader help you today?

Support us! We are indie developers!


This site is made by just two indie developers on a laptop doing marketing, support and development! Read more about the story.

Become a Premium Member ($3/month or $30/year) and get exclusive features!

Become Premium

Too expensive? Make a small donation by buying us coffee ($5) or help with server cost ($10)

Donate via Paypal

Or Donate anonymously using crypto!

Ethereum

0xfe58350B80634f60Fa6Dc149a72b4DFbc17D341E copy

Bitcoin

3ATGMxNzCUFzxpMCHL5sWSt4DVtS8UqXpi copy

Thank you for your support!

Follow Us on Twitter!

:(