On-Device Machine Learning: Text Generation on Android 📝

Combine the power of GPT-2, TensorFlow and Kotlin to bring state-of-the-art NLP on mobile

Pierric Cistac
Towards Data Science
7 min readDec 19, 2019

--

At Hugging Face, our goal is to solve and democratize Natural Language Processing (NLP). Currently, most of the models in production are running remotely on servers — for example at Google for search. Still, the improvement of hardware on mobile devices and the increasing focus on privacy make them more and more suitable to run offline.

The goal of this article is to give a high-level view of the development of an Android application for text generation running entirely on-device. The code is available here: https://github.com/huggingface/tflite-android-transformers/tree/master/gpt2

Text generation using GPT-2 on mobile
What we’re going to build 🤓

Part I: Converting GPT-2 to TensorFlow Lite format

GPT-2 is a model released in 2019 whose Natural Language Generation capabilities (NLG, a subset of NLP) were so impressive that the release of the biggest version was delayed for months. You can play with it using this funny (scary?) tool we released. In this app, we are going to use the smallest version of the model. Its generation capabilities are less impressive than the biggest one, but its size (500MB vs. 6GB) makes it way more suitable for mobile use!

Before being able to run it on-device, we need to convert it to a suitable format, TensorFlow Lite (TFLite). To do so, we can run this Python script:

tf-nightly” and “transformers” libraries need to be installed in your environment. You can also try it directly in your browser using this colab notebook.

This script makes use of our 🤗 Transformers library to import the “raw” model and then converts it to TFLite format. Note lines 15/16 of the script: Before running the conversion, we are using TFLite to specify that we want to quantize the weights (parameters) of the model to half-precision floating-point format. This results in a final size for our converted model of 237MB, i.e. half of the size of the original “input” model 🎉. The downside? A very minimal loss in accuracy, but it’s definitely worth it on mobile given the saved space in storage!

We could go even further in the compression of our model by doing a conversion of the weights to 8-bit integer representations, with a resulting size of only 128MB. But our tests with this version showed to be much slower on devices; thus we prefer here using the half-precision floating-point version. You can still experiment with the 8-bits version with our app by changing the default model.

Part II: Integrating the converted GPT-2 model in an Android app

Now that we have converted our model, we can focus on actually building our app. The entire source code is available on GitHub, so here I’m only going to focus on the most interesting parts.

In the Python script, we specified (lines 6/7) that our model is going to take as input a bidimensional array of integers of shape [1, 64], i.e. something like this, where the inner array contains 64 elements:

But what we’re going to have in real life is a string, corresponding to the current text. We thus need to convert that string into integers, a.k.a. tokens. Roughly, we can say that a token is a numeral representation of a part of our string.

Tokens are also what is returned by the model as output. Each run of the model allows us to determine the next token of our text, that we then pass with the previous text to our model for its next run, and so on

We need something to convert our string to tokens, and tokens back to a string. That’s the role of the Tokenizer. The two main functions of a Tokenizer are usually encode and decode.

Full implementation of the Tokenizer available here

The encode function takes our starting/previous text as parameter, parses it using a regex, and then converts every character to a specific representation. It finally applies a Byte-Pair Encoding (BPE) algorithm whose output is mapped to integers thanks to the model vocabulary. 🤯

The decode function does the reverse, mapping tokens to their vocabulary representation, and then decoding this representation to our final string.

Now that we know how to encode and decode our text, we can call our model! This is the role of the following generate function:

Click here to see the full implementation

The function’s inputs are the initial text and the number of tokens we want to generate (i.e., the number of times our model is called). The first step is to tokenize our text.

Remember we said the input of the model was an array of shape [1, 64]? We need to strip our previous text tokens to keep only the last 64 ones maximum. That’s our inputIds. It means that the generation of the next token only depends on those 64 previous tokens, ignoring any previous ones.

We could specify a higher sequence length when we convert our model, but it would imply more computation at inference, slowing down our app. It’s a trade-off between speed and “quality” of our generation. 🤔

We also need to create the data structures that our model will feed with its output. Our model has many outputs, but we’re only interested in the first one, the “predictions”.

We’re reaching here a limit in terms of expressiveness of Kotlin when it comes to multidimensional arrays; here is what it would be in Java:

I’m far from being a Java fanboy, but the right side of the expression seems easier to read to me!

We can then — finally! — run our model by calling the TFLite interpreter:

Once our “predictions” array is filled by the interpreter, we need to determine the token that will be our “next” one. There are many different ways of doing so; here we’re first using Top-K filtering, selecting the k higher predictions. We’re then applying a Softmax function to get a probability distribution of those values before finally selecting “the one” through multinomial sampling.

Part III: Interfacing with the Activity in a UI-friendly way thanks to Kotlin coroutines

It’s now time to link our model to the interface of our application! Running a model such as GPT-2 on device, even a quantized version, requires computing resources. If we do it wrong, we might end up with a UI freezing while the model is running, which is not very user-friendly! 😱

To avoid such a bad result, we’re going to make use of coroutines, a really nice way of doing non-blocking programming in Kotlin. Here is our (nearly) complete GPT2Client class, which is a ViewModel loaded from our main activity:

For full implementation, check here

The class first needs to load all the assets for our model and to initialize the TFLite interpreter. To do that without blocking the UI, in the init block we’re launching a new coroutine thanks to viewModelScope.launch. Inside this coroutine, we’re loading our model assets by calling 3 “load” methods. Here is the signature for loadModel:

What’s important here is the withContext(Dispatchers.IO) part. We’re saying that we want to execute this method on a different thread than the main one, here using one designed for I/O operations (see here for more details).

The “beauty” of creating a coroutine through viewModelScope.launch is that it ties its lifetime to the one of the ViewModel. It ensures that when the ViewModel is cleared, the coroutine is canceled! 🙌

Then, when the user clicks the “Trigger autocomplete” button in the app, the launchAutocomplete method is executed, creating another coroutine from which we’re going to call our model.

Inside this coroutine, we first make sure the initialization of the assets (initJob) is complete, then do the same for the potential previous model run (autocompleteJob), that we cancel if still running. We can then call our generate method:

The dispatcher used for this method is not Dispatchers.IO since we’re not doing any I/O operation here, but a more generic Dispatchers.Default which uses a common pool of shared background threads.

Another interesting part of this method is the yield() method call at the end of the repeat block. This is what allows the method to check for an eventual cancellation. Without that, no cancellation would be possible, and we would have to wait until the end of the entire generation before being able to free the resources! ☠️ Here we’re checking for cancellation after each token generation.

Another way of checking for cancellation would be to check the value of the isActive property

The completed text is then displayed in the app “automagically” thanks to the use of the LiveData structure (our completion property). 🧙‍♀️

That’s it! At Hugging Face we believe that we’re only at the beginning of the era of AI running on-device. With the continuous development of dedicated hardware and associated drivers and frameworks on one side and techniques like quantization and distillation on the other side, the capabilities of our smartphones promise to have a bright future, allowing the run of more complex models in a more efficient and performant way.

You can check the entire repository if you want more Android examples. We also released a repo with models and apps for iOS making use of the Apple-specific CoreML framework. And if you’re interested in more in-depth state-of-the-art NLP, our 🤗 Transformers library is here!

--

--