Enhancing the Robustness of Image Classification Models with AugMix

Adding Consistency Loss between AugMix Augmented Images to enhance the Generalization of your Image Classification Model

Lihi Gur Arie, PhD
Towards Data Science

--

Figure 1 - Visualizing AugMix: Original Image (left) and Two Augmented Versions. | Image by author

Introduction

Image classification models are best able to predict data from the same distribution as the training data. However, in real-world scenarios, the input data may suffer from variations. When inferencing with different cameras for example, the lighting conditions, contrast, color distortions etc. might vary from the training set, and significantly affect the performance of the model. To address this challenge, the AugMix algorithm by Hendrycks et al. [1] can be applied to any image classification model to improve its robustness and uncertainty estimates.

AugMix is a data augmentation technique that generates augmented variations of each training image. When combined with a consistency loss, it encourages the model to make consistent predictions for all versions of the same image. Although training the model with these augmented data versions takes longer, the resulting model becomes more stable, consistent, and resistant to a wide range of inputs.

AugMix Augmentation Technique

Augmented versions of training images are generated by applying three parallel chains composed of one to three randomly selected augmented operations such as translation, shear, and contrast, with randomly determined intensities. The chains are then combined with the original image at varied weights to yield a single version of the augmented image. The augmented version incorporate several sources of randomness, including the choice of operations, the intensity of these operations, the length of the augmentation chains, and the mixing weights.

Figure 2: A realization of AugMix Augmentation Chains, taken from Hendrycks et al. (2020).

For each training image, AugMix generates two augmented versions (augmix1 and augmix2) that preserve the semantic content of the image. You can view a demonstration of an image and its augmented versions in image 1.

Loading dataset with AugMix versions

The data loader should process both the original images and their modified versions generated using the AugMix technique. The PyTorch Image Models (timm) library [2] provides a convenient implementation for creating a PyTorch dataset and generating these AugMix augmentations.

train_dir = '/path/to/training/images/dir/'
train_dataset = ImageDataset(train_dir)
train_dataset = AugMixDataset(train_dataset, num_splits=3)
loader_train = create_loader(train_dataset,
input_size=(3, 224, 224),
batch_size=32,
is_training=True,
scale=[0.08, 1.],
ratio=[0.75, 1.33],
num_aug_splits=3)

The code above creates a training dataset and a data loader. The ImageDataset class is used to create a dataset object from the training images. The AugMixDataset class is used to augment the training images by generating additional, modified versions of the original images. The num_splits argument specifies how many augmented versions of each image should be generated. For example, if num_splits=3, then for each image in the dataset, two modified versions will be generated in addition to the original image. The paper authors recommend using two augmented versions, so num_splits should be set to three. The create_loader function is used to create a data loader object, which can be used to iterate over the training dataset in mini-batches during the training process.

Jensen-Shannon consistency loss

In each forward pass, the original image is passed to the model along with the two augmented images augmix1 and augmix2. To encourage the model to make consistent predictions for differently augmented versions of the same input data, the Jensen-Shannon divergence (JS) is added to the original cross entropy loss (L), weighted by the lambda hyperparameter.

Jensen-Shannon Consistency Loss is Composed of the Original Cross Entropy Loss (L) and the Jensen-Shannon Divergence (JS)

The Jensen-Shannon divergence (JS) is computed by first obtaining the average prediction probability:

Then, the average KL divergence between each image version and the average prediction probability is computed.

The Jensen-Shannon Consistency Loss forces the model to embed similarly all versions of the same image, which can help it learn more robust and generalizable features.

The timm library includes a user-friendly implementation of the Jensen-Shannon Consistency Loss. The num_splits argument specifies the number of augmented versions. The alpha argument, also known as the weighting factor lambda in the original paper, specifies a weighting factor for the JSD term. The default value is 12, but it should be adjusted to best fit the characteristics of the data being used to train the model.

loss_fn = JsdCrossEntropy(num_splits=3, alpha = 12)

Code overview

The following code demonstrates how to incorporate the AugMix data augmentation technique into a simplified training loop using the timm library. The code begins by creating a training dataset using the AugMix technique to generate modified versions of the original images. The data loader iterates over the training dataset, processing both the original images and their augmented versions. A model, loss function, optimizer, and scheduler are also created. Finally, the training loop iterates over the data loader for a specified number of epochs.

from timm import create_model
from timm.data import ImageDataset, AugMixDataset, create_loader
from timm.loss import JsdCrossEntropy
from timm.optim import AdamP
from tqdm import tqdm
from timm.scheduler import CosineLRScheduler
from matplotlib import pyplot as plt

epochs = 50

train_dir = '/path/to/training/images/dir/'
train_dataset = ImageDataset(train_dir)
train_dataset = AugMixDataset(train_dataset, num_splits=3)
loader_train = create_loader(train_dataset,
input_size=(3, 224, 224),
batch_size=1,
is_training=True,
scale=[0.08, 1.],
ratio=[0.75, 1.33],
num_aug_splits=3)


model = create_model('resnet18', pretrained=True, num_classes=2).cuda()
loss_JSL = JsdCrossEntropy(num_splits=3)
optimizer = AdamP(model.parameters(), lr=0.01)
scheduler = CosineLRScheduler(optimizer, t_initial=epochs)

for epoch in tqdm(range(epochs)):
for batch in loader_train:
inputs, targets = batch

# Forward pass
outputs = model(inputs)
loss = loss_JSL(outputs, targets)

# backward_pass
loss.backward()
optimizer.step()
optimizer.zero_grad()

scheduler.step(epoch + 1)

Concluding Remarks

In conclusion, AugMix is a powerful data processing technique for improving the robustness and uncertainty estimates of image classifiers. It is useful in situations where the data distribution encountered during deployment may differ from the training distribution, such as when using different cameras to capture images. AuxMix has been shown to effectively improve generalization performance without requiring any changes to the basic model, and it was recognized as a state-of-the-art (SOTA) algorithm in the Domain Generalization challenge on ImageNet-C, a benchmark dataset that includes various corruptions applied to images from the ImageNet dataset. It is simple to implement and worth the computational overhead it adds to the training process. Personally, I found it useful in my projects, and I recommend it as a tool that adds an extra value to the robustness of image classifiers.

Thank you for reading!

Want to learn more?

References

[1] Hendrycks, D., Gimpel, K., & Kim, B. (2020). AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty. arXiv preprint arXiv:2006.13718.

[2] The PyTorch Image Models (timm) library. https://timm.fast.ai/

--

--