The world’s leading publication for data science, AI, and ML professionals.

How to Migrate Your Python Machine Learning model to Other Languages

A 3-step tutorial on how to migrate your ML model to Java, Go, C++ or any other language. It's easier than you may think.

Photo by Bit Cloud on Unsplash
Photo by Bit Cloud on Unsplash

I recently worked on a project, where I needed to train a Machine Learning model that would run on the Edge – meaning, the processing and prediction occur on the device that collects the data.

As usual, I did my Machine Learning part in Python and I haven’t thought much about how we’re going to port my ML stuff to the edge device, which was written in Java.

When the modeling part was nearing the end, I started researching how to load a LightGBM model in Java. Prior to this, I had a discussion with a colleague who recommended that I retrain the model with the XGBoost model, which can be loaded in Java with XGBoost4J dependency.

LightGBM and XGBoost are both gradient boosting libraries with a few differences. I would expect to get a similar model if I decided to retrain the model with XGBoost, but I didn’t want to rerun all the experiments as there had to be a better way.

To my luck, I found a simple way to load any Machine Learning model in Python to any other language.

By reading this article, you’ll learn:

  • What is PMML?
  • How to save a Python model to PMML format?
  • How to load the PMML model in Java?
  • How to make predictions with the PMML model in Java?

Meet PMML

Photo by Dom Fou on Unsplash
Photo by Dom Fou on Unsplash

From Wikipedia:

The Predictive Model Markup Language (PMML) is an XML-based predictive model interchange format conceived by Dr. Robert Lee Grossman, then the director of the National Center for Data Mining at the University of Illinois at Chicago. PMML provides a way for analytic applications to describe and exchange predictive models produced by data mining and machine learning algorithms.

PMML supports:

  • Neural Networks
  • Support Vector Machines
  • Association rules
  • Naive Bayes classifier
  • Clustering models
  • Text models
  • Decision trees (Random forest)
  • Gradient Boosting (LightGBM and XGBoost)
  • Regression models

PMML enables us to load a Machine Learning model, that was trained in Python, in Java, Go lang, C++, Ruby and others.

1. Using PMML

Photo by JESHOOTS.COM on Unsplash
Photo by JESHOOTS.COM on Unsplash

My first thought after learning about PMML was that I would need to radically refactor the code, which would make retraining the model with XGBoost more feasible.

After thinking about it, I decided to give PMML a try. It has a well-maintained repository with clear instructions – which is always a good sign.

You can simply install the PMML package with:

pip install sklearn2pmml

sklearn2pmml package is needed to export the Python Machine Learning model to PMML format. Using it is simple, we just need to wrap the classifier with PMMLPipeline class.

To make it easier for you, I wrote a simple gist that trains a LightGBM model on the Iris dataset and exports the model to PMML format:

  1. Import required Python packages
  2. Load Iris dataset
  3. Split Iris dataset to train and test sets
  4. Train LightGBM model with PMML support – this is the only required change in your code.
  5. Measure classification accuracy of the model.
  6. And finally, save the model to the PMML format.

    How does the PMML model look like?

The code above creates a PMML file, which is an XML file. The XML contains all the model details as seen in the image below.

PMML model is stored in an XML file (image by author).
PMML model is stored in an XML file (image by author).

2. How to load the model in Java?

Photo by Jake Young on Unsplash
Photo by Jake Young on Unsplash

We trained the model with Python and exported it to PMML format, now we need to load it in Java.

I created a minimalistic repository LoadPMMLModel on Github, which shows how to load a PMML model in Java.

The first step is to add a PMML dependency to pom.xml (I’m using maven dependency manager):

<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.5.15</version>
</dependency

I saved the PMML file to the project’s resources folder so that compiler can package it.

PMML file is in the resources folder (image by author)
PMML file is in the resources folder (image by author)

Then we need to specify the path to the model:

String modelFolder = LoadPMMLModel.class.getClassLoader().getResource("model").getPath();
String modelName = "boosting_model.pmml";
Path modelPath = Paths.get(modelFolder, modelName);

Loading the model with PMML model is as simple as (the variable with the model in Java is Evaluator type):

Evaluator evaluator = new LoadingModelEvaluatorBuilder()
                .load(modelPath.toFile())
                .build();
evaluator.verify();

3. How to make predictions in Java?

Photo by Tadeusz Lakota on Unsplash
Photo by Tadeusz Lakota on Unsplash

Now let’s make a few predictions with the loaded model.

In Python, the prediction for the first sample in the test set was 1.

First prediction in python (image by author).
First prediction in python (image by author).

Let’s use the same sample as above in Python, but in Java:

Map<String, Double> features = new HashMap<>();
features.put("sepal length (cm)", 6.1);
features.put("sepal width (cm)", 2.8);
features.put("petal length (cm)", 4.7);
features.put("petal width (cm)", 1.2);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
for (InputField inputField : inputFields) {
            FieldName inputName = inputField.getName();
            Double value = features.get(inputName.toString());
            FieldValue inputValue = inputField.prepare(value);
            arguments.put(inputName, inputValue);
}

And query the model in Java for prediction:

Map<FieldName, ?> results = evaluator.evaluate(arguments);
// Extracting prediction
Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
Integer yPred = (Integer) resultRecord.get(targetName.toString());
System.out.printf("Prediction is %dn", yPred);
System.out.printf("PMML output %sn", resultRecord);

With the code above, we get the following output:

The output of the model in Java (image by author)
The output of the model in Java (image by author)

Conclusion

Photo by Johannes Plenio on Unsplash
Photo by Johannes Plenio on Unsplash

In my Machine Learning learning project, I used a regression boosting model. To my surprise, the exported PMML model produced the same results to the fifth decimal as the model in Python.

I don’t have anything bad to say about PMML as it works reliably in production.

Remember, you don’t need to copy-paste the code from this article as I created LoadPMMLModel repository on Github.

Please let me know what are your thoughts about PMML.

Let’s connect

Talk: Book a call Socials: YouTube 🎥 | LinkedIn | Twitter Code: GitHub


Related Articles