Volumetric Medical Image Segmentation with Vox2Vox

How to implement a 3D volumetric generative adversarial network for CT/ MRI segmentation

Enoch Kan
Towards Data Science

--

If you are familiar with generative adversarial networks (GANs) and their popular variants, the term Pix2Pix should not be an unfamiliar at all. Pix2Pix is a type of conditional GAN (cGAN) that performs image-to-image translation. In the medical field, they are typically used to perform modality translation and in some cases organ segmentation.

“Voxels on My Mind”- Don Backos

Pix2Pix, the Basics

Similar to most GANs, Pix2Pix consists of a single generator network and a single discriminator network. The generator network is nothing but a U-Net, which is a type of deep convolutional neural network originally proposed to perform biomedical image segmentation. U-Net has the following architecture:

‘U-Net Architecture’- Department of Computer Science, University of Freiburg

The U-Net contains three major components: the Encoder, the Bottleneck and the Decoder. Technical details including numbers of input/ output channels, kernel sizes, strides and paddings of Pix2Pix’s generator U-Net can be found in its original paper. In short, the Encoder provides a contracting path which convolves and reduces the dimensionality of a given 2D image. The Bottleneck blocks contain convolutional blocks with skip connections, and finally the Decoder provides an expansion path to upscale the encoded representation. Encoder and decoder layers that share the same size are concatenated along their channels. If you are interested in learning more about the U-Net specifically and how it performs image segmentation, Heet Sankesara has a great article about it.

The discriminator of Pix2Pix is nothing but a standard convolutional network which ‘discriminates’ whether a given image is real (original training data) or fake (generated by the U-Net generator). The training objective of Pix2Pix is a simple minmax formulation of L₂/ MSE loss (adversarial loss) and L₁ loss (reconstruction loss) between the generated and real images.

One of the earliest applications of Pix2Pix was to generate cat pictures from drawings (for all you cool cats and kittens). However, it has also been extended to the medical imaging field to perform domain transfer between magnetic resonance (MR), positron emission tomography (PET) and computed tomography (CT) images.

(left) Christopher Hesse’s Pix2Pix demo (right) MRI Cross-modality translation by Yang et.

One of the major criticisms of Pix2Pix is its inability to perform image-to-image translation on the 3D level. This is a huge roadblock for many medical AI researchers since medical images are usually volumetric of nature. With advancements in graphical processing units (GPUs) and deep neural network designs, researchers have made tremendous breakthroughs in recent years which allow them to perform volumetric segmentation. V-Nets (a simple 3D extension of U-Nets) and Dense V-Nets are common architectures to perform 3D single-/ multi-organ segmentation. M. Cirillo, D. Abramian and A. Eklund from Linkoping University of Sweden have proposed a variant of Pix2Pix, in which they call the Vox2Vox network to perform segmentation in an adversarial fashion. Here is the link to the paper.

Since there are currently no open-sourced PyTorch implementations of Vox2Vox at the time of this blog post, I decided to take a stab at it. The full network implementation can be found at my github repo. It is also linked to PapersWithCode in order to gain more exposure and feedbacks. Please feel free to fork, modify my code and even submit bug fixes as long as you give me credit.

Implementing Vox2Vox

For this project, you are going to need the following dependencies:

  • Python 3.7
  • PyTorch>=0.4.0
  • Torchvision
  • Matplotlib
  • Numpy, Scipy
  • Pillow
  • Scikit-image
  • Pyvista
  • h5py

After cloning the repo, you can install all the required libraries by running:

pip install -r requirements.txt

The implementation of the U-Net generator can be a bit tricky, mainly because the authors of the paper did not specify how the convolutional blocks are concatenated. It took me quite a few hours just to get the volume dimensions right.

“3D U-Net generator with volumetric convolutions and skip-connection bottleneck blocks”- M. Cirillo, D. Abramian and A. Eklund

Nevertheless, below are the three major building blocks of the volumetric U-Net generator. Each block contains a convolutional layer, a normalization layer and an activation layer. The three blocks: encoder blocks, bottleneck blocks and decoder blocks are implemented as follows:

Both the bottleneck and decoder blocks use concatenation. However, the bottleneck blocks are concatenate a little differently from the decoder blocks. As illustrated in the paper, input to a typical bottleneck block is the concatenation of the input and output of its previous block. The output dimensions of the bottleneck blocks should remain constant (8x8x8x512) despite constant concatenation of inputs and outputs. On the other hand, outputs of the decoder blocks are concatenated to outputs of their respective encoder blocks.

“Model Architecture of Vox2Vox”- M. Cirillo, D. Abramian and A. Eklund

With the model architecture description at hand, it is quite straightforward to implement the U-Net generator with the three basic blocks we have previously created:

The implementation of the discriminator is pretty straightforward as it contains the same volumetric encoder blocks as the generator. In fact, you will be able to train with any discriminator architecture you want. Just make sure you pay attention to the balance between the generator and the discriminator during training. By tuning the initial learning rates and setting an accuracy threshold for the discriminator, you can easily balance the two networks through trial-and-error. The authors of 3DGAN have provided many useful advice on balancing and stabilizing training of volumetric GANs in their original 3DGAN paper.

Another important piece of Vox2Vox’s implementation is its loss function. Similar to Pix2Pix, Vox2Vox’s loss can be broken down into adversarial loss (MSE loss) and reconstruction loss (generalized dice loss). In PyTorch, a criterion is analogous to a loss function. Hence, criterion_GAN represents the adversarial loss and criterion_voxelwise represents the reconstruction loss. Calculation of generator loss contains a tunable parameter lambda, which controls the ratio between adversarial and reconstruction losses. Below is the implementation for the loss function during generator and discriminator training:

Generalized Dice Loss

In medical image segmentation, a commonly used loss function is the generalized dice loss. Sørensen–Dice coefficient is typically used to evaluate the similarity between two samples, and it has the following formula:

TP stands for True Positive, FP stands for False Positive and FN stands for False Negatives. Dice coefficients usually range from 0 to 1, with 1 representing a perfect match between two given samples. Generalized dice loss is a simple modification of dice score to provide a loss function for minimization during deep learning training. Below is my PyTorch implementation of the generalized dice loss:

Although it is possible for me to go through the rest of my code line-by-line, I’m going to leave it to you to read through and understand my implementation. It is definitely helpful to read through the original paper. I have nothing but respect for the authors for creating Vox2Vox, so please give them credits as well if you decide to implement Vox2Vox on your own. Comment below if you have any questions and tell me what you’d like to see on my next blog post! Stay safe and stay curious :)

--

--

ML Lead @ Kognitiv, Founder @ Kortex Labs, The ML Practitioner 🇬🇧 🇺🇸 🇭🇰