Real Time Digit Recognition in IOS

Processing live camera feed with torchscript in swift

Bora Tunca
Towards Data Science

--

Demo gif by Author

The future of computing involves computers that can see like we see. Phones recognize our faces, cameras read books, cars drive themselves. We will have only more of these.

It is hard not to be curious about how these algorithms work. For people like me, building is the best way to learn. In that endeavor I built an ios app around the hello world problem of computer vision. This post documents the conversations I had with my rubber duck while developing the app. The source code is available on github. I hope you find this useful.

My rubber duck as seen in wikipedia (By Tom Morris — Own work, CC BY-SA 3.0)

Scope

Recognizing handwritten digits is the cornerstone problem of computer vision. It is best addressed by neural networks. There are plenty of resources on training neural network models for this problem. I won’t rehash the details around model building and training. Hence I assume some prior knowledge on these. My focus in this exercise is porting a trained model to a mobile environment.

The questions I am looking answers for:

  • How do I export a model from python to ios?
  • How do I run inference in swift?
  • How do I capture images from live camera feed?

Tech Stack

I used pytorch to build the model because I wanted to give torchscript a try. For the ios app I used swift with swiftUI.

Creating a Model

By design I picked a widely studied problem. Amitrajit’s post is one of the good resources out there for this problem. I followed its guidance for choosing the network architecture and the hyperparameters.

MNIST database photo from Wikipedia

The model I settled on has two convolutional layers and two fully connected layers. It uses LogSoftmax as the output layer activation.

Another important detail is the optimizer, the loss function and training parameters.

With this configuration the model achieves 97% accuracy in 3 epochs of training on mnist dataset.

Model Input

This model is trained on a training data. I want to make it work on images coming from a camera stream. This is called the production data. The app has to preprocess the production data to match the shape and semantics of the training data. Otherwise the results it gives will be suboptimal, if not garbage.

To get the preprocessing right, a thorough understanding of both the training and the production data is a must, as well as the output data the model returns. Therefore my duck and I spent a good time talking about the data.

Let’s start with the training dataset. Without any transformation the dataset contains tuples of PIL images and their labels.

We’ll apply two transformations to this dataset. One is for transforming the data from PIL format to a tensor. The other runs a normalization logic.

After the first transformation we get tensors in the shape of (1, 28, 28). This is because the images are single channel (grayscale), 28 by 28 pixel images. These tensors hold floating point numbers, valued between 0 and 1. These values reflect the intensity of the pixels; 0 is black and 1 is white.

The second transformation calculates the standard score (aka z-score) of each element in the tensor. Standard score tells us how many standard deviations a value is away from the mean. The mean of this dataset is 0.1307 and the standard deviation is 0.3081. Therefore after this transformation the values of the tensors will be spread between -0.4242 and 2.8215.

There is a lot of interesting information behind these transformation steps. For brevity I will only give a few links. Firstly, the source code tells a lot. Have a look at the ToTensor class and its underlying implementation. Secondly, read about standard score, and why we apply it to the training dataset. Then see the code of the Normalize class and its implementation. Finally, the Compose class will show how these steps are simply chained together.

Visualization

We gained an understanding of the input format. Next we should inspect the semantics. One challenge of neural networks is the large amount of data they work with. Raw print out of the input is not digestible by a human.

We have to make this data digestible, so that one can reason about it. A common approach is using visualization techniques. There are many sophisticated tools for visualizing large datasets. In our case the dataset is simple. Only print statements take us far enough. We know that the data represents a 28 by 28 image. Let’s print it accordingly.

Highlighting -0.42 in the output shows the pattern we expect to see.

In this input positive numbers represent the bright pixels while negative ones belong to the black background. The label of this input says this data is representing the digit three. The pattern we see verifies this. Let’s change the code to make the pattern more visible.

This ascii art attempt gives a good insight on what we are feeding the model with. The advantage of such a primitive visualization is portability. We’ll make use of the same logic later in the ios environment.

Not every dataset is as easy to visualize as this one. Fore more complex datasets there are sophisticated tools. One example is Matplotlib.

Output

This model is solving a classification problem. It predicts a confidence score for each of the 10 classes. This is reflected in the model architecture. The output layer is a fully connected layer with 10 nodes. Each node represents one digit out of ten.

How the confidence scores look like is determined by the activation method. In our case this is LogSoftmax (read more on the softmax function and why we are adding logarithm). This function gives us negative numbers. The closer the number is to zero, the higher the confidence.

In this example the model has the highest confidence in the third class, indicating that this image is most likely representing the digit three.

Getting out of Python

Torchscript is the first step on the journey towards a non-python environment. In a nutshell, torchscript gives us a model that can be used in c++. Once the model is trained, we can convert it to a torchscript module.

Torchscript offers a few ways to inspect the module. One of them is the .code property.

This representation captures the steps of the forward method we defined in the model. Further details are explained in the torchscript introduction tutorial. We can also optimize the the module for mobile.

Finally we serialize the module and save it to the disk. At this point the module is ready to be exported.

Getting into IOS

We trained a model that turns pixel data (images) to digit predictions. We can invoke it from c++. How do we run it in ios? This is what we’ll discover in this section.

The prerequisite to use torchscript in ios is the libtorch library. It allows us to deserialize the model we serialized in the python environment. This library is available in cocoapods.

The next hurdle is calling this library from swift. Swift requires handholding for talking to c++. Swift can talk to objective-c and objective-c can talk to c++. Therefore, we need an objective-c wrapper around the libtorch api.

The wrapper includes three pieces of structures:

The bridging header includes a single import for the module header. The module header contains the declarations. Only the module implementation is interesting to take a look at. It provides two methods.

