📱 Mobile Machine Learning

Deploying Pretrained TF Object Detection Models on Android

Right from trained checkpoints to an Android app

Shubham Panchal
Towards Data Science
11 min readOct 21, 2021

--

Photo by Sebastian Bednarek on Unsplash

Deploying machine learning models on mobile devices is the new phase of ML that’s about to begin. Vision models, mostly object detection models, have already made their way to mobile devices along with speech recognition, image classification, text completion etc. These models, run usually run on GPU-enabled computers, have tons of use-cases when deployed on mobile devices.

In order to demonstrate an end-to-end example on how to bring a ML model, specifically an object-detection model to Android, we’ll use Victor Dibia’s hand detection models for our demonstration, from the victordibia/handtracking repo. The model can detect human hands from an image and is made using the TensorFlow Object Detection API. We’ll use the trained checkpoints from Victor Dibia’s repo and convert them to the TensorFlow Lite ( TFLite ) format, which can be used to run the model on Android ( or even iOS, Raspberry Pi ).

Next, we move on to the Android app and create all the necessary classes/methods required to get the model running and also display its predictions (the bounding boxes ) over the live camera feed.

Let’s get started!

Contents

Conversion of Model Checkpoints to TFLite

👉 1. Setting Up the TF Object Detection API

👉 2. Converting the Checkpoints to Frozen Graphs

👉 3. Converting the Frozen Graphs to the TFLite Buffer

Integrating the TFLite model in Android

👉 1. Adding Dependencies For CameraX, Coroutines and TF Lite

👉 2. Initializing CameraX and ImageAnalysis.Analyzer

👉 3. Implementing the Hand Detection Model

👉 4. Drawing the Bounding Boxes over the Camera Feed

Projects/Blogs from the author

Conversion of Model Checkpoints to TFLite

Our first step will be to convert the trained model checkpoints, provided in Victor Dibia’s repo (MIT License), to the TensorFlow Lite format. TensorFlow Lite provides an efficient gateway to run TensorFlow models on Android, iOS and micro-controller devices. In order to run the conversion scripts, we need to set up the TensorFlow Object Detection API on our machine. You may also use this Colab Notebook to perform all conversions in the cloud.

I recommend you to use the Colab Notebook ( especially for Windows ) as I personally went into a number of errors while doing so.

1. Setting Up the TF Object Detection API

The TensorFlow Object Detection API provides a number of pretrained object detection models which can be fine-tuned on custom datasets and deployed directly into mobile, web, or the cloud. We’ll only require the conversion scripts that help us convert the model checkpoints into a TF Lite buffer.

The hand detection model was itself made using the TF OD API with TensorFlow 1.x. So, first we need to install TensorFlow 1.x or TF 1.15.0 ( the latest version in the 1.x family ) and then clone the tensorflow/models repo which contains the TF OD API.

# Installing TF 1.15.0
!pip install tensorflow==1.15.0

# Cloning the tensorflow/models repo
!git clone https://github.com/tensorflow/models

# Installing the TF OD API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf1/setup.py .
python -m pip install .

Also, we’ll clone Victor Dibia’s handtracking repo to get the model checkpoints,

!git clone https://github.com/victordibia/handtracking

2. Converting the Checkpoints to Frozen Graphs

Now in the models/research/object_detection directory, you’ll observe a Python script export_tflite_ssd_graph.py which we’ll use to convert the model checkpoints to a TFLite-compatible graph. The checkpoints could be found in the handtracking/model-checkpoint directory. ssd stands for ‘Single Shot Detector’ which is the architecture for the hand detection model whereas mobilenet denotes the backbone architecture of MobileNet ( v1 or v2 ) which is a specialized CNN architecture for mobile devices.

Workflow -> Conversion of model checkpoints to TFLite buffers. ( Image Source: The Author )

The exported TFLite graph contains fixed input and output nodes. We can find the names and the shapes of these nodes ( or tensors ) in the export_ssd_tflite_graph.py script. Using the script, we’ll convert the model checkpoints to a TFLite compatible graph, given three arguments,

  1. pipeline_config_path : Path to the .config file which contains the configuration of the SSD Lite model used.
  2. trained_checkpoint_prefix : Prefix of the trained model checkpoints we wish to convert.
  3. max_detections : The number of bounding boxes to be predicted. This is important as it is an important parameter for the non maximum suppression postprocessing operation added to the graph.
