The world’s leading publication for data science, AI, and ML professionals.

The Ultimate Guide to nnU-Net for State of the Art Image Segmentation

A theoretical and practical guide on how to use nnU-Net for Computer Vision and Semantic Image Segmentation and deliver SOTA performance.

The Ultimate Guide to nnU-Net

Everything you need to know to understand the State of the Art nnU-Net, and how to apply it to your own dataset.

Neuroimaging, by Milak Fakurian on Unsplash, link
Neuroimaging, by Milak Fakurian on Unsplash, link

During my Research internship in Deep Learning and Neurosciences at Cambridge University, I used the nnU-Net a lot, which is an extremely strong baseline in Semantic Image Segmentation.

However, I struggled a little to fully understand the model and how to train it, and did not find so much help on internet. Now that I am comfortable with it, I created this tutorial to help you, either in your quest to understand better what is behind this model, or how to use it in your own dataset.

Throughout this guide, you will:

  1. Develop a concise overview of the key contributions of nnU-Net.
  2. Learn how to apply nnU-Net to your own dataset.

All code available on this Google Collab notebook

This work took me a significant amount of time and effort. If you find this content valuable, please consider following me to increase its visibility and help support the creation of more such tutorials!

A Brief History of nnU-Net

Recognized as a state-of-the-art model in Image Segmentation, the nnU-Net is an indomitable force when it comes to both 2D and 3D image processing. Its performance is so robust that it serves as a strong baseline against which new Computer Vision architectures are benchmarked. In essence, if you are venturing into the world of developing novel computer vision models, consider the nnU-Net as your ‘target to surpass’.

This powerful tool is based on the U-Net model (You can find one of my tutorials here: Cook your first U-Net), which made its debut in 2015. The appellation "nnU-Net" stands for "No New U-Net", a nod to the fact that its design doesn’t introduce revolutionary architectural alterations. Instead, it takes the existing U-Net structure and squeezes out its full potential using a set of ingenious optimization strategies.

Contrary to many modern neural networks, the nnU-Net doesn’t rely on residual connections, dense connections, or attention mechanisms. Its strength lies in its meticulous optimization strategy, which includes techniques like resampling, normalization, judicious choice of loss function, optimiser settings, data augmentation, patch-based inference, and ensembling across models. This holistic approach allows the nnU-Net to push the boundaries of what’s achievable with the original U-Net architecture.

Exploring Diverse Architectures within nnU-Net

While it might seem like a singular entity, the nnU-Net is in fact an umbrella term for three distinct types of U-Nets:

2D, 3D, and cascade, Image from nnU-Net article
2D, 3D, and cascade, Image from nnU-Net article
  1. 2D U-Net: Arguably the most well-known variant, this operates directly on 2D images.
  2. 3D U-Net: This is an extension of the 2D U-Net and is capable of handling 3D images directly through the application of 3D convolutions.
  3. U-Net Cascade: This model generates low-resolution segmentations and subsequently refines them.

Each of these architectures brings its unique strengths to the table and, inevitably, has certain limitations.

For instance, employing a 2D U-Net for 3D image segmentation might seem counterintuitive, but in practice, it can still be highly effective. This is achieved by slicing the 3D volume into 2D planes.

While a 3D U-Net may seem more sophisticated, given its higher parameter count, it isn’t always the most efficient solution. Particularly, 3D U-Nets often struggle with anisotropy, which occurs when spatial resolutions differ along different axes (for example, 1mm along the x-axis and 1.2 mm along the z-axis).

The U-Net Cascade variant becomes particularly handy when dealing with large image sizes. It employs a preliminary model to condense the image, followed by a standard 3D U-Net that outputs low-resolution segmentations. The generated predictions are then upscaled, resulting in a refined, comprehensive output.

Image from nnU-Net article
Image from nnU-Net article

Typically, the methodology involves training all three model variants within the nnU-Net framework. The subsequent step may be to either choose the best performer among the three or employ ensembling techniques. One such technique might involve integrating the predictions of both the 2D and 3D U-Nets.

