The world’s leading publication for data science, AI, and ML professionals.

Paper Walkthrough: U-Net

A PyTorch implementation on one of the most popular semantic segmentation models.

Photo by Caleb Jones on Unsplash
Photo by Caleb Jones on Unsplash

Introduction to U-Net

When we talk about image segmentation, we should not forget about U-Net, a neural network architecture that was first proposed by Ronneberger et al. [1] back in 2015. This model was initially intended to perform segmentation tasks on medical images. Later on, other researchers found that this architecture could actually be used for general semantic segmentation tasks as well. Furthermore, it is also possible to utilize the model for other things like super resolution (i.e., upscaling low resolution image into a higher one) and diffusion (i.e., generating images from noises). In this article, I would like to show you how to implement U-Net from scratch using PyTorch. You can see the entire U-Net architecture in Figure 1. By looking at this structure, I think it is pretty straightforward how this network got its name.

Figure 1. The U-Net architecture [1].
Figure 1. The U-Net architecture [1].

There are several key components in the architecture. First, there is a Contracting Path, which is also known as the Encoder. This component is responsible for gradually shrinking the spatial dimension of the input image from 572×572 to 64×64. However, notice that the number of channels in each downsampling stage doubles instead to compensate the information loss caused by the spatial dimension reduction. In contrast, the Expansive Path (Decoder) expands the feature map into a larger spatial dimension while reducing the number of channels. Despite the symmetrical architecture, it is important to note that the output produced by the final upsampling stage is different from the input in terms of the image resolution.

There are two types of connections connecting the Encoder and the Decoder, namely the Bottleneck and the Residual-Like path. The Bottleneck part of the network corresponds to everything between the last pooling layer and the first transpose convolution layer, i.e., the lowermost part of the network shown in Figure 1. Meanwhile, the Residual-Like paths are the gray-colored arrows that help the network to preserve high-resolution features from the Encoder (since relying solely on the Bottleneck would result in significant loss of spatial information). Additionally, the reason that I name it Residual-Like is essentially because it differs from the one proposed in ResNet. In that architecture, we perform element-wise summation in the merging process, whereas in the case of U-Net we concatenate the two tensors instead.


Implementing U-Net with PyTorch

There are three imports that I do for this project: the base PyTorch module (torch) for standard mathematical functionalities, the nn submodule for loading neural network layers, and the summary() function taken from torchinfo to print out the details of a model.

# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

The Encoder

Now that the modules have been successfully loaded, we can actually start coding. Let’s begin with the Encoder first. In Figure 2 below, all components belong to the Encoder are highlighted in green. If you take a closer look at these Encoder stages, you can see that each of those comprises of two consecutive convolution layers with 3×3 kernels followed by a 2×2 maximum-pooling layer. Since the process done in all these stages are basically the same, we can just wrap each stage together and repeat the process four times.

Figure 2. The Encoder part of U-Net comprises of four downsampling stages [2].
Figure 2. The Encoder part of U-Net comprises of four downsampling stages [2].

