Efficient Image Segmentation Using PyTorch: Part 2

A CNN-based model

Dhruv Matani
Towards Data Science

--

This is the second part of the 4-part series to implement image segmentation step by step from scratch using deep learning techniques in PyTorch. This part will focus on implementing a baseline image segmentation Convolutional Neural Network (CNN) model.

Co-authored with Naresh Singh

Figure 1: Result of running image segmentation using a CNN. In order from top to bottom, input images, ground truth segmentation masks, predicted segmentation masks. Source: Author(s)

Article outline

In this article, we will implement a Convolutional Neural Network (CNN) based architecture called SegNet which will assign each pixel in an input image to a corresponding pet such as a cat or a dog. The pixels which don’t belong to any pet will be classified as background pixels. We will build and train this model on the Oxford Pets dataset using PyTorch to develop a sense of what it takes to deliver a successful image segmentation task. The model building process will be hands-on where we will discuss in detail the role of each layer in our model. The article will contain plenty of references to research papers and articles for further learning.

Throughout this article, we will reference the code and results from this notebook. If you wish to reproduce the results, you’ll need a GPU to ensure that the notebook completes running in a reasonable amount of time.

Articles in this series

This series is for readers at all experience levels with deep learning. If you want to learn about the practice of deep learning and vision AI along with some solid theory and hands-on experience, you’ve come to the right place! This is expected to be a 4-part series with the following articles:

  1. Concepts and Ideas
  2. A CNN-based model (this article)
  3. Depthwise separable convolutions
  4. A Vision Transformer-based model

Let’s start this discussion with a short introduction to the Convolution layers and a few other layers that are typically used together as a convolution block.

Conv-BatchNorm-ReLU and Max Pooling/Unpooling

A convolution, batch-normalization, ReLU block is the holy trinity of vision AI. You’ll see it used frequently with CNN-based vision AI models. Each of these terms stands for a distinct layer implemented in PyTorch. The convolution layer is responsible for performing a cross-correlation operation of learned filters on the input tensor. Batch Normalization centers the elements in the batch to zero mean and unit variance, and ReLU is a non-linear activation function that keeps just the positive values in the input.

A typical CNN progressively reduces the input spatial dimensions as layers are stacked. The motivation behind the reduction of spatial dimensions is discussed in the next section. This reduction is achieved by pooling the neighboring values using a simple function such as max or average. We will discuss this further in the Max-Pooling section. In classification problems, the stack of Conv-BN-ReLU-Pool blocks is followed by a classification head which predicts the probability that input belongs to one of the target classes. Some sets of problems such as Semantic Segmentation require per-pixel prediction. For such cases, a stack of upsampling blocks are appended after the downsampling blocks to project their output to the required spatial dimension. The upsampling blocks are nothing but Conv-BN-ReLU-Unpool blocks which replace the pooling layer with an un-pooling layer. We will talk more about un-pooling in the Max-Pooling section.

Now, let’s further elaborate on the motivation behind convolution layers.

Convolution

Convolutions are the basic building blocks of vision AI models. They are used heavily in computer vision and have historically been used to implement vision transformations such as:

  1. Edge detection
  2. Image blurring and sharpening
  3. Embossing
  4. Intensification

A convolution operation is an elementwise multiplication and aggregation of two matrices. An example convolution operation is shown in Figure 2.

Figure 2: An illustration of the convolution operation. Source: Author(s)

In a deep learning context, convolution is carried out between an n-dimensional parameter matrix called a filter or a kernel over a larger-sized input. This is achieved by sliding the filter over the input and applying convolution to the corresponding section. The extent of the slide is configured using a stride parameter. A stride of one means the kernel slides over by one step to operate on the next section. As opposed to the traditional approaches where a fixed filter is used, deep learning learns the filter from the data using backpropagation.

So how do convolutions assist in deep learning?

In deep learning, a convolution layer is used to detect visual features. A typical CNN model contains a stack of such layers. The bottom layers in the stack detect simple features such as lines and edges. As we move up in the stack, the layers detect increasingly complex features. Middle layers in the stack detect combinations of lines and edges and the top layers detect complex shapes such as a car, a face or an airplane. Figure 3 shows visually the output of top and bottom layers for a trained model.

Figure 3: What convolutional filters learn to identify. Source: Convolutional Deep Belief Networks for Scalable Unsupervised Learning of Hierarchical Representations