!python models/research/object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path handtracking/model-checkpoint/ssdlitemobilenetv2/data_ssdlite.config \
--trained_checkpoint_prefix handtracking/model-checkpoint/ssdlitemobilenetv2/out_model.ckpt-19040 \
--output_directory outputs \
--max_detections 10

After the script has been executed, we are left with two files, tflite_graph.pb and tflite_graph.pbtxt which are TFLite compatible graphs.

3. Converting the Frozen Graphs to the TFLite Buffer

Now we’ll use a second script ( or more precisely, a utility ) to convert the frozen graphs, generated in step 2, into the TFLite buffers ( .tflite ). As TensorFlow 2.x ruled out the use of Session and Placeholder we can’t convert frozen graphs to TFLite here. This was one of the reasons why we installed TensorFlow 1.x in step 1.

We’ll use the tflite_convert utility to convert the frozen graph to a TFLite buffer. We can also use the tf.lite.TFLiteConverter API, but we’ll stick to the command line utility for now.

!tflite_convert \
--graph_def_file=/content/outputs/tflite_graph.pb \
--output_file=/content/outputs/model.tflite \
--output_format=TFLITE \
--input_arrays=normalized_input_image_tensor \
--input_shapes=1,300,300,3 \
--inference_type=FLOAT \
--output_arrays="TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3" \
--allow_custom_ops

Once the execution is done, you’ll see a model.tflite file in the outputs directory. In order to check the input/output shapes, we’ll load the TFLite model using tf.lite.Interpreter and call .get_input_details() or .get_output_details() to get the input and output details respectively.

Tip: Use pprint to get a beautiful output.

import tensorflow as tf
import pprint

interpreter = tf.lite.Interpreter( '/content/outputs/model.tflite' )
interpreter.allocate_tensors()

pprint.pprint( interpreter.get_input_details())
pprint.pprint( interpreter.get_output_details() )

Integrating the TFLite model in Android

Once we’ve got our TFLite model with all the details of its input and output shapes, we’re ready to run it in an Android app. Create a new project in Android Studio or feel free to fork/clone the GitHub repo to get started!

1. Adding Dependencies For CameraX, Coroutines and TF Lite

As we’re going to detect hands on a live camera feed, we’ll need add CameraX dependencies to our Android app. Similarly, in order to run TFLite models, we’ll require the tensorflow-lite dependency along with the Kotlin Coroutines dependency that helps us run the model asynchronously. In the app-level build.gradle file, we’ll add the following dependencies,

plugins {
...
}

android {

...

aaptOptions {
noCompress "tflite"
}

...

}

dependencies {
...

// CameraX dependencies
implementation "androidx.camera:camera-camera2:1.0.1"
implementation "androidx.camera:camera-lifecycle:1.0.1"
implementation "androidx.camera:camera-view:1.0.0-alpha28"
implementation "androidx.camera:camera-extensions:1.0.0-alpha28"

// TensorFlow Lite dependencies
implementation 'org.tensorflow:tensorflow-lite:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'

// Kotlin Coroutines
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.4.1'
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.4.1'

...
}

Make sure you add aaptOptions{ noCompress "tflite" } so that the model isn’t compressed by the system to make your app’s size smaller. Now, in order to place the TFLite model in our app, we’ll create an assets folder under app/src/main . Paste the TFLite file ( .tflite ) in this folder.

Placing ‘model.tflite’ in the assets folder. ( Image Source: The Author )

2. Initializing CameraX and ImageAnalysis.Analyzer

We’ll use a PreviewView from the CameraX package to display the live camera feed to the user. Over it, we’ll place an overlay, called the BoundingBoxOverlay , to draw the bounding boxes over the camera feed. I won’t be discussing the implementation here, but you can learn it from the source code or this story of mine,

As we’re going to predict the bounding boxes for hands on live frame data, we’ll also require an ImageAnalysis.Analyzer object which returns every frame from the live camera feed. See this snippet from FrameAnalyzer.kt ,