However, it’s worth noting that this procedure can be quite time-consuming (and also money because you need GPU credits). If your constraints only allow for the training of a single model, fret not. You can choose to only train one model, since the ensembling model only brings very marginal gains.

This table illustrates the best-performing model variant in relation to specific datasets:

Image from nnU-Net article
Image from nnU-Net article

Dynamic adaptation of network topologies

Given the significant discrepancies in image size (consider the median shape of 482 × 512 × 512 for liver images versus 36 × 50 × 35 for hippocampus images), the nnU-Net intelligently adapts the input patch size and the number of pooling operations per axis. This essentially implies an automatic adjustment of the number of convolutional layers per dataset, facilitating the effective aggregation of spatial information. In addition to adapting to the varied image geometries, this model takes into account technical constraints, such as available memory.

It’s crucial to note that the model doesn’t perform segmentation directly on the entire image but instead on carefully extracted patches with overlapping regions. The predictions on these patches are subsequently averaged, leading to the final segmentation output.

But having a large patch means more memory usage, and the batch size also consumes memory. The tradeoff taken is to always prioritize the patch size (the model’s capacity) rather than the batch size (only useful for optimization).

Here is the Heuristic algorithm used to compute the optimal patch size and batch size:

Heuristic Rule for Batch and Patch Size, Image from nnU-Net article
Heuristic Rule for Batch and Patch Size, Image from nnU-Net article

And this is what it looks like for different Datasets and input dimensions:

Architecture in function of the input image resolution, Image from nnU-Net article
Architecture in function of the input image resolution, Image from nnU-Net article

Great! Now Let’s quickly go over all the techniques used in nnU-Net:

Training

All models are trained from scratch and evaluated using five-fold cross-validation on the training set, meaning that the original training dataset is randomly divided into five equal parts, or ‘folds’. In this cross-validation process, four of these folds are used for the training of the model, and the remaining one fold is used for the evaluation or testing. This process is then repeated five times, with each of the five folds being used exactly once as the evaluation set.

For the loss, we use a combination of Dice and Cross Entropy Loss. This is a very frequent loss in Image Segmentation. More details on the Dice Loss in V-Net, the U-Net big’s brother

Data Augmentation techniques

The nnU-Net have a very strong Data Augmentation pipeline. The authors use random rotations, random scaling, random elastic deformation, gamma correction and mirroring.

NB: You can add your own transformations by modifying the source code

Elastic deformation, from this article
Elastic deformation, from this article
Image from OpenCV library
Image from OpenCV library

Patch based Inference

So as we said, the model does not predict directly on the full resolution image, it does that on extracted patches and then aggregates the prediction.

This is what it looks like:

Patch Based inference, Image by Author
Patch Based inference, Image by Author

NB: The patches in the center of the picture are given more weight than the ones on the side, because they contain more information and the model performs better on them

Pairwise Model Ensembling

Model Ensembling, Image by author
Model Ensembling, Image by author

So if you remember well, we can train up to 3 different models, 2D, 3D, and cascade. But when we make inference we can only use one model at a time right?

Well turns out that no, different models have different strengths and weaknesses. So we can actually combine the predictions of several models so that if one model is very confident, we prioritize its prediction.

nnU-Net tests every combination of 2 models among the 3 available models and picks up the best one.

In Practice, there are 2 ways to do that:

Hard voting: For each pixel, we look at all the probabilities outputted by the 2 models, and we take the class with the highest probability.

Soft Voting: For each pixel, we average the probability of the models, and then we take the class with the maximum probability.

Practical implementation

Before we begin, you can download the dataset here and follow the Google Collab notebook.

If you did not understand anything about the first part, no worries, this is the practical part, you just need to follow me, and you are still going to get the best results.

You need a GPU to train the model otherwise it does not work. You can either do it locally, or on Google Collab, don’t forget to change the runtime > GPU