A convolution layer has a set of learnable filters that act on small regions in the input to produce a representative output value for each region. For example, a 3x3 filter operates over a 3x3 size region and produces a value representative of the region. The repeated application of a filter over input regions produces an output which becomes the input to the next layer in the stack. Intuitively, the layers higher up get to “see” a larger region of the input. For example, a 3x3 filter in the second convolution layer operates on the output of the first convolution layer where each cell contains information about the 3x3 sized region in the input. If we assume a convolution operation with stride=1, then the filter in the second layer will “see’’ the 5x5 sized region of the original input. This is called the receptive field of the convolution. The repeated application of convolutional layers progressively reduces the spatial dimensions of the input image and increases the field of vision of the filters which enables them to “see” complex shapes. Figure 4 shows the processing of a 1-D input by a convolution network. An element in the output layer is a representative of a relatively larger input chunk.

Figure 4: Receptive field of a 1d convolution with kernel size=3, applied 3 times. Assume stride=1 and no padding. After the 3rd successive application of the convolutional kernel, a single pixel is able to see 7 pixels in the original input image. Source: Author(s)

Once a convolutional layer can detect these objects and is able to generate their representations, we can use these representations for image classification, image segmentation, and object detection and localization. Broadly speaking, CNNs adhere to the following general principles:

  1. A Convolution layer either keeps the number of output channels © intact or doubles them.
  2. It keeps the spatial dimensions intact using a stride=1 or reduces them to a half using stride=2.
  3. It’s common to pool the output of a convolution block to change the spatial dimensions of an image.

A convolution layer applies the kernel independently to each input. This could cause its output to vary for different inputs. A Batch Normalization layer typically follows a convolution layer to address this problem. Let’s understand its role in detail in the next section.

Batch Normalization

Batch Normalization layer normalizes the channel values in the batch input to have a zero mean and a unit variance. This normalization is performed independently for each channel in the batch to ensure that the channel values for the inputs have the same distribution. Batch Normalization has the following benefits:

  1. It stabilizes the training process by preventing the gradients from becoming too small.
  2. It achieves faster convergence on our tasks.

If all we had was a stack of convolution layers, it would essentially be equivalent to a single convolution layer network because of the cascading effect of linear transformations. In other words, a sequence of linear transformations can be replaced with a single linear transformation which has the same effect. Intuitively, if we multiply a vector with a constant k₁ followed by multiplication with another constant k₂, it is equivalent to a single multiplication by a constant k₁k₂. Hence, for the networks to be realistically deep, they must have a non-linearity to prevent their collapse. We will discuss ReLU in the next section which is frequently used as a non-linearity.

ReLU

ReLU is a simple non-linear activation function which clips the lowest input values to be greater than or equal to 0. It also helps with the vanishing gradients problem limiting the outputs to be greater than or equal to 0. The ReLU layer is typically followed by a pooling layer to shrink the spatial dimensions in the downscaling subnetwork or an un-pooling layer to bump the spatial dimensions in the upscaling subnetwork. The details are provided in the next section.

Pooling

A pooling layer is used to shrink the spatial dimensions of our inputs. Pooling with stride=2 will transform an input with spatial dimensions (H, W) to (H/2, W/2). Max-pooling is the most commonly used pooling technique in deep CNNs. It projects the maximum value in a grid of (say) 2x2 onto the output. Then, we slide the 2x2 pooling window to the next section based on the stride similar to convolutions. Doing this repeatedly with a stride=2 results in an output that is half the height and half the width of the input. Another commonly used pooling layer is the average-pooling layer, which computes the average instead of the max.

The reverse of a pooling layer is called an un-pooling layer. It takes an (H, W) dimension input and converts it into a (2H, 2W) dimension output for stride=2. A necessary ingredient of this transformation is selecting the location in the 2x2 section of the output to project the input value. To do this, we need a max-unpooling-index-map which tells us the target locations in the output section. This unpooling-map is produced by a previous max-pooling operation. Figure 5 shows examples of pooling and un-pooling operations.

Figure 5: Max pooling and un-pooling. Source: DeepPainter: Painter Classification Using Deep Convolutional Autoencoders

We can consider max-pooling as a type of non-linear activation function. However, it’s reported that using it to replace a non-linearity such as ReLU affects the network’s performance. In contrast, average pooling can not be considered as a nonlinear function since it uses all its inputs to produce an output that is a linear combination of its inputs.

This covers all the basic building blocks of deep CNNs. Now, let’s put them together to create a model. The model we have chosen for this exercise is called a SegNet. We’ll discuss it next.

SegNet: A CNN-based model