The stack of two convolution layers is implemented in the DoubleConv() class shown in Codeblock 2 and 3 below. Here I initialize both layers (conv_0 and conv_1) as well as their corresponding batch normalization layers (bn_0 and bn_1). In fact, the use of batch normalization is not mentioned in the original U-Net paper. However, I will implement it anyway since it usually allows the model to obtain better accuracy. In this article I will only focus on demonstrating how to implement the U-Net architecture. So, if you want to actually train this model, it will be a good idea if you do it both with and without batch normalization to see if my hypothesis is correct. Furthermore, if you decide not to use batch normalization, ensure that you change the bias parameter of the convolution layers to True (at the line marked with #(1) and #(2) in Codeblock 2). This is essentially because if you use batch normalization, there will be no point of using bias term for the convolution as the normalization layer will cancel the biases out.

# Codeblock 2
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv_0 = nn.Conv2d(in_channels=in_channels, 
                                out_channels=out_channels, 
                                kernel_size=3, bias=False)    #(1)
        self.bn_0 = nn.BatchNorm2d(num_features=out_channels)

        self.conv_1 = nn.Conv2d(in_channels=out_channels,
                                out_channels=out_channels, 
                                kernel_size=3, bias=False)    #(2)
        self.bn_1 = nn.BatchNorm2d(num_features=out_channels)

        self.relu = nn.ReLU(inplace=True)

As all layers as well as the ReLU activation function have been initialized, now we need to string them together with the forward() function. Just a quick reminder: the correct layer sequence when working with CNNs is Conv-BN-ReLU, and this is exactly the structure that I implement below. In order to make the process clearer, here I also print out the tensor dimension after each convolution operation.

# Codeblock 3
    def forward(self, x):
        print(f'originaltt: {x.size()}')

        x = self.conv_0(x)
        x = self.bn_0(x)
        x = self.relu(x)
        print(f'after first convt: {x.size()}')

        x = self.conv_1(x)
        x = self.bn_1(x)
        x = self.relu(x)
        print(f'after second convt: {x.size()}')

        return x

We can run the Codeblock 4 to check whether our DoubleConv() works properly. Here I set the network to accept 1-channel image and output 64-channel image as written at the line marked with #(1). The tensor of random numbers (#(2)), which we assume it to be an image, has the dimension of 1×1×572×572. Each axis of this tensor represents the number of images in a single batch, the number of color channels, image height and image width, respectively.

# Codeblock 4
double_conv = DoubleConv(in_channels=1, out_channels=64)    #(1)
x = torch.randn((1, 1, 572, 572))    #(2)
x = double_conv(x).size()
# Codeblock 4 output
original                : torch.Size([1, 1, 572, 572])
after first conv        : torch.Size([1, 64, 570, 570])
after second conv       : torch.Size([1, 64, 568, 568])

The above output shows that both the height and width of the input image got reduced by two pixels after each convolution layer, in which it exactly matches with the one written in the first two convolution processes in Figure 1 (572 to 570 and 570 to 568). This reduction is primarily due to the 3×3 kernel size and the absence of padding prior to the convolution operation.


Previously I have mentioned that every single downsampling stage consists of two convolution layers and a single max-pooling layer. At this point we have successfully implemented the two convolutions inside the DoubleConv() class, but we haven’t put the pooling operation just yet. Now what I am going to do here is to create a new class named DownSample() which encapsulates both the convolutions and the pooling. The detailed code for this is shown in Codeblock 5 below.

# Codeblock 5
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.double_conv = DoubleConv(in_channels=in_channels, 
                                      out_channels=out_channels)    #(1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)    #(2)

    def forward(self, x):
        print(f'originaltt: {x.size()}')

        convolved = self.double_conv(x)
        print(f'after double convt: {convolved.size()}')

        maxpooled = self.maxpool(convolved)
        print(f'after poolingtt: {maxpooled.size()}')

        return convolved, maxpooled    #(3)

I think everything inside the __init__() function in the above code is pretty straightforward. First, we use the DoubleConv() class which the number of input and output channels are adjustable (#(1)), and second, we use nn.MaxPool2d() layer with the value of 2 for both the kernel_size and stride parameters (#(2)). With this pooling configuration, the spatial dimension produced is going to be two times smaller.

Taking a closer look at the forward() function, especially at the line marked by #(3), you can see that we return both the output of the double convolution layer (convolved) and the output of the pooling layer (maxpooled). The reason that I do this is because we have two branches in every single downsampling stage: convolved is the tensor to be transferred directly to the upsampling stage in the Decoder, while maxpooled is the one to be brought into the subsequent layer. In the figure below, convolved is highlighted in pink while maxpooled is highlighted in cyan.

Figure 3. The feature maps to be transferred directly to the decoder through Residual-Like connections (pink) and the ones to be fed into the subsequent layers (cyan) [2].
Figure 3. The feature maps to be transferred directly to the decoder through Residual-Like connections (pink) and the ones to be fed into the subsequent layers (cyan) [2].

Next, we are going to check if the DownSample() class we just created works properly by running the Codeblock 6 below. Here I assume that we are initializing the very first downsampling stage, which includes the two convolutions as well as the pooling layer. It is written in Figure 1 that the output of this stage should have the height and width of 284 with 64 channels. And we got it correct.

# Codeblock 6
down_sample = DownSample(in_channels=1, out_channels=64)
x = torch.randn((1, 1, 572, 572))
x = down_sample(x)
# Codeblock 6 output
original                : torch.Size([1, 1, 572, 572])
after double conv       : torch.Size([1, 64, 568, 568])
after pooling           : torch.Size([1, 64, 284, 284])

The Decoder

Now let’s jump into the counterpart of the Encoder, the Decoder. This part of U-Net essentially reverses the downsampling process done by the Encoder. Hence, the processes inside the Decoder are also known as upsampling. The four upsampling stages in the architecture are highlighted in orange in Figure 4 below. Every single upsampling stage comprises of a transpose convolution layer which is then followed by two consecutive standard convolution layers. In the case of U-Net, transpose convolution is responsible for doubling up the spatial dimensions of an image while at the same time it also halves the number of channels. Meanwhile, the use of standard convolution layers in the subsequent step is to maintain the channel count while refining the features.

Figure 4. The four upsampling stages in the Decoder [2].
Figure 4. The four upsampling stages in the Decoder [2].

However, it is actually not as trivial as it says because we do also need to think about how the Residual-Like paths are connected to each upsampling stage. The main idea is actually simple: just concatenate. But what makes it somewhat tricky is that the feature map produced in each downsampling stage is larger than that of the one produced in the upsampling stage. For instance, if you look at the first (uppermost) Residual-Like path in Figure 1, you will see that the feature map from the Encoder has the spatial dimension of 568×568 whereas the one in the Decoder is 392×392. Therefore, in order to make concatenation possible to be done, feature maps from the Encoder need to be cropped so that the size of the two tensors matches. In the Codeblock 7 below, I create a function named crop_image() to do so.

# Codeblock 7
def crop_image(original, expected):    #(1)

    original_dim = original.size()[-1]    #(2)
    expected_dim = expected.size()[-1]    #(3)

    difference = original_dim - expected_dim    #(4)
    padding = difference // 2    #(5)

    cropped = original[:, :, padding:original_dim-padding, padding:original_dim-padding]    #(6)

    return cropped

This function accepts two tensors: original and expected (#(1)). The former refers to the tensor coming from the Encoder, while the latter corresponds to the tensor that is already in the Decoder. So basically, I want the size of the original tensor to be cropped such that it has the same size with the expected tensor. At the line marked with #(2) and #(3), I use simple indexing to get the width of both images. In this case, we don’t need to take the height since it is just the same as the width. Next, we calculate the width difference between the original and the expected images to determine how much of the original image should be cropped (#(4)). At line #(5), we divide difference by two because we want to use the resulting number as the padding, ensuring the cropped region is symmetrical on all sides. Finally, we actually crop the original image using the code at line #(6).


Once the crop_image() function is finished, we can move on to the UpSample() class, which encapsulates the entire upsampling stage. You can see in Codeblock 8 below that I initialize another DoubleConv() block after an nn.ConvTranspose2d() layer. Since we want the resulting image to be spatially twice as large as the input, we need to set both kernel_size and stride to 2 as written at line #(1). Figure 5 illustrates how transpose convolution layer works in case you’re not yet familiar with it.

# Codeblock 8
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv_transpose = nn.ConvTranspose2d(in_channels=in_channels,
                                                 out_channels=out_channels, 
                                                 kernel_size=2, stride=2)    #(1)
        self.double_conv = DoubleConv(in_channels=in_channels,
                                      out_channels=out_channels)
Figure 5. An example of a transpose convolution operation on 2×2 input image with 2×2 kernel and stride 2 (results in 4×4 image) [3].
Figure 5. An example of a transpose convolution operation on 2×2 input image with 2×2 kernel and stride 2 (results in 4×4 image) [3].

To the forward() function of the UpSampling() class, we can see in Codeblock 9 that it accepts two inputs, where x is the tensor from the main flow and connection is the one coming directly from the Encoder (#(1)). Initially, we apply transpose convolution to tensor x (#(2)). Then, we put both connection and x into the crop_image() function such that the spatial dimension of connection is going to be the same as x (#(3)). As the cropping is done, the two tensors are concatenated along the channel dimension. To achieve this, we need to use torch.cat() with dim=1 (#(4)). Finally, we pass the tensor through the DoubleConv() block before returning it (#(5)).

# Codeblock 9
    def forward(self, x, connection):    #(1)
        print(f'x originalttt: {x.size()}')
        print(f'connection originaltt: {connection.size()}')

        x = self.conv_transpose(x)    #(2)
        print(f'x after conv transposett: {x.size()}')

        cropped_connection = crop_image(connection, x)    #(3)
        print(f'connection after croppedt: {x.size()}')

        x = torch.cat([x, cropped_connection], dim=1)    #(4)
        print(f'after concatenationtt: {x.size()}')

        x = self.double_conv(x)    #(5)
        print(f'after double convtt: {x.size()}')

        return x

Now let’s check whether our UpSample() class works properly by running the following codeblock. For this example, I will simulate the first upsampling stage, where the tensor coming from the Bottleneck is denoted as x (#(2)), while the one coming from the last downsampling stage is denoted as connection (#(3)). Here, I set the UpSampling() stage to accept image with 1024 channels and return 512 channels (#(1)). If our implementation is correct, the resulting image should have a height and width of 52.

# Codeblock 10
up_sample = UpSample(1024, 512)    #(1)

x = torch.randn((1, 1024, 28, 28))    #(2)
connection = torch.randn((1, 512, 64, 64))    #(3)

x = up_sample(x, connection)
# Codeblock 10 output
x original                      : torch.Size([1, 1024, 28, 28])    #(1)
connection original             : torch.Size([1, 512, 64, 64])     #(2)
x after conv transpose          : torch.Size([1, 512, 56, 56])     #(3)
connection after cropped        : torch.Size([1, 512, 56, 56])     #(4)
after concatenation             : torch.Size([1, 1024, 56, 56])    #(5)
after double conv               : torch.Size([1, 512, 52, 52])     #(6)

In the output above, we can observe that after the tensor x is processed with transpose convolution, its dimension change from 1×1024×28×28 (#(1)) to 1×512×56×56 (#(3)). The tensor connection initially has dimensions of 1×512×64×64 (#(2)) and is successfully cropped to 1×512×56×56 (#(4)). At this point, since both x and connection have 512 channels, the total number of channels becomes 1024 after they are concatenated (#(5)). Finally, the tensor is transformed to 1×512×52×52 (#(6)), which matches our expectations.


The Complete U-Net Architecture

So far, we have created the DoubleConv(), DownSample(), and UpSample() classes. What we are going to do afterwards is use them all to construct the entire U-Net architecture. Have a look at Codeblock 11 below to see how I do that.

# Codeblock 11
class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):    #(1)
        super().__init__()

        # Encoder    #(2)
        self.downsample_0 = DownSample(in_channels=in_channels, out_channels=64)
        self.downsample_1 = DownSample(in_channels=64, out_channels=128)
        self.downsample_2 = DownSample(in_channels=128, out_channels=256)
        self.downsample_3 = DownSample(in_channels=256, out_channels=512)

        # Bottleneck    #(3)
        self.bottleneck   = DoubleConv(in_channels=512, out_channels=1024)

        # Decoder    #(4)
        self.upsample_0   = UpSample(in_channels=1024, out_channels=512)
        self.upsample_1   = UpSample(in_channels=512, out_channels=256)
        self.upsample_2   = UpSample(in_channels=256, out_channels=128)
        self.upsample_3   = UpSample(in_channels=128, out_channels=64)

        # Output    #(5)
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

At the line marked with #(1), I set the default number of input channels to 1 and the number of classes to 2, matching the original U-Net architecture as explained in the paper (refer to Figure 1 to verify this in the first and last layers). In other words, this model by default accepts a grayscale image and outputs a binary segment. Nevertheless, it is definitely possible to change this number if you want to utilize the model for a more complex segmentation task.

Next, we create the encoder by stacking four downsampling stages (#(2)). Here you need to ensure that the number of input channels of a DownSample() stage is the same as the number of output channels of the previous stage. This principle also applies to the upsampling stages in the Decoder (#(4)). However, while the number of channels typically increases as the network deepens in the Encoder, it decreases in the Decoder as we move towards the output.

For the Bottleneck, we are going to employ a DoubleConv() block as it is essentially just a stack of two convolution layers (#(3)). Finally, we use standard nn.Conv2d() for the output layer with a kernel size of 1×1 and the number of output channels set to the number of classes (#(5)). – By the way if you’re new to image segmentation models, the number of classes essentially corresponds to the number of channels in the output layer, with every single channel represents a specific segment. For instance, if you have segments for the sky, ground, and object, you would need to set the number of output channels to 3: one channel for the sky, one for ground, and one for the object. This can be thought of as classifying each pixel in the output image as belonging to one of these classes.

Now as all U-Net components have been initialized, we can define the flow of the network in the following forward() function.

# Codeblock 12
    def forward(self, x):
        print(f'originaltt: {x.size()}')

        convolved_0, maxpooled_0 = self.downsample_0(x)    #(1)
        print(f'maxpooled_0tt: {maxpooled_0.size()}')

        convolved_1, maxpooled_1 = self.downsample_1(maxpooled_0)    #(2)
        print(f'maxpooled_1tt: {maxpooled_1.size()}')

        convolved_2, maxpooled_2 = self.downsample_2(maxpooled_1)    #(3)
        print(f'maxpooled_2tt: {maxpooled_2.size()}')

        convolved_3, maxpooled_3 = self.downsample_3(maxpooled_2)    #(4)
        print(f'maxpooled_3tt: {maxpooled_3.size()}')

        x = self.bottleneck(maxpooled_3)
        print(f'after bottleneckt: {x.size()}')

        upsampled_0 = self.upsample_0(x, convolved_3)    #(5)
        print(f'upsampled_0tt: {upsampled_0.size()}')

        upsampled_1 = self.upsample_1(upsampled_0, convolved_2)    #(6)
        print(f'upsampled_1tt: {upsampled_1.size()}')

        upsampled_2 = self.upsample_2(upsampled_1, convolved_1)
        print(f'upsampled_2tt: {upsampled_2.size()}')

        upsampled_3 = self.upsample_3(upsampled_2, convolved_0)
        print(f'upsampled_3tt: {upsampled_3.size()}')

        x = self.output(upsampled_3)
        print(f'final outputtt: {x.size()}')

        return x

There are several things I want to emphasize in the code. First, remember that the downsampling stages return two outputs: convolved and maxpooled (at line #(1) to #(4)). Later in the upsampling stages, we pair x with convolved_3 (at line #(5)), upsampled_0 with convolved_2 (at line #(6)), and so on for the remaining stages. This pairing is essential because the output of each upsampling stage is produced by combining the feature map from the previous stage with the corresponding feature map from the Encoder.

We can test the UNet() class we just created using the following codeblock. You can see the output that every single tensor dimension in the flow matches with the dimension written in the original paper (refer to Figure 1). This essentially means that we have correctly implement the U-Net architecture.

# Codeblock 13
unet = UNet()
x = torch.randn((1, 1, 572, 572))
x = unet(x)
# Codeblock 13 output
original                : torch.Size([1, 1, 572, 572])
convolved_0             : torch.Size([1, 64, 568, 568])
maxpooled_0             : torch.Size([1, 64, 284, 284])
convolved_1             : torch.Size([1, 128, 280, 280])
maxpooled_1             : torch.Size([1, 128, 140, 140])
convolved_2             : torch.Size([1, 256, 136, 136])
maxpooled_2             : torch.Size([1, 256, 68, 68])
convolved_3             : torch.Size([1, 512, 64, 64])
maxpooled_3             : torch.Size([1, 512, 32, 32])
after bottleneck        : torch.Size([1, 1024, 28, 28])
upsampled_0             : torch.Size([1, 512, 52, 52])
upsampled_1             : torch.Size([1, 256, 100, 100])
upsampled_2             : torch.Size([1, 128, 196, 196])
upsampled_3             : torch.Size([1, 64, 388, 388])
final output            : torch.Size([1, 2, 388, 388])

In order to display the detailed architecture, including the parameter count, model size, etc., we can use the summary() function we imported earlier.

# Codeblock 14
summary(unet, input_size=(1,1,572,572))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
UNet                                     [1, 2, 388, 388]          --
├─DownSample: 1-1                        [1, 64, 568, 568]         --
│    └─DoubleConv: 2-1                   [1, 64, 568, 568]         --
│    │    └─Conv2d: 3-1                  [1, 64, 570, 570]         576
│    │    └─BatchNorm2d: 3-2             [1, 64, 570, 570]         128
│    │    └─ReLU: 3-3                    [1, 64, 570, 570]         --
│    │    └─Conv2d: 3-4                  [1, 64, 568, 568]         36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 568, 568]         128
│    │    └─ReLU: 3-6                    [1, 64, 568, 568]         --
│    └─MaxPool2d: 2-2                    [1, 64, 284, 284]         --
├─DownSample: 1-2                        [1, 128, 280, 280]        --
│    └─DoubleConv: 2-3                   [1, 128, 280, 280]        --
│    │    └─Conv2d: 3-7                  [1, 128, 282, 282]        73,728
│    │    └─BatchNorm2d: 3-8             [1, 128, 282, 282]        256
│    │    └─ReLU: 3-9                    [1, 128, 282, 282]        --
│    │    └─Conv2d: 3-10                 [1, 128, 280, 280]        147,456
│    │    └─BatchNorm2d: 3-11            [1, 128, 280, 280]        256
│    │    └─ReLU: 3-12                   [1, 128, 280, 280]        --
│    └─MaxPool2d: 2-4                    [1, 128, 140, 140]        --
├─DownSample: 1-3                        [1, 256, 136, 136]        --
│    └─DoubleConv: 2-5                   [1, 256, 136, 136]        --
│    │    └─Conv2d: 3-13                 [1, 256, 138, 138]        294,912
│    │    └─BatchNorm2d: 3-14            [1, 256, 138, 138]        512
│    │    └─ReLU: 3-15                   [1, 256, 138, 138]        --
│    │    └─Conv2d: 3-16                 [1, 256, 136, 136]        589,824
│    │    └─BatchNorm2d: 3-17            [1, 256, 136, 136]        512
│    │    └─ReLU: 3-18                   [1, 256, 136, 136]        --
│    └─MaxPool2d: 2-6                    [1, 256, 68, 68]          --
├─DownSample: 1-4                        [1, 512, 64, 64]          --
│    └─DoubleConv: 2-7                   [1, 512, 64, 64]          --
│    │    └─Conv2d: 3-19                 [1, 512, 66, 66]          1,179,648
│    │    └─BatchNorm2d: 3-20            [1, 512, 66, 66]          1,024
│    │    └─ReLU: 3-21                   [1, 512, 66, 66]          --
│    │    └─Conv2d: 3-22                 [1, 512, 64, 64]          2,359,296
│    │    └─BatchNorm2d: 3-23            [1, 512, 64, 64]          1,024
│    │    └─ReLU: 3-24                   [1, 512, 64, 64]          --
│    └─MaxPool2d: 2-8                    [1, 512, 32, 32]          --
├─DoubleConv: 1-5                        [1, 1024, 28, 28]         --
│    └─Conv2d: 2-9                       [1, 1024, 30, 30]         4,718,592
│    └─BatchNorm2d: 2-10                 [1, 1024, 30, 30]         2,048
│    └─ReLU: 2-11                        [1, 1024, 30, 30]         --
│    └─Conv2d: 2-12                      [1, 1024, 28, 28]         9,437,184
│    └─BatchNorm2d: 2-13                 [1, 1024, 28, 28]         2,048
│    └─ReLU: 2-14                        [1, 1024, 28, 28]         --
├─UpSample: 1-6                          [1, 512, 52, 52]          --
│    └─ConvTranspose2d: 2-15             [1, 512, 56, 56]          2,097,664
│    └─DoubleConv: 2-16                  [1, 512, 52, 52]          --
│    │    └─Conv2d: 3-25                 [1, 512, 54, 54]          4,718,592
│    │    └─BatchNorm2d: 3-26            [1, 512, 54, 54]          1,024
│    │    └─ReLU: 3-27                   [1, 512, 54, 54]          --
│    │    └─Conv2d: 3-28                 [1, 512, 52, 52]          2,359,296
│    │    └─BatchNorm2d: 3-29            [1, 512, 52, 52]          1,024
│    │    └─ReLU: 3-30                   [1, 512, 52, 52]          --
├─UpSample: 1-7                          [1, 256, 100, 100]        --
│    └─ConvTranspose2d: 2-17             [1, 256, 104, 104]        524,544
│    └─DoubleConv: 2-18                  [1, 256, 100, 100]        --
│    │    └─Conv2d: 3-31                 [1, 256, 102, 102]        1,179,648
│    │    └─BatchNorm2d: 3-32            [1, 256, 102, 102]        512
│    │    └─ReLU: 3-33                   [1, 256, 102, 102]        --
│    │    └─Conv2d: 3-34                 [1, 256, 100, 100]        589,824
│    │    └─BatchNorm2d: 3-35            [1, 256, 100, 100]        512
│    │    └─ReLU: 3-36                   [1, 256, 100, 100]        --
├─UpSample: 1-8                          [1, 128, 196, 196]        --
│    └─ConvTranspose2d: 2-19             [1, 128, 200, 200]        131,200
│    └─DoubleConv: 2-20                  [1, 128, 196, 196]        --
│    │    └─Conv2d: 3-37                 [1, 128, 198, 198]        294,912
│    │    └─BatchNorm2d: 3-38            [1, 128, 198, 198]        256
│    │    └─ReLU: 3-39                   [1, 128, 198, 198]        --
│    │    └─Conv2d: 3-40                 [1, 128, 196, 196]        147,456
│    │    └─BatchNorm2d: 3-41            [1, 128, 196, 196]        256
│    │    └─ReLU: 3-42                   [1, 128, 196, 196]        --
├─UpSample: 1-9                          [1, 64, 388, 388]         --
│    └─ConvTranspose2d: 2-21             [1, 64, 392, 392]         32,832
│    └─DoubleConv: 2-22                  [1, 64, 388, 388]         --
│    │    └─Conv2d: 3-43                 [1, 64, 390, 390]         73,728
│    │    └─BatchNorm2d: 3-44            [1, 64, 390, 390]         128
│    │    └─ReLU: 3-45                   [1, 64, 390, 390]         --
│    │    └─Conv2d: 3-46                 [1, 64, 388, 388]         36,864
│    │    └─BatchNorm2d: 3-47            [1, 64, 388, 388]         128
│    │    └─ReLU: 3-48                   [1, 64, 388, 388]         --
├─Conv2d: 1-10                           [1, 2, 388, 388]          130
==========================================================================================
Total params: 31,036,546
Trainable params: 31,036,546
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 167.34
==========================================================================================
Input size (MB): 1.31
Forward/backward pass size (MB): 1992.61
Params size (MB): 124.15
Estimated Total Size (MB): 2118.07
==========================================================================================

This concludes our exploration of the U-Net architecture. Feel free to leave a comment if you notice any mistake, especially regarding the implementation. I would be very happy to hear your feedback.

For your reference, all codes used in this article can be accessed on my GitHub repo, which you can find it here.

Thank you for reading!


References

[1] Olaf Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image Segmentation. Arxiv. https://arxiv.org/pdf/1505.04597 [Accessed August 21, 2024].

[2] Image created originally by author based on [1].

[3] Image created originally by author.


Related Articles