Neural ODEs with PyTorch Lightning and TorchDyn

Effortless, Scalable Training of Neural Differential Equations

Michael Poli
Towards Data Science

--

Various classes of Neural ODEs.

Traditional neural network models are composed of a finite number of layers. Neural Differential Equations (NDEs), a core model class of the so-called continuous-depth learning framework, challenge this notion by defining forward inference passes as the solution of an initial value problem. This effectively means that NDEs can be thought of as being comprised of a continuum of layers, where the vector field itself is parametrized by an arbitrary neural network. Since seminal work that initially popularized the idea, the framework has grown quite large, seeing applications in control, generative modeling and forecasting.

A Neural Ordinary Differential Equation (Neural ODE) with parameters, and thus vector field, varying in “depth” (s), trained to perform a binary classification task.

NDEs are typically the correct choice whenever the underlying dynamics or model to approximate are known to evolve according to differential equations. Another domain in which continuous-depth have proven to be beneficial is density estimation and generative modeling, where properties of the continuous formulation are used directly to reduce the computational cost of normalizing flows.

Here is a self-contained video introduction on the topic.

TorchDyn: A library for neural differential equations

Unlocking the full potential of continuous-depth models requires a specialized software ecosystem. For example, inference of NDEs is carried out via numerical differential equation solvers.

TorchDyn, part of the broader DiffEqML software ecosystem, offers an intuitive access-point to model design for continuous-depth learning. The library follows core design ideals driving the success of modern deep learning frameworks; namely modular, object-oriented and with a focus on GPUs and batched operations.

Some of the building blocks of TorchDyn models: here and here for more information.

The API will be immediately familiar to anyone who has worked on PyTorch:

NeuralDEs represent the primary model class which can be interacted with in the usual PyTorch fashion. Internally, DEFunc perform auxiliary operations required to preserve compatibility across NeuralDE variants, such as higher-order dynamics or handling additional dimensions for integral cost functions, distributed on the whole depth of the NDE.

What’s more, TorchDyn models leverage PyTorch-Lightning to allow for effortless and scalable training.

Below, we’ll step through Neural Ordinary Differential Equations (Neural ODEs) training with TorchDyn and PyTorch Lightning. At the end, we’ll dive deeper into recent advances and show how inference can be sped up through Hypersolvers, also trained with PyTorch Lightning.

Defining Neural ODE models

We will start with a Neural ODE for a binary classification problem. In particular, our objective is separating these two classes of points. The dataset contains pairs (x, y) of 2D points and their label, indicated below in color

Generated as follows, using ToyDataset of TorchDyn:

The next step involves defining a standard pl.LightningModule for classification tasks, whose forward will be handled by NeuralDE . A LightningModule is just a nn.Module, with added hooks to structure your model. It leaves the core training logic to you, and automate the engineering. Note how the NeuralDE object takes as input a neural network f defining the vector field:

We then proceed as usual, creating a Trainer and then leveraging any PyTorch-Lightning feature needed for training or logging. Here is a full list of Lightning features such as compatibility with both GPUs and TPUs (multi-node) 16-bit precision training and early stopping.

We can visualize the NDE flows, or the evolution of input data across their depth, through the trajectory method. s_span indicates the depth interval — as a mesh of points — we are interested in.

Flows of the trained Neural ODE. The points of each class are pulled apart. The final points are then linearly separable.

Hypersolvers in PyTorch Lightning: Faster Neural Differential Equations

Neural Differential Equations inference is typically slower than comparable discrete neural networks, since these continuous models come with the additional overhead of solving a differential equation. Various approaches have been proposed to alleviate these limitations e.g regularizing the vector field such that is easier to solve. In practice, an easier to solve NDE translates into a smaller number of function evaluations — with adaptive-step solvers — or in other words less calls to the neural network parametrizing the vector field.

However, regularizing the vector field may not always be an option, particularly in situations where the differential equation is partially known and specified a prior — as is the case for example in control applications.

Euler method versus its Hypersolver variant — solving a Neural ODE for MNIST image classification. HyperEuler maintains a lower error across the solution, in turn preserving more accurate class predictions.

The framework of hypersolvers considers instead the interplay between model and solver, analyzing ways in which the solver and the training strategies can be adapted to maximize NDE speedups.

In their simplest form, hypersolvers take a base solver formulation and enhance its performance on an NDE with an additional learning component, trained by approximating local errors of the solver. Other techniques are available, for example, adversarial training to exploit base solver weaknesses and aid in its generalization across dynamics.

We will use this HyperEuler implementation to speed up inference of the Neural ODE of the previous section. Here, we will use as hypersolver network g a tiny neural network made up of a single linear layer:

Do not be alarmed by the high number of epochs. These will be simple training iterations, since we are doing full-batch training.

After training, we can visualize the flows of the Neural ODE solved with our hypersolver variant of the Euler method, HyperEuler , and verify whether they are the same as those solved with the adaptive — step method:

Flows of the trained Neural ODE, solved with HyperEuler. The flows are almost identical to the solutions shown above, but are computed in 50% the number of function evaluations.

At the cost of only 12 additional parameters, we are able to successfully reduce the number of function evaluations to less than 20, starting from 40+. Not bad!

Conclusion

We’ve shown how to train Neural ODEs through TorchDyn and PyTorch-Lightning, including how to speed them up with hypersolvers. Much more is possible in the continuous-depth framework, we suggest the following set of tutorials for those interested in a deeper dive.

The DiffEqML continuous-depth ecosystem is in rapid expansion, andTorchDyn itself is currently close to a new release including some of the latest advances from NeurIPS 2020. Our goal with DiffEqML is to take these models to the next level of practical relevance by allowing for effortless and scalable training of NDEs and traditional neural network — NDE hybrids across GPUs and GPU clusters. We will continue to base our efforts on the solid foundation provided by PyTorch-Lightning , and we are excited to strengthen our collaboration with even better integration.

For those interested in more content related to NDEs, follow us on Twitter for more updates. Feel free to check out our poster sessions at the upcoming NeurIPS 2020 to ask us more questions:

Michael Poli and Stefano Massaroli, DiffEqML

--

--

Researcher at KAIST. Working at the intersection of deep learning, dynamical systems, optimization and control.