SegNet is a deep CNN model based on the fundamental blocks that we have discussed in this article. It has two distinct sections. The bottom section, also called an encoder, down-samples the input to generate features representative of the input. The top decoder section up-samples the features to create per-pixel classification. Each section is composed of a sequence of Conv-BN-ReLU blocks. These blocks also incorporate pooling or un-pooling layers in downsampling and upsampling paths respectively. Figure 6 shows the arrangement of the layers in more detail. SegNet uses the pooling indices from the max-pooling operation in the encoder to determine which values to copy over during the max-unpooling operation in the decoder. While each element of an activation tensor is 4-bytes (32-bits), an offset within a 2x2 square can be stored using just 2-bits. This is more efficient in terms of memory used since these activations (or indices in the case of SegNet) need to be stored while the model runs.

Figure 6: The SegNet model architecture for image segmentation. Source: SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

This notebook contains all the code for this section.

This model has 15.27M trainable parameters.

The following configuration was used during model training and validation.

  1. The random horizontal flip and color jitter data augmentations are applied to the training set to prevent overfitting
  2. The images are resized to 128x128 pixels in a non-aspect preserving resize operation
  3. No input normalization is applied to the images; instead a batch normalization layer is used as the first layer of the model
  4. The model is trained for 20 epochs using the Adam optimizer with a LR of 0.001 and a StepLR Scheduler that decays the learning rate by 0.7 every 7 epochs
  5. The cross-entropy loss function is used to classify a pixel as belonging to a pet, the background, or a pet border

The model achieved a validation accuracy of 88.28% after 20 training epochs.

We plotted a gif showing how the model is learning to predict the segmentation masks for 21 images in the validation set.

Figure 6: A gif showing how the SegNet model is learning to predict segmentation masks for 21 images in the validation set. Source: Author(s)

The definitions of all the validation metrics are described in Part-1 of this series.

If you’d like to see a fully-convolutional model for segmenting pet images implemented using Tensorflow, please see Chapter-4: Efficient Architectures of the Efficient Deep Learning Book.

Observations from model learning

Based on the development of the predictions that the trained model makes after every epoch, we can observe the following.

  1. The model is able to learn enough to make the output look in the right ballpark of the pet in the image even as early as 1 training epoch
  2. The border pixels are harder to segment since we’re using an unweighted loss function that treats each success (or failure) equally, so getting the border pixels wrong doesn’t cost the model much in terms of the loss. We would encourage you to investigate this and check what strategies you could try to fix this issue. Try using Focal Loss and see how it performs
  3. The model seems to be learning even after 20 training epochs. This suggests that we could improve validation accuracy if we trained the model longer
  4. Some of the ground-truth labels themselves are hard to figure out — for example, the mask of the dog in the middle row, last column has a lot of unknown pixels in the area where the body of the dog is occluded by plants. This is very hard for the model to figure out, so one should always expect loss in accuracy for such examples. However, this doesn’t mean that the model isn’t doing well. One should always spot check the predictions to develop a sense of the model’s behavior in addition to looking at the overall validation metrics.
Figure 7: An example of the ground truth segmentation mask containing a lot of unknown pixels. This is a very hard input for any ML model. Source: Author(s)

Conclusion

In part-2 of this series, we learned about the basic building blocks of deep CNNs for vision AI. We saw how to implement the SegNet model from scratch in PyTorch, and we visualized how the model trained on successive epochs performs on 21 validation images. This should help you appreciate how quickly models can learn enough to make the output look somewhere in the right ballpark. In this case, we can see segmentation masks that roughly resemble the actual segmentation mask as early as the first training epoch!

In the next part of this series, we’ll take a look at how we can optimize our model for on-device inference and reduce the number of trainable parameters (and hence model size) while keeping the validation accuracy roughly the same.

Further Reading

Read more about convolutions here:

  1. The course titled “Ancient Secrets of computer vision” at The University of Washington taught by Joseph Redmon has an excellent set of videos on Convolutions (especially chapters 4, 5, and 13), which we highly recommend watching
  2. A guide to convolution arithmetic for deep learning (highly recommended)
  3. https://towardsdatascience.com/computer-vision-convolution-basics-2d0ae3b79346
  4. The Conv2d layer in PyTorch (documentation)
  5. What do convolutions learn?
  6. Convolution visualizer

Read more about batch normalization here:

  1. Batch normalization: Wikipedia
  2. Batch normalization: Machine learning mastery
  3. BatchNorm2d layer in PyTorch here.

Read more about activation functions and ReLU here:

  1. ReLU: Machine learning mastery
  2. ReLU: Wikipedia
  3. ReLU: Quora
  4. ReLU API in PyTorch

--

--

Machine Learning, PyTorch, CNNs, Transformers, Vision, Speech, Text AI. On-Device AI, Model Optimization, ML and Data Infrastructure. My views are my own.