📱 Mobile Machine Learning

Deploying Scikit-Learn Models In Android Apps With ONNX

Using scikit-learn models to perform inference in Android apps

Shubham Panchal
Towards Data Science
7 min readAug 22, 2022

--

Photo by Plann on Pexels

Scikit-learn is truly a package that revolutionized machine learning and data science and is still the most fundamental prerequisite for any ML/DS role. But as the domain of machine learning is making in its way from research to industry, deployment of ML models now plays a crucial role in the software development cycle.

But most of us could only run scikit-learn models in Python scripts or Jupyter notebooks and only a limited number of blogs/videos discuss about their deployment. Deployment is easy with web frameworks like Flask, Django or FastAPI that can help build an API for interfacing the app code with the ML code.

Deployment of ML models that could bring innovative features to mobile apps would be a great deal, as smartphones are the only devices that remain with users throughout and handle most of their workload. So, in this story, we’ll discuss how to deploy scikit-learn models on Android apps using ONNX that would act as a bridge between both the worlds. If you aren’t aware of ONNX, don’t worry even I wasn’t aware before I made this story!

As I’m an Android developer, I can discuss the deployment process of an ML-model on Android from head to toe. For iOS or cross-platform frameworks, a procedure similar to what we’ll discuss can be followed too.

First, we’ll convert the scikit-learn model to ONNX then use this ONNX model in an Android app.

The source code for the Android app can be found here ->

📱 Mobile Machine Learning in Android

17 stories

1. Getting the ONNX model ( in Python )

What is ONNX?

ONNX stands for Open Neural Network Exchange, which is an intermediate representation used to convert ML models from one framework to another easily. It is an open-source project co-developed by Microsoft, Amazon and Meta to bring robustness in deploying and running ML models thus making them framework-invariant.

For instance, if you had to convert a PyTorch model to a TensorFlow model, you can first convert it to an ONNX model. TensorFlow supports loading ONNX models, so you can then get a TF model from the ONNX model. Here’s a detailed blog on the ONNX framework,

A. Building the Scikit-Learn model

The first step is to build a scikit-learn model in Python. To demonstrate the deployment process, we’d use a simple LinearRegression model to make predictions. You may download the student-study-hours dataset by Himanshu Nakrani from Kaggle. The dataset is available with the CC 1.0 Public Domain License.

from sklearn.linear_model import LinearRegression
import pandas as pd
import numpy as np

# Reading the data from the CSV file
data = pd.read_csv( 'score.csv' )

X , y = data.values[ : , 0 ] , data.values[ : , 1 ]
X = np.expand_dims( X , axis=1 )

# Fitting the linear regression model
regressor = LinearRegression()
regressor.fit( X , y )

# Make predictions
print( f'Prediction for 8.5 hours is {regressor.predict([[8.5]])[0]}' )

Run the above snippet and note down the output. We’ll use the same input variable i.e. 8.5 to make a prediction in the Android app

B. Installing sklearn2onnx and ONNX conversion

We’ll install sklearn-onnx that will help us convert the scikit-learn model to an ONNX model ( .onnx ).

pip install skl2onnx

sklearn-onnx can convert mostly all scikit-learn models and preprocessing components and you can find a list here. Next, we convert the scikit-learn model,

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# Specify an initial type for the model ( similar to input shape for the model )
initial_type = [
( 'input_study_hours' , FloatTensorType( [None,1] ) )
]

# Write the ONNX model to disk
converted_model = convert_sklearn( regressor , initial_types=initial_type )
with open( "sklearn_model.onnx", "wb" ) as f:
f.write( converted_model.SerializeToString() )

C. Conversion to ONNX Runtime ( Optional )

This step is optional, and we can directly run the .onnx model in Android. The point by highlighted by Scott McKay in the Scikit_Learn_Android_Demo, as,

Its ( ORT format ) main benefit is allowing usage of the smaller build (onnxruntime-mobile android package) if binary size is a big concern. Same graph optimizations are done by ONNX Runtime for either model format — they’re just done ahead of time when creating the .ort format model so we don’t have to include optimizer code in the smaller build.

However, there’s a chance your model might not be supported by the onnxruntime-mobile package due to the operators/types it uses, so using a .onnx model with onnxruntime-android is the simplest and safest option.

