Train a neural net for semantic segmentation in 50 lines of code, with Pytorch

Sagi eppel
Towards Data Science
9 min readDec 3, 2021

--

How to train a neural net for semantic segmentation in less than 50 lines of code (40 if you exclude imports). The goal here is to give the fastest simplest overview of how to train semantic segmentation neural net in PyTorch using the built-in Torchvision neural nets (DeepLabV3).

Code is available: https://github.com/sagieppel/Train-Semantic-Segmentation-Net-with-Pytorch-In-50-Lines-Of-Code

The goal is semantic segmentation is to take images and identify regions belonging to specific classes. This is done by processing the image through a convolution neural network that outputs a map with a class per pixel. The classes are given as a set of numbers. For example, in this case, we will use the LabPics V1 dataset with three classes (shown in the figure below):

Images , corresponding segmentation masks: Black (0) = background, Gray (1) = Empty vessel, White (2) = Filled region. Image by the author.

Class 0: Not a vessel (black),
Class 1: Empty region of the vessel(gray),
Class 2: Filled region of the vessel(white).

The goal of the net is to receive an image and predict for each pixel one of the 3 classes.

The first step download the LabPics dataset from here: https://zenodo.org/record/3697452/files/LabPicsV1.zip?download=1

You will also need to install Pytorch and OpenCV for image reading.

OpenCV can be installed using:

pip install opencv-python

First, let's import packages and define the main training parameters:

import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf
Learning_Rate=1e-5width=height=800 # image width and height
batchSize=3

Learning_Rate: is the step size of the gradient descent during the training.

Width and height are the dimensions of the image used for training. All images during the training processes will be resized to this size.

batchSize: is the number of images that will be used for each iteration of the training.

batchSize*Width*Height will be proportional to the memory requirement of the training. Depending on your hardware, it might be necessary to use a smaller batchSize to avoid out-of-memory problems.

Note that since we train with only a single image size, the net once trained is likely to be limited to work with only images of this size. In most cases what you want to do is change the size between each training batch.

Next we create a list of all images in the dataset:

TrainFolder="LabPics/Simple/Train//"
ListImages=os.listdir(os.path.join(TrainFolder, "Image"))

Were TrainFolder and is the LabPics dataset train folder.
The images are stored in the “image” subfolder of the TrainFolder.

Next, we define a set of transformations that will be performed on the image using the TorchVision transform module:

transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor()])

This defines a set of transformations that will apply to the image and annotation map. This includes converting to PIL format, which is the standard format for the transform. Resizing and converting to PyTorch format. For the image, we also normalize the intensity of the pixels in the image by subtracting the Mean and dividing by the standard deviation of pixels intensity. The mean and deviation were calculated beforehand large set of images.

Next, we create a function that will allow us to load a random image and the corresponding annotation map for training:

def ReadRandomImage():   idx=np.random.randint(0,len(ListImages)) # Pick random image   Img=cv2.imread(os.path.join(TrainFolder, "Image",ListImages[idx]))  Filled =  cv2.imread(os.path.join(TrainFolder,   "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0)       

Vessel = cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0)
AnnMap = np.zeros(Img.shape[0:2],np.float32) # Segmentation map
if Vessel is not None: AnnMap[ Vessel == 1 ] = 1
if Filled is not None: AnnMap[ Filled == 1 ] = 2
Img=transformImg(Img)
AnnMap=transformAnn(AnnMap)
return Img,AnnMap

In the first part, we pick a random index from the list of images and load the image corresponding to this index.

idx = np.random.randint(0,len(ListImages)) # Pick random imageImg = cv2.imread(os.path.join(TrainFolder, "Image",ListImages[idx]))

Next, we want to load the annotations masks for the image:

Filled =  cv2.imread(os.path.join(TrainFolder,   "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0)       

Vessel = cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0)

These annotations are stored as images/masks that cover the region belonging to the specific class (Filled/Vessel). Each class mask is stored in a separate .png image file. Where pixels belonging to the class have values of 1 (gray) and the others are 0 (black).

Image, Masks for the filling and Vessel regions, and unified annotation map (gray=1, white=2, black=0). The image by the author.

To train the net, we need to create one segmentation map where the values of pixels belonging to the empty vessel region are 1 (gray), the values of the pixels belonging to the filled region are 2 (white), and the rest are 0 (black).

First, we create a segmentation map full of zeros in the shape of the image:

AnnMap = np.zeros(Img.shape[0:2],np.float32)

Next, we set all the pixels that have a value of 1 in the Vessel mask to have a value of 1 in the segmentation mask. And all the pixels that value of 1 in the Filled mask to have a value of 2 in the segmentation mask:

if Vessel is not None:  AnnMap[ Vessel == 1 ] = 1    
if Filled is not None: AnnMap[ Filled == 1 ] = 2

Where “AnnMap[ Filled == 1 ] = 2” means that every position in the Filled mask with a value of 1, will get a value of 2 in the AnnMap.

If there is no annotation file for the Vessel and Filled classes (which will happen if a class does not appear in the image), the cv2.imread will return None, and the mask will be ignored.

Finally, we convert the annotation into PyTorch format using the transformation we defined earlier:

Img=transformImg(Img)
AnnMap=transformAnn(AnnMap)

For training, we need to use a batch of images. This means several images stacked on top of each other in a 4D matrix. We create the batch using the function:

def LoadBatch(): # Load batch of images
images = torch.zeros([batchSize,3,height,width])
ann = torch.zeros([batchSize, height, width])

for i in range(batchSize):
images[i],ann[i]=ReadRandomImage()

return images, ann

The first part creates an empty 4d matrix that will store the images with dimensions: [batchSize, channels, height, width], where channels are the number of layers for the image; this is 3 for RGB image and 1 for the annotation map.

The next part load set of images and annotation to the empty matrix, using the ReadRandomImage() we defined earlier.

for i in range(batchSize):
images[i],ann[i]=ReadRandomImage()

Now that we can load our data, its time to load the neural net:

device = torch.device(‘cuda’) if torch.cuda.is_available() else torch.device(‘cpu’)Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classesNet=Net.to(device)optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate,weight_decay=Weight_Decay) # Create adam optimizer