initWithFileAtPath deserializes the model at the given file path. Once the model is deserialized, it sets it to evaluation mode.

predictImage method is the heart of the integration. This is where we run inference on the model.

First we need a tensor to run inference. We use the torch::from_blob function to convert an image buffer to a tensor.

at::Tensor tensor = 
torch::from_blob(imageBuffer, {1, 1, 28, 28}, at::kFloat);

The second argument, {1, 1, 28, 28}, indicates the size of the tensor we want to create. This value shows the batch size, followed by the size of a single training image. Even though we are running inference on a single image, pytorch modules always expect the input as a batch. Hence the batch size is one. Last argument indicates the type of the elements in the tensor.

The next two lines is where my lack of c++ experience catches up with me. On a high level I understand we disable the gradient calculation because we are running inference only. I could not find further tangible details on these calls.

torch::autograd::AutoGradMode guard(false);
at::AutoNonVariableTypeMode non_var_type_mode(true);

Finally we call the forward method and convert the result to a tensor.

auto outputTensor = _impl.forward({tensor}).toTensor();

The rest of the logic is self descriptive. We process the output to return an array of numbers.

Running Inference

In my first iteration I ran inference on static images from the mnist dataset. Our model is trained on this dataset. For these images we should get high accuracy from the model.

I added a few images from the dataset to the asset catalog. Then I was able to access these images.

Next, we need an instance of the model to call the inference method on. We use the api we implemented earlier to instantiate the model. The file path should point to a serialized torchscript model. In our case this is a copy of the file we generated earlier, after converting the model to torchscript.

We have the input, we have the model instance, we should be ready to call the predict method. However, we are not, because we need to preprocess the input image to transform it into a structure that our model understands. This is where our understanding of the model input becomes important.

What we have as an input is a UIImage. What the predict method expects is an array of floating point numbers. Not only do we want to turn the image pixels to floating point numbers, we also want to normalize these numbers. If you remember the data transformations we ran on the training set, one of them was normalization.

torchvision.transforms.Normalize((0.1307,), (0.3081,))

To emphasis it again; we want the production input data to be as close as possible to the training data. Otherwise we will get garbage results.

There are a series of steps we need to take to convert the input data to the format the model expects. Once identified, these steps are trivial to implement. At most it takes a bit of familiarity with the swift standard library and Apple’s core image framework.

Input Pipeline

I’ll restate the goal: UIImage -> Byte Array -> Float Array. The steps I broke down this pipeline into are:

  • Convert UIImage to CGImage
  • Read CGImage Bytes
  • Normalize

UIImage to CGImage

Working with images in ios involves some complexity. A case in point is reading the bytes. In order to access the bytes of a UIImage, we need to first convert it to a CIImage, and then to a CGImage. After that we can read the bytes of the image. For an introduction to these image classes check out the integrating core image post on hackingwithswift.

This logic may fail, therefore the method returns an optional result. After the error handling we gain access to the CGImage.

Reading the Bytes

We are back on track to read the bytes. In this step we’ll allocate a memory space and redraw the image in that memory. After this process the bytes we are interested in will end up in the allocated memory.

We allocate the memory by initializing a variable.

var rawBytes: [UInt8] = 
[UInt8](repeating: 0, count: 28 * 28) // 1 bytes per pixel;

There are a couple of things to notice. We are using 8 bit unsigned integers. Any 8 bit data type would work here. Secondly, the size of the array matches the number of pixels the training images have. We used grayscale images for the training, we need only one byte to represent each pixel.

Next we will create a bitmap graphics context. This structure will enable us to redraw the image. The constructor for the context requires a pointer to a memory space. We’ll use the swift standard library method withUnsafeMutableBytes to get a pointer.

The constructor arguments are important to get right. data points to the allocated memory. width and height are self descriptive. According to the documentation each pixel has components in memory. In our case a pixel has only one component of size one byte. Therefore, bitsPerComponent is 8. space indicates the color space. In our case this is grayscale. bitmapInfo tells that this bitmap does not have an alpha (transparency) channel. The documentation has further information on the arguments.

After the necessary error handling, we can move on to drawing the image. First we define the area we want the image to be drawn in. Then we call the draw method.

When this callback executes, it will overwrite the values of the rawBytes array, with the pixel data of the image we have in cgImage.

Normalize

The last step of the preprocess is normalization. First we divide the values with 255 to arrive at numbers between 0 and 1. Then we’ll apply the same logic we applied to the training dataset. If you looked at the source code of of the normalize method earlier, you know it is straightforward to port.

We got back floating point numbers between -0.4242 and 2.8215. Exactly same as what we had in the training environment.

Visualization II

As a last check let’s run the ascii art method, this time in swift.

The output looks very similar to the input image.

Running Inference II

Early on we instantiated the model. We also just preprocessed the input. We are ready to call the predict method.

We again used the withUnsafeMutableBytes method. This is because the predict method on the module object expects a pointer. After the error handling, we get an array of prediction results.

As we investigated earlier, the output has 10 numbers. We are interested in the top highest confidence scores. We run a sorting logic to identify these.

The first element of the sorted result is 3 with a high confidence. It works! 🥳

Interlude

We just managed to make the ios app recognize a static image of a digit. It’s amazing how much detail goes into building a solution even for such a widely studied problem. It is also satisfactory to see all of the pieces come together and deliver results.

So far we achieved two of the goals:

  • How do I export a model from python to ios?
  • How do I run inference in swift?
  • How do I capture images from live camera feed?

In part II we’ll discover the hidden details of live video processing. Stay tuned.

--

--