Writing a custom data augmentation layer in Keras

Subclass Layer, and implement call() with TensorFlow functions

Lak Lakshmanan
Towards Data Science

--

Data augmentation can help an image ML model learn to handle variations of the image that are not in the training dataset. For example, it is likely that photographs provided to an ML model (especially if these are photographs by amateur photographers) will vary quite considerably in terms of lighting. We can therefore increase the effective size of the training dataset and make the ML model more resilient if we augment the training dataset by randomly changing the brightness, contrast, saturation, etc. of the training images.

While Keras has several built-in data augmentation layers (like RandomFlip), it doesn’t currently support changing the contrast and brightness. So, let’s implement one.

Writing the Data Augmentation Layer

The class will inherit from a Keras Layer and take two arguments: the range within which to adjust the contrast and the brightness (full code is in GitHub):

class RandomColorDistortion(tf.keras.layers.Layer):
def __init__(self, contrast_range=[0.5, 1.5],
brightness_delta=[-0.2, 0.2], **kwargs):
super(RandomColorDistortion, self).__init__(**kwargs)
self.contrast_range = contrast_range
self.brightness_delta = brightness_delta

When invoked, this layer will need to behave differently depending on whether it is in training mode or not. If not in training mode, the layer will simply return the original images. If it is in training mode, it will generate two random numbers, one to adjust the contrast within the image and the other to adjust the brightness. The actual adjust is carried out using methods available in the tf.image module:

def call(self, images, training=None):
if not training:
return images

contrast = np.random.uniform(
self.contrast_range[0], self.contrast_range[1])
brightness = np.random.uniform(
self.brightness_delta[0], self.brightness_delta[1])

images = tf.image.adjust_contrast(images, contrast)
images = tf.image.adjust_brightness(images, brightness)
images = tf.clip_by_value(images, 0, 1)
return images

Note: For efficiency, it is important that the implementation of the layer consist of TensorFlow functions so that they can be implemented efficiently on a GPU.

Testing that the layer works

To test that the layer works, simply create the layer and call it on some images:

layer=RandomColorDistortion()
trainds = create_preproc_dataset('gs://practical-ml-vision-book/flowers_tfr/train-*')

for (img, label) in trainds.take(3):
...
for idx in range(1, 5):
aug = layer(img, training=True)
ax[rowno, idx].imshow((aug.numpy()));

The result is shown below:

Random contrast and brightness adjustment on three of the training images. The original images are shown in the first panel of each row, and four generated images shown in the other panels.

Incorporating the layer into a model

To use the layer, simply insert it into the Keras model layers. The layer will be applied during training and be a no-op during evaluation or prediction:

layers = [
...
tf.keras.layers.experimental.preprocessing.RandomFlip(
mode='horizontal',
name='random_lr_flip/none'
),
RandomColorDistortion(name='random_contrast_brightness/none'),
hub.KerasLayer …
]

Does it work?

The purpose of data augmentation is to improve model accuracy and to reduce overfitting. On that count, this layer works quite well on the flowers dataset.

Compare the training plot without data augmentation:

with the training plot after data augmentation:

We get better accuracy (0.88 instead of 0.85) and the training and validation curves remain totally in-sync indicating that overfitting is under control.

Enjoy!

Next Steps:

  1. Try out the notebook. The full code is in GitHub and there are easy links to open the notebook in Colab and in Google Cloud AI Platform Notebooks.
  2. This notebook is part of a series of notebooks for a forthcoming O’Reilly book “Practical ML Vision”. Look at the existing notebooks, and watch/star the GitHub repo to keep up with our work.
  3. The Keras guide on preprocessing layers is very thorough.

--

--