Optimizing Style Transfer Models for Videos on Mobile Devices

Experimenting wth PyTorch Lighting and Flash flexibility to create a face-to-cartoon model

Jacopo Mangiavacchi
Towards Data Science

--

iOS ToonClip App using the face to cartoon model described in this post — Model, App and image are provided by the author.

Hugging Face Spaces demo at: https://huggingface.co/spaces/Jacopo/ToonClip

Free iOS Mobile App at: https://apps.apple.com/us/app/toonclip/id1536285338

Source code available at (WIP): https://github.com/JacopoMangiavacchi/ToonClip-ComicsHero

Introduction

GANs are super popular architecture to build image style transfer models but sometimes it takes quite a while to train them and more importantly to optimize the generator model to being able to fit and run in real time on mobile devices.

In this experimentation I focused on these three specific areas:

  1. Easily experiment with different model architectures and backbones taking advantages of PyTorch Lighting and Flash flexibility
  2. Quickly train without GAN discriminators using a custom Pytorch Perceptual Feature Loss
  3. Constantly convert to ONNX / CoreML and benchmark the model on device to find better optimization for mobile devices

Credits

Before starting I really want to thank Doron Adler for providing a great project that truly inspire my experimentation and this post, the U2Net Heros style transfer:

This U2Net model in particular implemented the core idea of a muti_bce_loss_fusion function to build a Perceptual Feature Loss to compare all layers outputs of the U2Net encoder-decoder.

In my experiment I followed a similar but at the same time more general and hopefully flexible technique using a Perceptual Feature Loss, computed on a separated image model and being able to capture different layer features independently of the specific model architecture used for the encoder-decoder.

Specifically I tried to take all the advantages of modern high level frameworks on top of PyTorch to:

  • Support any well known image model for the encoder-decoder architecture and any image featurizer backend on the encoder with pre-trained weights available from PyTorch Lighting Flash
  • Implement a custom loss and quickly training with PyTorch Lighting supporting GPU/TPU etc.

Dataset

In my experimentation I reused the same data provided by Doron Adler in his U2Net Heros style transfer project mentioned above. This is an image dataset of 100 high resolution (1024x1024) face images selected from the FFHQ dataset with associated cartoonized images. The dataset is provided on same GitHub repo mentioned above and licensed under with Apache License 2.0

In the code below you can see the PyTorch Dataset class loading and transforming this data:

As you probably noticed on top of this small dataset I used the Albumentations library to add some image augmentation transformation and being able to generalize better once inferencing the model live on real images and videos captured from a mobile camera.

In particular, as the data contains images both as input (face) and output (styled face), I reused the ReplayCompose functionality of the Albumentation api in order to be sure to apply same random transformations to each couple of face and styled face images.

In the code snippet below you will find an example of how to configure this image to image Dataset with different augmentation filters:

You will see later in this post that I am using transfer learning technique importing weights from a model previously trained on the ImageNet dataset so I also added above in the Albumentation pipeline a normalization step configured with the ImageNet stats.

Finally, in order to be able to re-execute the training loop several time on the same distribution, before randomly splitting the dataset in train and evaluation data I specified a specific random manual seed and later reset the seed to the initial PyTorch random seed. This way I was able to train every time with different random applications of the augmentation pipeline but still work on the same train and validation distribution.

Perceptual Feature Loss

In GAN and generally in Style Transfer domain the idea of a loss functions based on feature activation and style loss is a well known technique frequently adopted for example in scenarios such as super resolutions.

In this experimentation I followed in particular the well known technique of using an external, pre-trained, VGG network to get features from the different convolution blocks and use these features to some up the loss for different features instead of simply computing the MSE or L1 distance between the generated image and the original label/style image.

