Boost any Machine Learning model with ONNX conversion

A simple understanding on how to convert a model to ONNX

Thomas Chaigneau
Towards Data Science

--

The Open Neural Network Exchange (ONNX) is an open-source ecosystem that aims to standardize and optimize artificial intelligence models across a variety of platforms.

It is a machine-readable format that can be used to exchange information between different software applications and frameworks (e.g. TensorFlow, PyTorch, etc.).

Photo by Sammy Wong on Unsplash

Historically, the ONNX format was named Toffee and was developed by the PyTorch team at Facebook. The framework was released at the end of 2017 and co-authored by Microsoft and Facebook.

Since then, the ONNX format has been supported by several other companies, including Intel, AMD, and IBM.

I personally have been working with the onnxruntime for months now and I am very happy with the fact that it is supported by a large number of frameworks and projects.

Interoperability as key feature of ONNX is one of the best asset of this tool and makes it very interesting to include in all ML projects.

“ONNX Runtime is a cross-platform inference and training machine-learning accelerator” — ONNX Runtime

In this post, I will show you the steps you need to follow and understand in order to convert your model to ONNX. I will also list all available libraries that can be used to convert your model to ONNX depending on the framework you are using.

Furthermore, I will include an example of converting a PyTorch LSTM model to ONNX.

👣 Conversion steps

All frameworks have their own way to convert their models to ONNX. But there are some common steps that you need to follow.

Let’s see what are the steps you need to follow to convert your model to ONNX.

- 🏋️ Training your model

It seems pretty obvious that you need to have a trained model to convert it to ONNX, but you need to take care of having the model architecture and the model weights available.

If you only save the model weights, you will not be able to convert it to ONNX, because the model architecture is required and really important to convert your model to ONNX.

With the model architecture, ONNX is able to trace the different layers of your model and convert it to a graph (also called an intermediate representation).

Model weights are the weights of the different layers which are used to compute the output of the model. So, they are equally important to successfully convert your model.

- 📝 Input names and output names

You will need to define the input names and the output names of your model. These metadata are used to describe the inputs and outputs of your model.

- 🧊 Input sample

As said before, ONNX will trace the different layers of your model in order to create a graph of theses layers.

While tracing the layers, ONNX will also need an input sample to understand how the model is working and what operators are used to compute the outputs.

The selected sample will be the input of the first layer of the model and is also used to define the input shape of the model.

- 🤸‍♀️ Dynamic axes

Then, ONNX requires to know the dynamic axes of your model. Most of the time during the conversion, you will use a batch size of 1.

But if you want to have a model that can take a batch of N samples, you will need to define the dynamic axes of your model to accept a batch of N samples.

e.g. This way the exported model will thus accept inputs of size [batch_size, 1, 224, 224] where `batch_size` can be any value between 1 and N.

- 🔄 Conversion evaluation

Finally, you will need to evaluate the converted model to ensure that it is a sustainable ONNX model and it is working as expected. There are two separate steps to evaluate the converted model.

The first step is to use the ONNX’s API to check the model’s validity. This is done by calling the onnx.checker.check_model function. This will verify the model’s structure and confirm if the model has a valid ONNX scheme or not.

Each node in the model isevaluated by checking the inputs and outputs of the node.

The second step is to compare the output of the converted model with the output of the original model. This is done by comparing both outputs with the numpy.testing.assert_allclose function.

This function will compare the two outputs and will raise an error if the two outputs are not equal, based on the rtol and atol parameters.

It’s common to use a rtol of 1e-03 and atol of 1e-05 for the comparison, where rtol stands for the relative tolerance and atol is the absolute tolerance.

🔧 Tools to convert your model to ONNX

For each framework, there are different tools to convert your model to ONNX. We will list the tools that are available for each framework.

The list can change over time, so please let me know in the comments if you find a tool that is not listed.

Now that you know how conversion works and you have all the tools to convert your model to ONNX, you are ready to experiment and see if it helps you in your ML projects. I bet you will find it useful!

You can check the full PyTorch example I provide in the next section or directly go to the conclusion of this post.

🧪 Full PyTorch example

I recently needed to convert a PyTorch model to ONNX. The model is a simple LSTM forecasting model that predicts the next value of an input sequence.

The model comes from a larger open-source project I worked on that aims to make training and serving forecasting models automatic. That project is called Make Us Rich.

It’s important to note that I use PyTorch Lightning to train the model. This framework is a wrapper around PyTorch and allows you to boost your training process.

If you want to know more about PyTorch Lightning, you can check the PyTorch Lightning documentation.

First of all, you need to download an example of data we will use to train and convert the model. You can download it with this simple command:

$ wget https://raw.githubusercontent.com/ChainYo/make-us-rich/master/example.csv

This file contains values on the value of BTC (Bitcoin) every hour over 365 days. We will use this data to create sequences to train our LSTM model.

In order to train one LSTM model, input data should be converted as sequences of data. These are the preprocessing functions we will use to convert the data to sequences used by the model:

All preprocessing functions needed to create sequences

We will use these functions to get our train_sequences , val_sequences and test_sequences .

Creating required sequences from data

Now that we have our sequences ready to be used by the model, let’s see in details the model architecture, Dataset and DataLoader classes.

Dataset and DataLoader classes to handle data used by the model

These two classes are used to load the data that will be used to train the model. I don’t want to describe all the details about loading data in PyTorch and PyTorch Lightning because it’s not the focus of this tutorial.

If you need more explanations about Dataset and DataLoader classes, here is a great article about it.

Here is the LSTM model architecture we are going to use:

Model architecture for the example

Now that we have defined the model architecture and the way data is loaded, we can start training the model.

Here is the training loop I used:

Training loop created with PyTorch Lightning

After training, we now have a model checkpoint that we want to convert to ONNX. We will load the model from the checkpoint file, get a data sample loaded via the DataLoader and convert the model to ONNX.

Here is the code for conversion:

You should now have a file named model.onnx in your current directory.

The last step is to evaluate the converted model to ensure that it is a sustainable ONNX model and it is working as expected.

If everything went well, you should see the following output:

 🎉 ONNX model is valid. 🎉

You have successfully converted your model to ONNX. You can now use the converted model in any framework that supports ONNX and on any machine.

Conclusion

In this post we have covered the basics of how to convert your model to ONNX and how to evaluate the converted model. Plus, we have seen an example for a forecasting LSTM PyTorch model.

I hope this post helped you to understand the conversion process of a model to ONNX. I also hope that your personal experience with ONNX will help you boost your ML projects as much as it helps me boost my own projects.

ONNX framework also includes a lot of other features that are not covered in this post like the ability to optimize training and inference for your model and for your hardware. It could be discussed in a next post, let me know if it interests you.

If you have any questions or face any issues, please feel free to contact me. I will be happy to help you.

--

--