If binary size is an issue and your model is not supported by onnxruntime-mobile, you can do a custom build to create a package that only includes the operators and types your model requires. This provides the smallest possible binary size.

Note, till now we’ve only converted the scikit-learn model to ONNX. ONNX is an open format to represent machine learning models. Running ONNX models on different systems or devices or accelerating the inference of ONNX would require a runtime. The runtime should be able to parse the ONNX model format and make predictions from it. Such a runtime is onnxruntime from Microsoft open-source.

We’ll now convert the .onnx model to the .ort format. The official docs say,

The ORT format is the format supported by reduced size ONNX Runtime builds. Reduced size builds may be more appropriate for use in size-constrained environments such as mobile and web applications.

pip install onnxruntime

Next, we need to use the convert_onnx_models_to_ort script to perform the conversion, like,

!python -m onnxruntime.tools.convert_onnx_models_to_ort sklearn_model.onnx

You can also run the .ort build in Python. See this official tutorial,

Next, we move on to the Android part and discuss how we can make predictions using the .ort model.

2. Using the model in Android app ( in Kotlin )

You may open a new Android Studio project with desired configurations and add the onnxruntime Maven dependency in build.gradle file (module-level)

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'

A. Create an OrtSession

First, we need to copy the .ort or .onnx file add it the app\src\main\res\raw folder. This folder isn’t created initially when you create a new Android app project. Refer to this SO answer if you haven’t created the raw folder.

The ORT file in the ‘raw’ folder. Image Source: Author

Next, define a method in MainActivity.kt to create an OrtSession with an OrtEnvironment .

// Create an OrtSession with the given OrtEnvironment
private fun createORTSession( ortEnvironment: OrtEnvironment ) : OrtSession {
val modelBytes = resources.openRawResource( R.raw.sklearn_model ).readBytes()
return ortEnvironment.createSession( modelBytes )
}

B. Performing inference

Next, we need to run the ORT/ONNX file using the OrtSession that we created in the previous step. Here’s a method runPrediction to do so,

// Make predictions with given inputs
private fun runPrediction( input : Float , ortSession: OrtSession , ortEnvironment: OrtEnvironment ) : Float {
// Get the name of the input node
val inputName = ortSession.inputNames?.iterator()?.next()
// Make a FloatBuffer of the inputs
val floatBufferInputs = FloatBuffer.wrap( floatArrayOf( input ) )
// Create input tensor with floatBufferInputs of shape ( 1 , 1 )
val inputTensor = OnnxTensor.createTensor( ortEnvironment , floatBufferInputs , longArrayOf( 1, 1 ) )
// Run the model
val results = ortSession.run( mapOf( inputName to inputTensor ) )
// Fetch and return the results
val output = results[0].value as Array<FloatArray>
return output[0][0]
}

We create an OrtTensor with the shape that we provided as initial_type in step 1-B. Then, we simply need to run ortSession.run to run the ORT model. With some customized input in Android, we can add the following lines in the onCreate method in MainActivity.kt ,

// Initialize the views
val inputEditText = findViewById<EditText>( R.id.input_edittext )
val outputTextView = findViewById<TextView>( R.id.output_textview )
val button = findViewById<Button>( R.id.predict_button )

button.setOnClickListener {
// Parse input from inputEditText
val inputs = inputEditText.text.toString().toFloatOrNull()
if ( inputs != null ) {
val ortEnvironment = OrtEnvironment.getEnvironment()
val ortSession = createORTSession( ortEnvironment )
val output = runPrediction( inputs , ortSession , ortEnvironment )
outputTextView.text = "Output is ${output}"
}
else {
Toast.makeText( this , "Please check the inputs" , Toast.LENGTH_LONG ).show()
}
}

Run the app and see the magic yourself. The app produces the same output for 8.5 hours of study ( input variable ) as it would do in Python,

Demo of the app. Image Source: Author

The End

As you might have realized till now, the process of this conversion was very easy and doesn’t include any cross-platform conversions. Many developers, startups and ML practitioners might have thought of such conversion, but they might be unaware of this simple process.

This may be a glorified hack, so if you’re facing any issues with this workaround, do share them in the comments so that other readers know about them. You can also open an issue on the GitHub repository.

Hope you found this story interesting. If you have any queries, do let me know in the comments or send me a message on equipintelligence@gmail. Keep learning and have a nice day ahead!

--

--