Neural ODEs with PyTorch Lightning and TorchDyn
Effortless, Scalable Training of Neural Differential Equations
Traditional neural network models are composed of a finite number of layers. Neural Differential Equations (NDEs), a core model class of the so-called continuous-depth learning framework, challenge this notion by defining forward inference passes as the solution of an initial value problem. This effectively means that NDEs can be thought of as being comprised of a continuum of layers, where the vector field itself is parametrized by an arbitrary neural network. Since seminal work that initially popularized the idea, the framework has grown quite large, seeing applications in control, generative modeling and forecasting.
NDEs are typically the correct choice whenever the underlying dynamics or model to approximate are known to evolve according to differential equations. Another domain in which continuous-depth have proven to be beneficial is density estimation and generative modeling, where properties of the continuous formulation are used directly to reduce the computational cost of normalizing flows.
Here is a self-contained video introduction on the topic.
TorchDyn: A library for neural differential equations
Unlocking the full potential of continuous-depth models requires a specialized software ecosystem. For example, inference of NDEs is carried out via numerical differential equation solvers.
TorchDyn
, part of the broader DiffEqML software ecosystem, offers an intuitive access-point to model design for continuous-depth learning. The library follows core design ideals driving the success of modern deep learning frameworks; namely modular, object-oriented and with a focus on GPUs and batched operations.
The API will be immediately familiar to anyone who has worked on PyTorch:
NeuralDEs
represent the primary model class which can be interacted with in the usual PyTorch fashion. Internally, DEFunc
perform auxiliary operations required to preserve compatibility across NeuralDE variants, such as higher-order dynamics or handling additional dimensions for integral cost functions, distributed on the whole depth of the NDE.
What’s more, TorchDyn models leverage PyTorch-Lightning to allow for effortless and scalable training.
Below, we’ll step through Neural Ordinary Differential Equations (Neural ODEs) training with TorchDyn and PyTorch Lightning. At the end, we’ll dive deeper into recent advances and show how inference can be sped up through Hypersolvers, also trained with PyTorch Lightning.
Defining Neural ODE models
We will start with a Neural ODE for a binary classification problem. In particular, our objective is separating these two classes of points. The dataset contains pairs (x, y)
of 2D points and their label, indicated below in color
Generated as follows, using ToyDataset
of TorchDyn
:
The next step involves defining a standard pl.LightningModule
for classification tasks, whose forward
will be handled by NeuralDE
. A LightningModule
is just a nn.Module
, with added hooks to structure your model. It leaves the core training logic to you, and automate the engineering. Note how the NeuralDE
object takes as input a neural network f
defining the vector field:
We then proceed as usual, creating a Trainer
and then leveraging any PyTorch-Lightning
feature needed for training or logging. Here is a full list of Lightning features such as compatibility with both GPUs and TPUs (multi-node) 16-bit precision training and early stopping.
We can visualize the NDE flows, or the evolution of input data across their depth, through the trajectory
method. s_span
indicates the depth interval — as a mesh of points — we are interested in.
Hypersolvers in PyTorch Lightning: Faster Neural Differential Equations
Neural Differential Equations inference is typically slower than comparable discrete neural networks, since these continuous models come with the additional overhead of solving a differential equation. Various approaches have been proposed to alleviate these limitations e.g regularizing the vector field such that is easier to solve. In practice, an easier to solve NDE translates into a smaller number of function evaluations — with adaptive-step solvers — or in other words less calls to the neural network parametrizing the vector field.
However, regularizing the vector field may not always be an option, particularly in situations where the differential equation is partially known and specified a prior — as is the case for example in control applications.
The framework of hypersolvers considers instead the interplay between model and solver, analyzing ways in which the solver and the training strategies can be adapted to maximize NDE speedups.
In their simplest form, hypersolvers take a base solver formulation and enhance its performance on an NDE with an additional learning component, trained by approximating local errors of the solver. Other techniques are available, for example, adversarial training to exploit base solver weaknesses and aid in its generalization across dynamics.
We will use this HyperEuler implementation to speed up inference of the Neural ODE of the previous section. Here, we will use as hypersolver network g
a tiny neural network made up of a single linear layer:
Do not be alarmed by the high number of epochs. These will be simple training iterations, since we are doing full-batch training.
After training, we can visualize the flows of the Neural ODE solved with our hypersolver variant of the Euler method, HyperEuler
, and verify whether they are the same as those solved with the adaptive — step method:
At the cost of only 12 additional parameters, we are able to successfully reduce the number of function evaluations to less than 20, starting from 40+. Not bad!
Conclusion
We’ve shown how to train Neural ODEs through TorchDyn
and PyTorch-Lightning
, including how to speed them up with hypersolvers. Much more is possible in the continuous-depth framework, we suggest the following set of tutorials for those interested in a deeper dive.
The DiffEqML
continuous-depth ecosystem is in rapid expansion, andTorchDyn
itself is currently close to a new release including some of the latest advances from NeurIPS 2020. Our goal with DiffEqML
is to take these models to the next level of practical relevance by allowing for effortless and scalable training of NDEs and traditional neural network — NDE hybrids across GPUs and GPU clusters. We will continue to base our efforts on the solid foundation provided by PyTorch-Lightning
, and we are excited to strengthen our collaboration with even better integration.
For those interested in more content related to NDEs, follow us on Twitter for more updates. Feel free to check out our poster sessions at the upcoming NeurIPS 2020 to ask us more questions:
Michael Poli and Stefano Massaroli, DiffEqML