So, first of all, you need to have a dataset ready with input images and their corresponding segmentation. You can follow my tutorial by downloading this ready dataset for 3D Brain segmentation, and then you can replace it with your own dataset.

Downloading data

First of all you should download your data and place them in the data folder, by naming the two folders "input" and "ground_truth" which contains the segmentation.

For the rest of the tutorial I will use the MindBoggle dataset for image segmentation. You can download it on this Google Drive:

We are given 3D MRI scans of the Brain and we want to segment the White and Gray matter:

Image by Author
Image by Author

It should look like this:

Tree, Image by Author
Tree, Image by Author

Setting up the main directory

If you run this on Google Colab, set collab = True, otherwise collab = False

collab = True

import os
import shutil
#libraries
from collections import OrderedDict
import json
import numpy as np

#visualization of the dataset
import matplotlib.pyplot as plt
import nibabel as nib

if collab:
    from google.colab import drive
    drive.flush_and_unmount()
    drive.mount('/content/drive', force_remount=True)
    # Change "neurosciences-segmentation" to the name of your project folder
    root_dir = "/content/drive/MyDrive/neurosciences-segmentation"

else:
    # get the dir of the parent dir
    root_dir = os.getcwd()

input_dir = os.path.join(root_dir, 'data/input')
segmentation_dir = os.path.join(root_dir, 'data/ground_truth')

my_nnunet_dir = os.path.join(root_dir,'my_nnunet')
print(my_nnunet_dir)

Now we are going to define a function that creates folders for us:

def make_if_dont_exist(folder_path,overwrite=False):
    """
    creates a folder if it does not exists
    input:
    folder_path : relative path of the folder which needs to be created
    over_write :(default: False) if True overwrite the existing folder
    """
    if os.path.exists(folder_path):

        if not overwrite:
            print(f'{folder_path} exists.')
        else:
            print(f"{folder_path} overwritten")
            shutil.rmtree(folder_path)
            os.makedirs(folder_path)

    else:
      os.makedirs(folder_path)
      print(f"{folder_path} created!")

And we use this function to create our "my_nnunet" folder where everything is going to be saved

os.chdir(root_dir)
make_if_dont_exist('my_nnunet', overwrite=False)
os.chdir('my_nnunet')
print(f"Current working directory: {os.getcwd()}")

Library installation

Now we are going to install all the requirements. First let’s install the nnunet library. If you are in a notebook run this in a cell:

!pip install nnunet

Otherwise you can install nnunet directly from the terminal with

pip install nnunet

Now we are going to clone the nnUnet git repository and NVIDIA apex. This contains the training scripts as well as a GPU accelerator.

!git clone https://github.com/MIC-DKFZ/nnUNet.git
!git clone https://github.com/NVIDIA/apex

# repository dir is the path of the github folder
respository_dir = os.path.join(my_nnunet_dir,'nnUNet')
os.chdir(respository_dir)
!pip install -e
!pip install --upgrade git+https://github.com/nanohanno/hiddenlayer.git@bugfix/get_trace_graph#egg=hiddenlayer

Creation of the folders

nnUnet requires a very specific structure for the folders.

task_name = 'Task001' #change here for different task name

# We define all the necessary paths
nnunet_dir = "nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data"
task_folder_name = os.path.join(nnunet_dir,task_name) 
train_image_dir = os.path.join(task_folder_name,'imagesTr') # path to training images
train_label_dir = os.path.join(task_folder_name,'labelsTr') # path to training labels
test_dir = os.path.join(task_folder_name,'imagesTs') # path to test images
main_dir = os.path.join(my_nnunet_dir,'nnUNet/nnunet') # path to main directory
trained_model_dir = os.path.join(main_dir, 'nnUNet_trained_models') # path to trained models

Originally the nnU-Net was designed for a decathlon challenge with different tasks. If you have different tasks just run this cell for all your tasks.

