2D or 3D? A Simple Comparison of Convolutional Neural Networks for Automatic Segmentation of Cardiac Imaging

Will Burton
Towards Data Science
6 min readApr 21, 2019

--

Convolutional neural networks (CNNs) have shown promise for a multitude of computer vision tasks. Among these applications is automatic segmentation. Segmentation of medical imaging is used in research settings for the development of computational modeling, quantitative studies, and population-based analysis. Additionally, medical imaging segmentation has applications in industry that range from diagnosis to patient-specific device development. Depending on the application, manual segmentation of a single scan volume can take hours. On the other hand, a well-trained CNN can accurately segment the structures of a scan in a matter of seconds. Accordingly, CNNs have the potential to supplement traditional medical imaging workflows to drive down associated costs.

CNNs for segmentation can be categorized based on the dimension of convolutional kernel that is utilized. 2D CNNs use 2D convolutional kernels to predict the segmentation map for a single slice. Segmentation maps are predicted for a full volume by taking predictions one slice at a time. The 2D convolutional kernels are able to leverage context across the height and width of the slice to make predictions. However, because 2D CNNs take a single slice as input, they inherently fail to leverage context from adjacent slices. Voxel information from adjacent slices may be useful for the prediction of segmentation maps.

2D CNNs predict segmentation maps for MRI slices in a single anatomical plane.

3D CNNs address this issue by using 3D convolutional kernels to make segmentation predictions for a volumetric patch of a scan. The ability to leverage interslice context can lead to improved performance but comes with a computational cost as a result of the increased number of parameters used by these CNNs.

3D CNNs use 3D convolutional kernels to predict segmentation maps for volumetric patches of an MRI volume.

So how much better are 3D CNNs and are they worth the extra cost? In this post I explore this question by running a simple experiment. I compare the performance of 2D and 3D CNNs on a cardiac imaging data set. The resulting metrics yield intuition about the performance gap between these 2 types of CNNs.

2D and 3D CNNs were trained for the task of automatic segmentation of the heart from magnetic resonance imaging (MRI). The data set from [1] was used for this. This data set contains MRIs for 33 patients, where each patient is associated with 20 different volumes, taken at different time steps. The first 25 patients were used for training (5920 axial slices). The remaining scans were kept for testing (2060 slices). The data set provides annotated contours for the endocardium and epicardium. These annotations were used to create segmentation maps with 3 classes: background, myocardium, and inner chamber.

Original slice, annotated contours, and resulting ground truth segmentation map used in the experiments.

A fully-convolutional 2D CNN [2] was trained for segmentation of the scans. Slices were kept in original resolution (256x256 pixels) and the model was trained using a batch size of 8. The model was trained using full 2D slices.

A 3D CNN [3] was trained for the same task. Settings used for the training of the 2D CNN were kept the same for the 3D CNN except for input size and batch size. The 3D CNN was trained on 3D patches from the training set with dimensions 256x256x8 and batch size of 1.

Both CNNs were augmented with 2 dropout layers [4] and deep supervision [5] during training, which I pretty much always use for training deep learning models, regardless of the application. The models were trained for 50k iterations with a learning rate of 5e-6; plenty of training for each model to converge. No data augmentation was used.

The models were evaluated on the test scans by taking the mean dice similarity coefficient (DSC) of the segmentation predictions over each test scan. It helps to think of the DSC as a more discriminating version of “accuracy”. Inference was performed with the 2D CNN by taking each slice of a scan, one at a time. Inference with the 3D CNN was performed by iteratively sampling 3D patches from a scan until all voxels in the scan had predictions associated with them.

Classwise and overall DSC results for each CNN

When looking at DSC, it helps to examine the class-specific DSC of the smaller structures on the segmentation map (i.e. not background) because the DSC metric is not invariant to class frequency. Looks like the 3D CNN performed better. The performance gap between the 2 models was much wider than I had expected it to be. Below are some qualitative results of the predictions. There are definitely a few instances where the 3D CNN was clearly superior.

Some qualitative results: Original slice (top row), 2D CNN predictions (middle), 3D CNN predictions (bottom). Notice the first two columns, where the 3D CNN performed much better.

There are some considerations to discuss. The 2D CNN was trained on a single 1080 Ti GPU with a batch size of 8 slices and dimensions 256x256. On the other hand, the 3D CNN necessitated 2 GPUs for a batch size of 1 with patch dimensions 256x256x8. Also, I had initially trained the 3D CNN using even smaller patches (94x94x8) in order to increase batch size during training but this led to extremely poor results. Therefore, when training 3D CNNs, it may be beneficial to keep a large field-of-view on the slice plane. The speed of training is another consideration. Training iterations for the 2D CNN took ~0.5 seconds per pass (I performed all batch pre-processing in parallel to the actual training steps). A training step with the 3D CNN lasted for 3 seconds; around 6 times as long! This is the difference between training a neural network for 1 day vs. training a neural network for almost a week! One other note: Because the 3D patches had a depth of 8 but the 3D CNN had an output stride of 16, there were complications involved with the max pooling layers. For max pooling layers, I used a stride of 2 for the height and width dimension but kept a stride of 1 for the depth dimension. If you don’t do this, the depth dimension of your batch will go to zero during the forward pass.

So are 3D CNNs worth the effort? In my opinion, yes! But be smart about it. If you are tackling a new problem for the first time, start with 2D CNNs so you can quickly iterate through experiments. After you think you have tuned all the knobs just right, then upgrade to 3D CNNs. Additionally, for your 3D CNNs you can begin training by using smaller patches or down sampled training instances for faster training steps before fine tuning on the desired resolution at the end of training!

Source

Want to chat about deep learning? Reach out to me on LinkedIn or leave a comment below.

References

[1] Andreopoulos A, Tsotsos JK. Efficient and generalizable statistical models of shape and appearance for analysis of cardiac MRI.

[2] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation.

[3] Milletari F, Navab N, Ahmadi SA. V-net: Fully convolutional neural networks for volumetric medical image segmentation.

[4] Srivastava N, Hinton G, Krizhevsky A, Sutskever I, Salakhutdinov R. Dropout: a simple way to prevent neural networks from overfitting.

[5] Lee CY, Xie S, Gallagher P, Zhang Z, Tu Z. Deeply-supervised nets.

--

--