Non-official implementation of SNODE: Spectral Discretization of Neural ODEs

December 9, 2024

Today we’re going to implement the aforementioned paper in code. Reason 1: They did not release code. Reason 2: I am using their method in my research so I had to code it up. Reason 2: Save other people’s time? Reason 3: Noooo codddieeeee.

At a high level, the authors of SNODE propose a scientific computing-esque way to train neural odes. Neural odes take the form:

\[\frac{du(s)}{ds} = f_{\theta}(u(s), s)\]

Where f tends to be the machine learning piece, i.e. an MLP, and it is easiest to think of \(s, u\) as variables corresponding to depth and layer output respectively. Thus this equation effectively states: the infinitesimal change in output relative to an infinitesimal change in depth is dictated by a neural network.

The current prevailing methods to train this are (1) basically brute-force: forward solve the ode and use an automatic differentiation library to backprop through each step of the solve. (2) is to instead solve a backwards in time ODE, called the adjoint, to compute the gradients.

To setup the intuition for SNODE, we can say that these two methods enforce the ODE via explicitly using a solver, and fit the data by updating trainable parameters w.r.t. an objective function. Once all is said and done, given an initial condition (an example from the input data), the solver would draw (in airquotes) out a trajectory which is both amenable to the ODE and amenable to the data/given objective function.

How else could we draw a trajectory? Well… we could suggest one using some well understood basis functions? Then condition it to satisfy such things.

Thereby,

\[u(s) \approx h(s) = \sum_{i=1}^N c_i \phi_i(s)\]

and we fit the data by imposing it wherever we have it:

\[
u(s_j) = h(s_j) \quad \text{for } j = 1, 2, \ldots, M. \]

\[
min_{h(s)} \quad \left[h(s_j) - ] \quad \text{for } i = 1, 2, \ldots, M. \]. This comes down to solving a linear system of equations. If there are more data points than basis functions (\(M > N\)), the system is overdetermined; if there are less, the system is underdetermined; and if they are equal we have a dandy square system to solve. In the paper, legendre polynomials are used but we will swap them out with kernels because we prefer those.

Solving this system begets us all of of the coefficients \(c\), which assure we’re matching our data! Now given the coefficients which let us fit the data, we can look for agreeable trainable parameters which also assert we’re enforcing the ODE.

We can easily express the derivative our trajectory approximation \(h(s)\) by differentiating the basis:

\[ \frac{d h(s)}{ds} \approx \sum_{i=1}^N c_i \frac{ d \phi_i(s)}{d s} \]

And what we need to do now is find the best \({\theta}\) to match \(f_{\theta}\) to \(\frac{d h(s)}{ds}\). But we can do this for any \(s\) because we can evaluate our approximation anywhere! Quadrature nodes are not a bad idea?

\[
min_\theta \quad \left[ f_{\theta}(h(s_i), s_i) - \frac{d h(s_i)}{ds} \right] \quad \text{for } i = 1, 2, \ldots, M. \].

On to code!

We’ll use the libraries equinox/diffrax. The former is a pleasant neural network library on top of jax and the latter is a numerical differential equation solver library in jax (which is written by equinox authors and so works seamlessly with equinox).

We’ll also employ the deformed sines/cosines dataset in the diffrax neural-ode tutorial because it’ll be easy to sanity check our results in seeing how the vanilla NODE does.

```