The diagram below from the paper “Perceptual Losses for Real-Time Style Transfer and Super-Resolution” (https://arxiv.org/abs/1603.08155) illustrates this idea and many articles exist to fully describe this process and provide reference implementation for the code (including the Fast.AI MOOC training by Jeremy Howard).

Model structure provided the paper

In my case I re-used the snippet code made available by Alper Ahmetoğluon on GitHub Gist (https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49) but I’ve updated it adding the options to weight each block feature in order to follow both the paper and Jeremy Howard suggestion from the Fast.AI course to power up the Gramian matrix product on the output of these layers.

I simply commented the call to normalize the input image to the ImageNet stats as I’m already normalizing all the input face and style images in the dataset transformation pipeline.

Image to Image model and Segmentation

A Style Transfer, image to image, problem is at the end very similar to a more generic Semantic Segmentation problem and any model architecture like fully convolutional encoder-decoder used for this generic image domain really works as well in the style transfer scenario.

In my exploration I wanted to be as flexible as possible on experimenting with different specific encoder-decoder model architectures and being able to benchmark real performance results on architecure such as: ‘deeplabv3’, ‘deeplabv3plus’, ‘fpn’, ‘linknet’, ‘manet’, ‘pan’, ‘pspnet’, ‘unet’ and finally ‘unetplusplus’.

I found the PyTorch Lighting Flash library particularly useful in this specific context as it specifically offers a high level SemanticSegmentation class that support all the architectures above.

Plus it provides weights for different combinations of image feature backbones to use on the encoder part of these Fully Convolutional Network (FCN). For example on UNet++ it specifically support backbones such as:

'densenet121', 'densenet161', 'densenet169', 'densenet201', 'dpn107', 'dpn131', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'efficientnet-b0', 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b6', 'efficientnet-b7', 'efficientnet-b7', 'efficientnet-b8', 'efficientnet-l2', 'gernet_l', 'gernet_m', 'gernet_s', 'inceptionresnetv2', 'inceptionv4', 'mobilenet_v2', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', 'regnetx_002', 'regnetx_004', 'regnetx_006', 'regnetx_008', 'regnetx_016', 'regnetx_032', 'regnetx_040', 'regnetx_064', 'regnetx_080', 'regnetx_120', 'regnetx_160', 'regnetx_320', 'regnety_002', 'regnety_004', 'regnety_006', 'regnety_008', 'regnety_016', 'regnety_032', 'regnety_040', 'regnety_064', 'regnety_080', 'regnety_120', 'regnety_160', 'regnety_320', 'res2net101_26w_4s', 'res2net50_14w_8s', 'res2net50_26w_4s', 'res2net50_26w_6s', 'res2net50_26w_8s', 'res2net50_48w_2s', 'res2next50', 'resnest101e', 'resnest14d', 'resnest200e', 'resnest269e', 'resnest26d', 'resnest50d', 'resnest50d_1s4x24d', 'resnest50d_4s2x40d', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x16d', 'resnext101_32x32d', 'resnext101_32x48d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext50_32x4d', 'se_resnet101', 'se_resnet152', 'se_resnet50', 'se_resnext101_32x4d', 'se_resnext50_32x4d', 'senet154', 'skresnet18', 'skresnet34', 'skresnext50_32x4d', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'xception'

Semantic Segmentation indeed is the task of performing classification at a pixel-level, meaning each pixel will be associated to a given class. Usually when these pixel classifier FCN models are adopted for semantic segmentation it is necessary to specify to the final convolutional layer a number of channels equal to the number of classes expected and call ArgMax activation at the end to select the most scored class for each pixel.

In case of Style Transfer (image to image) scenario the adoption of FCN it’s even easier as it is simply necessary to use a convolutional layer with 3 channels as final output to let the model predict a color image.

The code snippet below show how to easily and flexibly get the specific FCN model you want to use providing as parameter the expected model architecture (unetplusplus), the backbone for the head encoder (mobilenetv3_large_100) and the specific weight for reusing from transfer learning (imagenet).

As you can see in this specific test I used a mobilenetv3_large_100 backbone to target specific performance for live video inferencing on mobile devices but certainly other backbones could be used to get even better quality images for final predictions.

Training

Unfortunately PyTorch Lighting Flash has not been that flexible to directly train with the Flash Trainer the SemanticSegmentation model above on the image to image scenario.

Anyway, very luckily I was able to still use PyTorch Lighting and code my own ImageToImage Lighting Module class to finally train and validate the model with the previously described VGG Perceptual Feature Loss.

The code snippet below shows the ImageToImageModel class and how to instantiate and pass this class to the PyTorch Lighting Trainer to start a training and evaluation loop on the specific ML acceleration you have.

ONNX export, results and HuggingFace Space demo

Exporting a model to ONNX format from PyTorch Lighting or even just PyTorch is nowadays a very simple process.

Specifically for PyTorch Lighting I had to call the following api moving back the model weights first to the CPU (apparently ONNX exporter require this):

It worth to mention also that for the MobileNetV3 backbone I used in this model instance I had to specify a recent opset version in order to support in the ONNX Runtime all new ops used by this model, such as the activation functions like hardwish (see the CoreML paragraph below for more details).

For quickly testing the ONNX model I found super useful the Hugging Face Spaces environment, specifically supporting the easy to use Gradio api.

Here is the code to quickly test this ONNX model using Gradio on user images and caching the model image from the Hugging Face Hub:

A live Gradio demo of this model is available on Hugging Face Spaces at this url: https://huggingface.co/spaces/Jacopo/ToonClip

For convenient reuse I’ve also directly exported the ONNX model on the HuggingFace model Hub — https://huggingface.co/Jacopo/ToonClip

Here is in any case some example images generated by this model from untrained FFHQ validation set:

Sample images provided by the FFHQ dataset with corresponding prediction generated by the model provided by the author

Please directly try this demo on Hugging Face and eventually flag any issues or anomalies using the UI feature provided on the demo (https://huggingface.co/spaces/Jacopo/ToonClip)

CoreML export and image processing

Apple provides a very useful Python package called CoreMLTools (https://coremltools.readme.io/docs) to quickly export ML models from different frameworks such as PyTorch and TensforFlow to the CoreML format, allow the model to get accelerated on specific hardware available on the Apple ecosystem (GPU, Neural Engines, CPU accellerators).

Unfortunately exporting image model require a bit more of coding on the export process, specially if you want to target max performance for live image processing on videos.

Developer need specifically to pre-process and post-process the image obtained in the inferencing platform using usually the Swift programming language and the local libraries for image capture and processing and apply to this image the proper normalization and shaping.

There are different technique to implement this on the model itself, coding in python, or on the client side using Swift, iOS/macOS libraries and the CoreML runtime API specifically.

My suggestion is usually to incapsulate the model in a wrapper Python module and feed directly this wrapper to the CoreML exporter api and finally update the CoreML generated specs (actually in proto-buffer format) to convert the output tensor in a image to easily integrate the model in the client app flow.

In the code snippet below you can find the PyTorch wrapper class for the model previously discussed and the code necessary to export and format the model using image types for both the input and output parameter.

Finally, the specific UNet++ backbone model described in this sample is using some PyTorch operator that are not yet supported by the CoreMLTools conversion library. Luckily CoreMLTools allows ML developers to define custom Ops using their MIL framework api to quickly redefine and register proper conversion for the missing PyTorch Op implementation.

In the specific the following code re-implement the hardwish and hardsigmoid PyTorch activators used by the UNet++ model:

Source Code

Most of the code of this experimentation has already been provided in the code snippets above in this post.

Full source code and a full end-to-end notebook will be provided soon at: https://github.com/JacopoMangiavacchi/ToonClip-ComicsHero

Thank you for reading this article, testing the iPhone App or the Hugging Face demo and thanks again to the people mentioned in this article for the great inspiration. It has been fun learning new things while playing with cartoons!

--

--

Microsoft Principal Data Scientist — Google Machine Learning Developer Expert (ML GDE) — Former  + IBM Senior Architect and Engineer