The first part is identifying whether the computer has GPU or CPU. If there is Cuda GPU the training will be done on the GPU:

device = torch.device(‘cuda’) if torch.cuda.is_available() else torch.device(‘cpu’)

For any practical dataset, training using a CPU is extremely slow.

Next, we load the deep lab net semantic segmentation:


Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)

torchvision.models. contain many useful models for semantic segmentation like UNET and FCN . We choose Deeplabv3 since its one best semantic segmentation nets. By setting pretrained=True we load the net with weight pretrained on the COCO dataset. It is always better to start from the Pretrained model when learning a new problem since it allows the net to use the previous experience and converge faster.

We can see all the layers of the net we just loaded by writing:

print(Net)

….

(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
(4): Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))

This prints the net of the layer in the order they are used. The final layer of the network is a convolution layer with 256 layers input and 21 layers output. The 21 represent the number of output classes. Since we only have 3 classes in our dataset, we want to replace it with a new convolutional layer with 3 outputs:

Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) 

To be fair this part is optional since a net with 21 output classes can predict 3 classes simply by ignoring the reminder 18 classes. But it's more elegant this way.

Next, we load the net into our GPU or CPU device:

Net=Net.to(device)

Finally, we load an optimizer:

optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer

The optimizer will control the gradient rates during the backpropagation step of the training. Adam optimizer is one of the fastest optimizers available.

Finally, we start the training loop:

for itr in range(20000): # Training loop
images,ann=LoadBatch()

images=torch.autograd.Variable(images,requires_grad=False)
.to(device)

ann = torch.autograd.Variable(ann,requires_grad=False).to(device)

Pred=Net(images)[‘out’] # make prediction

LoadBatch was defined earlier and load the batch of images and annotation maps. images and ann will store the loaded images and annotations.

torch.autograd.Variable: convert the data into gradient variables that can be used by the net. We set Requires_grad=False since we don't want to apply the gradient to the image, only to the layers of the net. The to(device) copy the tensor to the same device (GPU/CPU) as the net.

Finally, we input the image to the net and get the prediction.

Pred=Net(images)[‘out’] # make prediction

Once we made a prediction, we can compare it to the real (ground truth) annotation and calculate the loss:

criterion = torch.nn.CrossEntropyLoss() # Set loss function
Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
Loss.backward() # Backpropogate loss
Optimizer.step() # Apply gradient descent change to weight

First, we define the loss function. We use the standard cross-entropy loss:

criterion = torch.nn.CrossEntropyLoss()

We use this function to calculate the loss using the prediction and the real annotation:

Loss=criterion(Pred,ann.long())

Once we calculate the loss, we can apply the backpropagation and change the net weights.

Loss.backward() # Backpropogate loss
Optimizer.step() # Apply gradient descent change to weight

This covers the full training stage, but we also need to save the trained model. Otherwise, it will be lost once the program stop.
Saving is time-consuming, so we want to do it about once every 1000 steps:

if itr % 1000 == 0: 
print(“Saving Model” +str(itr) + “.torch”)
torch.save(Net.state_dict(), str(itr) + “.torch”)

After running this script about 3000 steps, the net should give decent results.

Full code can be found here:

All together 50 lines of code not including spaces, and 40 lines not including imports:-)

Finally, once the net has been trained, we want to apply to segment real image and see the result. We do this using a separate inference script that use the train net to segment an image:

import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf
import matplotlib.pyplot as plt
modelPath = "3000.torch" # Path to trained model
imagePath = "test.jpg" # Test image
height=width=900
transformImg = tf.Compose([tf.ToPILImage(), tf.Resize((height, width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) Net = Net.to(device) # Set net to GPU or CPUNet.load_state_dict(torch.load(modelPath)) # Load trained modelNet.eval() # Set to evaluation modeImg = cv2.imread(imagePath) # load test imageheight_orgin , widh_orgin ,d = Img.shape # Get image original size plt.imshow(Img[:,:,::-1]) # Show imageplt.show()
Img = transformImg(Img) # Transform to pytorch
Img = torch.autograd.Variable(Img, requires_grad=False).to(device).unsqueeze(0)with torch.no_grad():
Prd = Net(Img)['out'] # Run net
# resize to orginal size
Prd = tf.Resize((height_orgin,widh_orgin))(Prd[0])
#Convert probability to class map
seg = torch.argmax(Prd, 0).cpu().detach().numpy()
plt.imshow(seg) # display image
plt.show()

Most of the code here is the same as the training script, with only a few differences:

Net.load_state_dict(torch.load(modelPath)) # Load trained model

Load the net we trained and saved earlier from the file in modelPath

Net.eval()

Convert the net from training mode to evaluation mode. This mainly means no batch normalization statistics will be calculated.

with torch.no_grad():

This means the net is run without collecting gradients. Gradients are only relevant for training and collecting them is resource-intensive.

Note that the output in Pred will be mapped with 3 channels per image, with each channel representing the unnormalize probability for one of the 3 classes. To find for each pixel the class it belongs to, we take the channel (class) with the highest value of the 3 using the argmax function:

seg = torch.argmax(Prd[0], 0)

We do it for every pixel in the output map and get one of 3 classes for every pixel.

The results:

Input Image:

Output prediction:

--

--

I am a researcher at the University of Toronto, focusing on applying computer vision for controlling autonomous chemistry lab.