# Creation of all the folders
overwrite = False # Set this to True if you want to overwrite the folders
make_if_dont_exist(task_folder_name,overwrite = overwrite)
make_if_dont_exist(train_image_dir, overwrite = overwrite)
make_if_dont_exist(train_label_dir, overwrite = overwrite)
make_if_dont_exist(test_dir,overwrite= overwrite)
make_if_dont_exist(trained_model_dir, overwrite=overwrite)

You should have a structure like that now:

Image by Author
Image by Author

Setting the enironment variables

The script needs to know where you put your raw_data, where it can find the preprocessed data, and where it had to save the results.

os.environ['nnUNet_raw_data_base'] = os.path.join(main_dir,'nnUNet_raw_data_base')
os.environ['nnUNet_preprocessed'] = os.path.join(main_dir,'preprocessed')
os.environ['RESULTS_FOLDER'] = trained_model_dir

Move the files in the right repositories:

We define a function that will move our images to the right repositories in the nnunet folder:

def copy_and_rename(old_location,old_file_name,new_location,new_filename,delete_original = False):
    shutil.copy(os.path.join(old_location,old_file_name),new_location)
    os.rename(os.path.join(new_location,old_file_name),os.path.join(new_location,new_filename))
    if delete_original:
        os.remove(os.path.join(old_location,old_file_name))

Now let’s run this function for the input and ground truth images:

list_of_all_files = os.listdir(segmentation_dir)
list_of_all_files = [file_name for file_name in list_of_all_files if file_name.endswith('.nii.gz')]

for file_name in list_of_all_files:
    copy_and_rename(input_dir,file_name,train_image_dir,file_name)
    copy_and_rename(segmentation_dir,file_name,train_label_dir,file_name)

Now we have to rename the files to be accepted by the nnUnet format, for example subject.nii.gz will become subject_0000.nii.gz

def check_modality(filename):
    """
    check for the existence of modality
    return False if modality is not found else True
    """
    end = filename.find('.nii.gz')
    modality = filename[end-4:end]
    for mod in modality:
        if not(ord(mod)>=48 and ord(mod)<=57): #if not in 0 to 9 digits
            return False
    return True

def rename_for_single_modality(directory):

    for file in os.listdir(directory):

        if check_modality(file)==False:
            new_name = file[:file.find('.nii.gz')]+"_0000.nii.gz"
            os.rename(os.path.join(directory,file),os.path.join(directory,new_name))
            print(f"Renamed to {new_name}")
        else:
            print(f"Modality present: {file}")

rename_for_single_modality(train_image_dir)
# rename_for_single_modality(test_dir)

Setting up the JSON file

We are almost done!

You mostly need to modify 2 things:

  1. The Modality (if its CT or MRI this changes the normalization)
  2. The labels: Enter your own classes
overwrite_json_file = True #make it True if you want to overwrite the dataset.json file in Task_folder
json_file_exist = False

if os.path.exists(os.path.join(task_folder_name,'dataset.json')):
    print('dataset.json already exist!')
    json_file_exist = True

if json_file_exist==False or overwrite_json_file:

    json_dict = OrderedDict()
    json_dict['name'] = task_name
    json_dict['description'] = "Segmentation of T1 Scans from MindBoggle"
    json_dict['tensorImageSize'] = "3D"
    json_dict['reference'] = "see challenge website"
    json_dict['licence'] = "see challenge website"
    json_dict['release'] = "0.0"

    ######################## MODIFY THIS ########################

    #you may mention more than one modality
    json_dict['modality'] = {
        "0": "MRI"
    }
    #labels+1 should be mentioned for all the labels in the dataset
    json_dict['labels'] = {
        "0": "Non Brain",
        "1": "Cortical gray matter",
        "2": "Cortical White matter",
        "3" : "Cerebellum gray ",
        "4" : "Cerebellum white"
    }

    #############################################################

    train_ids = os.listdir(train_label_dir)
    test_ids = os.listdir(test_dir)
    json_dict['numTraining'] = len(train_ids)
    json_dict['numTest'] = len(test_ids)

    #no modality in train image and labels in dataset.json
    json_dict['training'] = [{'image': "./imagesTr/%s" % i, "label": "./labelsTr/%s" % i} for i in train_ids]

    #removing the modality from test image name to be saved in dataset.json
    json_dict['test'] = ["./imagesTs/%s" % (i[:i.find("_0000")]+'.nii.gz') for i in test_ids]

    with open(os.path.join(task_folder_name,"dataset.json"), 'w') as f:
        json.dump(json_dict, f, indent=4, sort_keys=True)

    if os.path.exists(os.path.join(task_folder_name,'dataset.json')):
        if json_file_exist==False:
            print('dataset.json created!')
        else:
            print('dataset.json overwritten!')