// Image Analyser for performing hand detection on camera frames.
class FrameAnalyser(
private val handDetectionModel: HandDetectionModel ,
private val boundingBoxOverlay: BoundingBoxOverlay ) : ImageAnalysis.Analyzer {

private var frameBitmap : Bitmap? = null
private var isFrameProcessing = false


override fun analyze(image: ImageProxy) {
// If a frame is being processed, drop the current frame.
if ( isFrameProcessing ) {
image.close()
return
}
isFrameProcessing = true

// Get the `Bitmap` of the current frame ( with corrected rotation ).
frameBitmap = BitmapUtils.imageToBitmap( image.image!! , image.imageInfo.rotationDegrees )
image.close()

// Configure frameHeight and frameWidth for output2overlay transformation matrix.
if ( !boundingBoxOverlay.areDimsInit ) {
Logger.logInfo( "Passing dims to overlay..." )
boundingBoxOverlay.frameHeight = frameBitmap!!.height
boundingBoxOverlay.frameWidth = frameBitmap!!.width
}

CoroutineScope( Dispatchers.Main ).launch {
runModel( frameBitmap!! )
}
}


private suspend fun runModel( inputImage : Bitmap ) = withContext( Dispatchers.Default ) {
...
}

}

BitmapUtils contains some useful static methods to manipulate Bitmap s. isFrameProcessing is a Boolean variable which determines whether the incoming frame must be dropped or passed to the model. As you may observe, we’re running the model in a CoroutineScope and hence you’ll observe no lags while the model produces the inference.

3. Implementing the Hand Detection Model

Next, we’ll create a class called HandDetectionModel which will handle all the TFLite operations and return the predictions given an image ( as a Bitmap ).

// Helper class for Hand detection TFLite model
class HandDetectionModel( context: Context ) {

// I/O details for the hand detection model.
// Refer to the comments of this script ->
// https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py
// For quantization, use the tflite_convert utility as described in the conversion notebook ( README ).
private val modelInputImageDim = 300
private val isQuantized = false
private val maxDetections = 10
private val boundingBoxesTensorShape = intArrayOf( 1 , maxDetections , 4 ) // [ 1 , 10 , 4 ]
private val confidenceScoresTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
private val classesTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
private val numBoxesTensorShape = intArrayOf( 1 ) // [ 1 , ]
// Input tensor processor for quantized and non-quantized versions of the model.
private val inputImageProcessorQuantized = ImageProcessor.Builder()
.add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
.add( CastOp( DataType.FLOAT32 ) )
.build()
private val inputImageProcessorNonQuantized = ImageProcessor.Builder()
.add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
.add( NormalizeOp( 128.5f , 128.5f ) )
.build()

// See app/src/main/assets for the TFLite model.
private val modelName = "model.tflite"
private val numThreads = 4
private var interpreter : Interpreter
// Confidence threshold for NMS
private val outputConfidenceThreshold = 0.9f

...

We’ll understand each of the terms separately in the above snippet,

