But there's lots of other examples too, e.g. antisymmetric RNN, who use differential equations as inspiration for an RNN cell: arxiv.org/abs/1902.09689
Best of all, if you look at the update mechanisms for a GRU or an LSTM -- but not simple "h_{n+1}=tanh(Wh_n + Wx_n + b)" Elman-like RNNs -- then they look suspiciously like discretised differential equations...
5/n
First moral of the story: you want to understand RNNs, or build better RNNs, then differential equations are the way to do it. <3
6/n
Alright, how do rough differential equations come into it?
These are differential equations, which (intuitively speaking...) have an input signal driving them -- but the input signal has wiggly fine-scale structure.
[Brownian motion, and SDEs, are a special case.]
7/n
Over in machine learning, what else has wiggly fine-scale structure?
Long time series!
Realistically, we probably don't care about _every_ data point in that long time series.
But each of those data points _is_ still giving us some small amount of information.
8/n
So, what to do? Throw away data? Selectively skip some data? Bin the data?
The answer (at least here) is that last one. Bin the data, very carefully, so that you get the most information out of the wiggles.
(...the least technical way possible of describing my PhD...)
9/n
In particular, we use the "logsignature". This is a map that extracts the information that's most important for describing how an input path drives a differential equation.
...and as we've already discussed, RNNs are differential equations! <3
10/n
In fact, if you work through the mathematics...
(...see the appendix if you're really curious...)
...then doing all of the above corresponds to a particular numerical solver for differential equations, called the "log-ODE method".
11/n
Now remember, GRUs were just Euler-discretised differential equations. So were ResNets. So seeing another numerical solver shouldn't be too surprising!
12/n
In fact nearly anything worth using seems to look like a discretised differential equation...
Returning to logsignatures: we get a way to feed really long data into anything RNN-like.
In this case we go full-continuous and use a Neural CDE, but any old RNN would work as well.
And because we've made things much shorter -- "extracted information from the wiggles" --
14/n
-- then all the standard problems with learning on long time series are alleviated.
No vanishing gradients.
No exploding gradients.
No taking a really really long time just to run the model...
(...looking at you, everyone I have to share GPU resources with...)
15/n
In our experiments we successfully handle datasets up to 17k samples in length.
(We could probably go longer too, but that was just what we had lying around...)
16/n
Anyway, let's wrap this up.
If you're interested in RNNs / time series...
If you're interested in long time series...
If you're interested in (neural) differential equations...
...then this might be interesting for you.
Announcing Equinox v0.1.0! Lots of new goodies for your neural networks in JAX.
-The big one: models using native jax.jit and jax.grad!
-filter, partition, combine, to manipulate PyTrees
-new filter functions
-much-improved documentation
-PyPI availability!
A thread: 1/n 🧵
First: simple models can be used directly with jax.jit and jax.grad. This is because Equinox models are just PyTrees like any other. And JAX understands PyTrees.
2/n
More complex models might have arbitrary Python types in their PyTrees -- we don't limit you to just JAX arrays.
In this case, filter/partition/combine offer a succient way to split one PyTree into two, and then recombine them.