Preprocess the data for nnU-Net format

This creates the dataset for the nnU-Net format

# -t 1 means "Task001", if you have a different task change it
!nnUNet_plan_and_preprocess -t 1 --verify_dataset_integrity

Train the models

We are now ready to train the models!

To train the 3D U-Net:

#train 3D full resolution U net
!nnUNet_train 3d_fullres nnUNetTrainerV2 1 0 --npz 

To train the 2D U-Net:

# train 2D U net
!nnUNet_train 2d nnUNetTrainerV2 1 0 --npz

To train the cascade model:

# train 3D U-net cascade
!nnUNet_train 3d_lowres nnUNetTrainerV2CascadeFullRes 1 0 --npz
!nnUNet_train 3d_fullres nnUNetTrainerV2CascadeFullRes 1 0 --npz

Note: If you pause the traning and want to resume it, add a "-c" in the end for "continue".

For example:

#train 3D full resolution U net
!nnUNet_train 3d_fullres nnUNetTrainerV2 1 0 --npz 

Inference

Now we can run the inference:

result_dir = os.path.join(task_folder_name, 'nnUNet_Prediction_Results')
make_if_dont_exist(result_dir, overwrite=True)

# -i is the input folder
# -o is where you want to save the predictions
# -t 1 means task 1, change it if you have a different task number
# Use -m 2d, or -m 3d_fullres, or -m 3d_cascade_fullres
!nnUNet_predict -i /content/drive/MyDrive/neurosciences-segmentation/my_nnunet/nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data/Task001/imagesTs -o /content/drive/MyDrive/neurosciences-segmentation/my_nnunet/nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data/Task001/nnUNet_Prediction_Results -t 1 -tr nnUNetTrainerV2 -m 2d -f 0  --num_threads_preprocessing 1

Visualization of the predictions

First let’s check the training loss. This looks very healthy, and we have a Dice Score > 0.9 (green curve).

This is truly excellent for so little work and a 3D Neuroimaging segmentation task.

Training loss, test loss, validation Dice, Image by Author
Training loss, test loss, validation Dice, Image by Author

Let’s look at one sample:

Prediction on the MindBoggle dataset, Image by Author
Prediction on the MindBoggle dataset, Image by Author

The results are indeed impressive! It’s clear that the model has effectively learned how to segment brain images with high accuracy. While there may be minor imperfections, it’s important to remember that the field of image segmentation is advancing rapidly, and we’re making significant strides towards perfection.

In the future, there’s scope to further optimize the performance of nnU-Net, but that will be for an other article


Thanks for reading! Before you go:

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

You should get my articles in your inbox. Subscribe here.

If you want to have access to premium articles on Medium, you only need a membership for $5 a month. If you sign up with my link, you support me with a part of your fee without additional costs.


If you found this article insightful and beneficial, please consider following me and leaving a clap for more in-depth content! Your support helps me continue producing content that aids our collective understanding.

References

  1. Ronneberger, O., Fischer, P., & Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234–241). Springer, Cham.
  2. Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for Deep Learning-based biomedical image segmentation. Nature Methods, 18(2), 203–211.
  3. Ioffe, S., & Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167.
  4. Ulyanov, D., Vedaldi, A., & Lempitsky, V. (2016). Instance normalization: The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022.
  5. MindBoggle dataset

Related Articles