  1. modelImageInputDim is the size of the input image for our model. Our model would take in an image of size 300 * 300.
  2. maxDetections represents the maximum number of predictions made by our model. It determines the shapes of boundingBoxesTensorShape , confidenceScoresTensorShape , classesTensorShape and numTensorShape .
  3. outputConfidenceThreshold is used to filter the predictions made by our model. This is not NMS but we only take the boxes whose score is greater than this threshold.
  4. inputImageProcessorQuantized and inputImageProcessorNonQuantized are instances of TensorOperator which resize the given images to a size of modelImageInputDim * modelInputImageDim . In case of a quantized model, we standardize the given image with mean and standard deviation both equal to 127.5.

Now, we’ll implement a method run() which will take a Bitmap image and output the bounding boxes in the form of List<Prediction> . Prediction is a class which holds the predicted data, like the confidence score and bounding box coordinates.

// Store the width and height of the input frames as they will be used for future transformations.
inputFrameWidth = inputImage.width
inputFrameHeight = inputImage.height

var tensorImage = TensorImage.fromBitmap( inputImage )
tensorImage = if ( isQuantized ) {
inputImageProcessorQuantized.process( tensorImage )
}
else {
inputImageProcessorNonQuantized.process( tensorImage )
}

val confidenceScores = TensorBuffer.createFixedSize( confidenceScoresTensorShape , DataType.FLOAT32 )
val boundingBoxes = TensorBuffer.createFixedSize( boundingBoxesTensorShape , DataType.FLOAT32 )
val classes = TensorBuffer.createFixedSize( classesTensorShape , DataType.FLOAT32 )
val numBoxes = TensorBuffer.createFixedSize( numBoxesTensorShape , DataType.FLOAT32 )
val outputs = mapOf(
0 to boundingBoxes.buffer ,
1 to classes.buffer ,
2 to confidenceScores.buffer ,
3 to numBoxes.buffer
)

val t1 = System.currentTimeMillis()
interpreter.runForMultipleInputsOutputs( arrayOf(tensorImage.buffer), outputs )
Logger.logInfo( "Model inference time -> ${System.currentTimeMillis() - t1} ms." )

return processOutputs( confidenceScores , boundingBoxes )

confidenceScores , boundingBoxes , classes and numBoxes are the four tensors which will hold the outputs of our model. The processOutputs method will filter the bounding boxes and return only those boxes whose confidence score is greater than the threshold.

private fun processOutputs( scores : TensorBuffer ,
boundingBoxes : TensorBuffer ) : List<Prediction> {
// Flattened version of array of shape [ 1 , maxDetections ] ( size = maxDetections )
val scoresFloatArray = scores.floatArray
// Flattened version of array of shape [ 1 , maxDetections , 4 ] ( size = maxDetections * 4 )
val boxesFloatArray = boundingBoxes.floatArray
val predictions = ArrayList<Prediction>()
for ( i in boxesFloatArray.indices step 4 ) {
// Store predictions which have a confidence > threshold
if ( scoresFloatArray[ i / 4 ] >= filterThreshold ) {
predictions.add(
Prediction(
getRect( boxesFloatArray.sliceArray( i..i+3 )) ,
scoresFloatArray[ i / 4 ]
)
)
}
}
return predictions.toList()
}


// Transform the normalized bounding box coordinates relative to the input frames.
private fun getRect( coordinates : FloatArray ) : Rect {
return Rect(
max( (coordinates[ 1 ] * inputFrameWidth).toInt() , 1 ),
max( (coordinates[ 0 ] * inputFrameHeight).toInt() , 1 ),
min( (coordinates[ 3 ] * inputFrameWidth).toInt() , inputFrameWidth ),
min( (coordinates[ 2 ] * inputFrameHeight).toInt() , inputFrameHeight )
)
}

4. Drawing the Bounding Boxes over the Camera Feed

Once we’ve received the bounding boxes, we’ll like to draw them over the camera feed, just as we do with OpenCV. We’ll create a new class BoundingBoxOverlay and add it in the activity_main.xml . The class looks like,

class BoundingBoxOverlay(context : Context, attributeSet : AttributeSet)
: SurfaceView( context , attributeSet ) , SurfaceHolder.Callback {

// Variables used to compute output2overlay transformation matrix
// These are assigned in FrameAnalyser.kt
var areDimsInit = false
var frameHeight = 0
var frameWidth = 0

// This var is assigned in FrameAnalyser.kt
var handBoundingBoxes: List<Prediction>? = null

// This var is assigned in MainActivity.kt
var isFrontCameraOn = false

private var output2OverlayTransform: Matrix = Matrix()
private val boxPaint = Paint().apply {
color = Color.YELLOW
style = Paint.Style.STROKE
strokeWidth = 16f
}
private val textPaint = Paint().apply {
strokeWidth = 2.0f
textSize = 32f
color = Color.YELLOW
}

private val displayMetrics = DisplayMetrics()

...

override fun onDraw(canvas: Canvas?) {
if ( handBoundingBoxes == null ) {
return
}
if (!areDimsInit) {
...
}
else {
for ( prediction in handBoundingBoxes!! ) {
val rect = prediction.boundingBox.toRectF()
output2OverlayTransform.mapRect( rect )
canvas?.drawRoundRect( rect , 16f, 16f, boxPaint )
canvas?.drawText(
prediction.confidence.toString(),
rect.centerX(),
rect.centerY(),
textPaint
)
}
}
}

}

That’s All! We’ve just implemented a hand detector in an Android app! You may run the app after reviewing all the code.

The Android app running the hand detection model. The text in the center of each box denotes the confidence of that prediction.

The End

Hope you enjoyed this story! Feel free to express your thoughts at equipintelligence@gmail.com or in the comments below.

Have a nice day ahead, dear developer!

--

--