📱 Mobile Machine Learning
Deploying Pretrained TF Object Detection Models on Android
Right from trained checkpoints to an Android app
Deploying machine learning models on mobile devices is the new phase of ML that’s about to begin. Vision models, mostly object detection models, have already made their way to mobile devices along with speech recognition, image classification, text completion etc. These models, run usually run on GPU-enabled computers, have tons of use-cases when deployed on mobile devices.
In order to demonstrate an end-to-end example on how to bring a ML model, specifically an object-detection model to Android, we’ll use Victor Dibia’s hand detection models for our demonstration, from the victordibia/handtracking repo. The model can detect human hands from an image and is made using the TensorFlow Object Detection API. We’ll use the trained checkpoints from Victor Dibia’s repo and convert them to the TensorFlow Lite ( TFLite ) format, which can be used to run the model on Android ( or even iOS, Raspberry Pi ).
Next, we move on to the Android app and create all the necessary classes/methods required to get the model running and also display its predictions (the bounding boxes ) over the live camera feed.
Let’s get started!
Contents
Conversion of Model Checkpoints to TFLite
👉 1. Setting Up the TF Object Detection API
👉 2. Converting the Checkpoints to Frozen Graphs
👉 3. Converting the Frozen Graphs to the TFLite Buffer
Integrating the TFLite model in Android
👉 1. Adding Dependencies For CameraX, Coroutines and TF Lite
👉 2. Initializing CameraX and ImageAnalysis.Analyzer
Conversion of Model Checkpoints to TFLite
Our first step will be to convert the trained model checkpoints, provided in Victor Dibia’s repo (MIT License), to the TensorFlow Lite format. TensorFlow Lite provides an efficient gateway to run TensorFlow models on Android, iOS and micro-controller devices. In order to run the conversion scripts, we need to set up the TensorFlow Object Detection API on our machine. You may also use this Colab Notebook to perform all conversions in the cloud.
I recommend you to use the Colab Notebook ( especially for Windows ) as I personally went into a number of errors while doing so.
1. Setting Up the TF Object Detection API
The TensorFlow Object Detection API provides a number of pretrained object detection models which can be fine-tuned on custom datasets and deployed directly into mobile, web, or the cloud. We’ll only require the conversion scripts that help us convert the model checkpoints into a TF Lite buffer.
The hand detection model was itself made using the TF OD API with TensorFlow 1.x. So, first we need to install TensorFlow 1.x or TF 1.15.0 ( the latest version in the 1.x family ) and then clone the tensorflow/models repo which contains the TF OD API.
# Installing TF 1.15.0
!pip install tensorflow==1.15.0
# Cloning the tensorflow/models repo
!git clone https://github.com/tensorflow/models
# Installing the TF OD API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf1/setup.py .
python -m pip install .
Also, we’ll clone Victor Dibia’s handtracking repo to get the model checkpoints,
!git clone https://github.com/victordibia/handtracking
2. Converting the Checkpoints to Frozen Graphs
Now in the models/research/object_detection
directory, you’ll observe a Python script export_tflite_ssd_graph.py
which we’ll use to convert the model checkpoints to a TFLite-compatible graph. The checkpoints could be found in the handtracking/model-checkpoint
directory. ssd
stands for ‘Single Shot Detector’ which is the architecture for the hand detection model whereas mobilenet
denotes the backbone architecture of MobileNet ( v1 or v2 ) which is a specialized CNN architecture for mobile devices.
The exported TFLite graph contains fixed input and output nodes. We can find the names and the shapes of these nodes ( or tensors ) in the export_ssd_tflite_graph.py
script. Using the script, we’ll convert the model checkpoints to a TFLite compatible graph, given three arguments,
pipeline_config_path
: Path to the.config
file which contains the configuration of the SSD Lite model used.trained_checkpoint_prefix
: Prefix of the trained model checkpoints we wish to convert.max_detections
: The number of bounding boxes to be predicted. This is important as it is an important parameter for the non maximum suppression postprocessing operation added to the graph.
!python models/research/object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path handtracking/model-checkpoint/ssdlitemobilenetv2/data_ssdlite.config \
--trained_checkpoint_prefix handtracking/model-checkpoint/ssdlitemobilenetv2/out_model.ckpt-19040 \
--output_directory outputs \
--max_detections 10
After the script has been executed, we are left with two files, tflite_graph.pb
and tflite_graph.pbtxt
which are TFLite compatible graphs.
3. Converting the Frozen Graphs to the TFLite Buffer
Now we’ll use a second script ( or more precisely, a utility ) to convert the frozen graphs, generated in step 2, into the TFLite buffers ( .tflite
). As TensorFlow 2.x ruled out the use of Session
and Placeholder
we can’t convert frozen graphs to TFLite here. This was one of the reasons why we installed TensorFlow 1.x in step 1.
We’ll use the tflite_convert
utility to convert the frozen graph to a TFLite buffer. We can also use the tf.lite.TFLiteConverter
API, but we’ll stick to the command line utility for now.
!tflite_convert \
--graph_def_file=/content/outputs/tflite_graph.pb \
--output_file=/content/outputs/model.tflite \
--output_format=TFLITE \
--input_arrays=normalized_input_image_tensor \
--input_shapes=1,300,300,3 \
--inference_type=FLOAT \
--output_arrays="TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3" \
--allow_custom_ops
Once the execution is done, you’ll see a model.tflite
file in the outputs
directory. In order to check the input/output shapes, we’ll load the TFLite model using tf.lite.Interpreter
and call .get_input_details()
or .get_output_details()
to get the input and output details respectively.
Tip: Use
pprint
to get a beautiful output.
import tensorflow as tf
import pprint
interpreter = tf.lite.Interpreter( '/content/outputs/model.tflite' )
interpreter.allocate_tensors()
pprint.pprint( interpreter.get_input_details())
pprint.pprint( interpreter.get_output_details() )
Integrating the TFLite model in Android
Once we’ve got our TFLite model with all the details of its input and output shapes, we’re ready to run it in an Android app. Create a new project in Android Studio or feel free to fork/clone the GitHub repo to get started!
1. Adding Dependencies For CameraX, Coroutines and TF Lite
As we’re going to detect hands on a live camera feed, we’ll need add CameraX dependencies to our Android app. Similarly, in order to run TFLite models, we’ll require the tensorflow-lite
dependency along with the Kotlin Coroutines dependency that helps us run the model asynchronously. In the app-level build.gradle
file, we’ll add the following dependencies,
plugins {
...
}
android {
...
aaptOptions {
noCompress "tflite"
}
...
}
dependencies {
...
// CameraX dependencies
implementation "androidx.camera:camera-camera2:1.0.1"
implementation "androidx.camera:camera-lifecycle:1.0.1"
implementation "androidx.camera:camera-view:1.0.0-alpha28"
implementation "androidx.camera:camera-extensions:1.0.0-alpha28"
// TensorFlow Lite dependencies
implementation 'org.tensorflow:tensorflow-lite:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
// Kotlin Coroutines
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.4.1'
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.4.1'
...
}
Make sure you add aaptOptions{ noCompress "tflite" }
so that the model isn’t compressed by the system to make your app’s size smaller. Now, in order to place the TFLite model in our app, we’ll create an assets
folder under app/src/main
. Paste the TFLite file ( .tflite
) in this folder.
2. Initializing CameraX and ImageAnalysis.Analyzer
We’ll use a PreviewView
from the CameraX package to display the live camera feed to the user. Over it, we’ll place an overlay, called the BoundingBoxOverlay
, to draw the bounding boxes over the camera feed. I won’t be discussing the implementation here, but you can learn it from the source code or this story of mine,
As we’re going to predict the bounding boxes for hands on live frame data, we’ll also require an ImageAnalysis.Analyzer
object which returns every frame from the live camera feed. See this snippet from FrameAnalyzer.kt
,
// Image Analyser for performing hand detection on camera frames.
class FrameAnalyser(
private val handDetectionModel: HandDetectionModel ,
private val boundingBoxOverlay: BoundingBoxOverlay ) : ImageAnalysis.Analyzer {
private var frameBitmap : Bitmap? = null
private var isFrameProcessing = false
override fun analyze(image: ImageProxy) {
// If a frame is being processed, drop the current frame.
if ( isFrameProcessing ) {
image.close()
return
}
isFrameProcessing = true
// Get the `Bitmap` of the current frame ( with corrected rotation ).
frameBitmap = BitmapUtils.imageToBitmap( image.image!! , image.imageInfo.rotationDegrees )
image.close()
// Configure frameHeight and frameWidth for output2overlay transformation matrix.
if ( !boundingBoxOverlay.areDimsInit ) {
Logger.logInfo( "Passing dims to overlay..." )
boundingBoxOverlay.frameHeight = frameBitmap!!.height
boundingBoxOverlay.frameWidth = frameBitmap!!.width
}
CoroutineScope( Dispatchers.Main ).launch {
runModel( frameBitmap!! )
}
}
private suspend fun runModel( inputImage : Bitmap ) = withContext( Dispatchers.Default ) {
...
}
}
BitmapUtils
contains some useful static methods to manipulate Bitmap
s. isFrameProcessing
is a Boolean variable which determines whether the incoming frame must be dropped or passed to the model. As you may observe, we’re running the model in a CoroutineScope
and hence you’ll observe no lags while the model produces the inference.
3. Implementing the Hand Detection Model
Next, we’ll create a class called HandDetectionModel
which will handle all the TFLite operations and return the predictions given an image ( as a Bitmap
).
// Helper class for Hand detection TFLite model
class HandDetectionModel( context: Context ) {
// I/O details for the hand detection model.
// Refer to the comments of this script ->
// https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py
// For quantization, use the tflite_convert utility as described in the conversion notebook ( README ).
private val modelInputImageDim = 300
private val isQuantized = false
private val maxDetections = 10
private val boundingBoxesTensorShape = intArrayOf( 1 , maxDetections , 4 ) // [ 1 , 10 , 4 ]
private val confidenceScoresTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
private val classesTensorShape = intArrayOf( 1 , maxDetections ) // [ 1 , 10 ]
private val numBoxesTensorShape = intArrayOf( 1 ) // [ 1 , ]
// Input tensor processor for quantized and non-quantized versions of the model.
private val inputImageProcessorQuantized = ImageProcessor.Builder()
.add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
.add( CastOp( DataType.FLOAT32 ) )
.build()
private val inputImageProcessorNonQuantized = ImageProcessor.Builder()
.add( ResizeOp( modelInputImageDim , modelInputImageDim , ResizeOp.ResizeMethod.BILINEAR ) )
.add( NormalizeOp( 128.5f , 128.5f ) )
.build()
// See app/src/main/assets for the TFLite model.
private val modelName = "model.tflite"
private val numThreads = 4
private var interpreter : Interpreter
// Confidence threshold for NMS
private val outputConfidenceThreshold = 0.9f
...
We’ll understand each of the terms separately in the above snippet,
modelImageInputDim
is the size of the input image for our model. Our model would take in an image of size 300 * 300.maxDetections
represents the maximum number of predictions made by our model. It determines the shapes ofboundingBoxesTensorShape
,confidenceScoresTensorShape
,classesTensorShape
andnumTensorShape
.outputConfidenceThreshold
is used to filter the predictions made by our model. This is not NMS but we only take the boxes whose score is greater than this threshold.inputImageProcessorQuantized
andinputImageProcessorNonQuantized
are instances ofTensorOperator
which resize the given images to a size ofmodelImageInputDim
*modelInputImageDim
. In case of a quantized model, we standardize the given image with mean and standard deviation both equal to 127.5.
Now, we’ll implement a method run()
which will take a Bitmap
image and output the bounding boxes in the form of List<Prediction>
. Prediction
is a class which holds the predicted data, like the confidence score and bounding box coordinates.
// Store the width and height of the input frames as they will be used for future transformations.
inputFrameWidth = inputImage.width
inputFrameHeight = inputImage.height
var tensorImage = TensorImage.fromBitmap( inputImage )
tensorImage = if ( isQuantized ) {
inputImageProcessorQuantized.process( tensorImage )
}
else {
inputImageProcessorNonQuantized.process( tensorImage )
}
val confidenceScores = TensorBuffer.createFixedSize( confidenceScoresTensorShape , DataType.FLOAT32 )
val boundingBoxes = TensorBuffer.createFixedSize( boundingBoxesTensorShape , DataType.FLOAT32 )
val classes = TensorBuffer.createFixedSize( classesTensorShape , DataType.FLOAT32 )
val numBoxes = TensorBuffer.createFixedSize( numBoxesTensorShape , DataType.FLOAT32 )
val outputs = mapOf(
0 to boundingBoxes.buffer ,
1 to classes.buffer ,
2 to confidenceScores.buffer ,
3 to numBoxes.buffer
)
val t1 = System.currentTimeMillis()
interpreter.runForMultipleInputsOutputs( arrayOf(tensorImage.buffer), outputs )
Logger.logInfo( "Model inference time -> ${System.currentTimeMillis() - t1} ms." )
return processOutputs( confidenceScores , boundingBoxes )
confidenceScores
, boundingBoxes
, classes
and numBoxes
are the four tensors which will hold the outputs of our model. The processOutputs
method will filter the bounding boxes and return only those boxes whose confidence score is greater than the threshold.
private fun processOutputs( scores : TensorBuffer ,
boundingBoxes : TensorBuffer ) : List<Prediction> {
// Flattened version of array of shape [ 1 , maxDetections ] ( size = maxDetections )
val scoresFloatArray = scores.floatArray
// Flattened version of array of shape [ 1 , maxDetections , 4 ] ( size = maxDetections * 4 )
val boxesFloatArray = boundingBoxes.floatArray
val predictions = ArrayList<Prediction>()
for ( i in boxesFloatArray.indices step 4 ) {
// Store predictions which have a confidence > threshold
if ( scoresFloatArray[ i / 4 ] >= filterThreshold ) {
predictions.add(
Prediction(
getRect( boxesFloatArray.sliceArray( i..i+3 )) ,
scoresFloatArray[ i / 4 ]
)
)
}
}
return predictions.toList()
}
// Transform the normalized bounding box coordinates relative to the input frames.
private fun getRect( coordinates : FloatArray ) : Rect {
return Rect(
max( (coordinates[ 1 ] * inputFrameWidth).toInt() , 1 ),
max( (coordinates[ 0 ] * inputFrameHeight).toInt() , 1 ),
min( (coordinates[ 3 ] * inputFrameWidth).toInt() , inputFrameWidth ),
min( (coordinates[ 2 ] * inputFrameHeight).toInt() , inputFrameHeight )
)
}
4. Drawing the Bounding Boxes over the Camera Feed
Once we’ve received the bounding boxes, we’ll like to draw them over the camera feed, just as we do with OpenCV. We’ll create a new class BoundingBoxOverlay
and add it in the activity_main.xml
. The class looks like,
class BoundingBoxOverlay(context : Context, attributeSet : AttributeSet)
: SurfaceView( context , attributeSet ) , SurfaceHolder.Callback {
// Variables used to compute output2overlay transformation matrix
// These are assigned in FrameAnalyser.kt
var areDimsInit = false
var frameHeight = 0
var frameWidth = 0
// This var is assigned in FrameAnalyser.kt
var handBoundingBoxes: List<Prediction>? = null
// This var is assigned in MainActivity.kt
var isFrontCameraOn = false
private var output2OverlayTransform: Matrix = Matrix()
private val boxPaint = Paint().apply {
color = Color.YELLOW
style = Paint.Style.STROKE
strokeWidth = 16f
}
private val textPaint = Paint().apply {
strokeWidth = 2.0f
textSize = 32f
color = Color.YELLOW
}
private val displayMetrics = DisplayMetrics()
...
override fun onDraw(canvas: Canvas?) {
if ( handBoundingBoxes == null ) {
return
}
if (!areDimsInit) {
...
}
else {
for ( prediction in handBoundingBoxes!! ) {
val rect = prediction.boundingBox.toRectF()
output2OverlayTransform.mapRect( rect )
canvas?.drawRoundRect( rect , 16f, 16f, boxPaint )
canvas?.drawText(
prediction.confidence.toString(),
rect.centerX(),
rect.centerY(),
textPaint
)
}
}
}
}
That’s All! We’ve just implemented a hand detector in an Android app! You may run the app after reviewing all the code.
The End
Hope you enjoyed this story! Feel free to express your thoughts at equipintelligence@gmail.com or in the comments below.
Have a nice day ahead, dear developer!