Authors: Jenny Huang, Ian Hunt-Isaak, William Palmer
Training an Image Segmentation model on new images can be daunting, especially when you need to label your own data. To make this task easier and faster, we built a user-friendly tool that lets you build this entire process in a single Jupyter notebook. In the sections below, we will show you how our tool lets you:
- Manually label your own images
- Build an effective segmentation model through Transfer Learning
- Visualize the model and its results
- Share your project as a Docker image
The main benefits of this tool are that it is easy-to-use, all in one platform, and well-integrated with existing data science workflows. Through interactive widgets and command prompts, we built a user-friendly way to label images and train the model. On top of that, everything can run in a single Jupyter notebook, making it quick and easy to spin up a model, without much overhead. Lastly, by working in a Python environment and using standard libraries like Tensorflow and Matplotlib, this tool can be well-integrated into existing data science workflows, making it ideal for uses like scientific research.
For instance, in microbiology, it can be very useful to segment microscopy images of cells. However, tracking cells over time can easily result in the need to segment hundreds of images, which can be very difficult to do manually. In this article, we will use microscopy images of yeast cells as our dataset and show how we built our tool to differentiate between the background, mother cells, and daughter cells.
1. Labelling
There are many existing tools to create labelled masks for images, including Labelme, ImageJ, and even the graphics editor GIMP. While these are all great tools, they can’t be integrated within a Jupyter notebook, making them harder to use with many existing workflows. Fortunately, Jupyter Widgets make it easy for us to make interactive components and connect them with the rest of our Python code.
To create training masks in the notebook, we have two problems to solve:
- Select parts of an image with a mouse
- Easily switch between images and select the class to label
To solve the first problem, we used the Matplotlib widget backend and the built-in LassoSelector. The LassoSelector handles drawing a line to show what you are selecting, but we need a little bit of custom code to draw the masks as an overlay:
For the second problem, we added nice looking buttons and other controls using ipywidgets:

We combined these elements (along with improvements like scroll to zoom) to make a single labelling controller object. Now we can take microscopy images of yeast and segment the mother cells and daughter cells:
You can check out the full object, which lets you scroll to zoom, right click to pan, and select multiple classes here.
Now we can label a small number of images in the notebook, save them into the correct folder structure, and start to train CNN!
2. Model Training
The Model
U-net is a convolutional neural network that was initially designed to segment biomedical images but has been successful for many other types of images. It builds upon existing convolutional networks to work better with very few training images and make more precise segmentations. It is a state-of-the-art model that is also easy to implement using the [segmentation_models](https://github.com/qubvel/segmentation_models)

U-net is unique because it combines an encoder and a decoder using cross-connections (the gray arrows in the figure above). These skip connections cross from the same sized part in the downsampling path to the upsampling path. This creates awareness of the original pixels inputted into the model when you upsample, which has been shown to improve performance on segmentation tasks.
As great as U-net is, it won’t work well if we don’t give it enough training examples. And given how tedious it is to manually segment images, we only manually labelled 13 images. With so few training examples, it seems impossible to train a neural network with millions of parameters. To overcome this, we need both Data Augmentation and Transfer Learning.
Data Augmentation
Naturally, if your model has a lot of parameters, you would need a proportional amount of training examples to get good performance. Using our small dataset of images and masks, we can create new images that will be as insightful and useful to our model as our original images.
How do we do that? We can flip the image, rotate it at an angle, scale it inward or outward, crop it, translate it, or even blur the image by adding noise, but most importantly, we can do a combination of those operations to create many new training examples.

Image data augmentation has one more complication in segmentation compared to classification. For classification, you just need to augment the image as the label will remain the same (0 or 1 or 2…). However, for segmentation, the label (which is a mask) needs to also be transformed in sync with the image. To do this, we used the [albumentations](https://albumentations.readthedocs.io/en/latest/)
library with a custom data generator since, to our knowledge, the Keras ImageDataGenerator
does not currently support the combination "Image + mask".
Transfer Learning
Even though we have now created 100 or more images, this still isn’t enough as the U-net model has more than 6 million parameters. This is where transfer learning comes into play.
Transfer Learning lets you take a model trained on one task and reuse it for another similar task. It reduces your training time drastically and more importantly, it can lead to effective models even with a small training set like ours. For example, neural networks like MobileNet, Inception, and DeepNet, learn a feature space, shapes, colors, texture, and more, by training on a great number of images. We can then transfer what was learned by taking these model weights and modifying them slightly to activate for patterns in our own training images.
Now how do we use transfer learning with U-net? We used the segmentation_models
library to do this. We use the layers of a deep neural network of your choosing (MobileNet, Inception, ResNet) and the parameters found training on image classification (ImageNet) and use them as the first half (encoder) of your U-net. Then, you train the decoder layers with your own augmented dataset.
Putting it Together
We put this all together in a Segmentation model class that you can find here. When creating your model object, you get an interactive command prompt where you can customize aspects of your U-net like the loss function, backbone, and more:
After 30 epochs of training, we achieved 95% accuracy. Note that it is important to choose a good loss function. We first tried cross-entropy loss, but the model was unable to distinguish between the similar looking mother and daughter cells and had poor performance due to the class imbalance of seeing many more non-yeast pixels than yeast pixels. We found that using dice loss gave us much better results. The dice loss is linked to the Intersection over Union Score (IOU) and is usually better adapted to segmentation tasks as it gives incentive to maximize the overlap between the predicted and ground truth masks.

3. Visualization
Now that our model is trained, let’s use some Visualization techniques to see how it works. We follow Ankit Paliwal’s tutorial to do so. You can find the implementation in his corresponding GitHub repository. In this section, we will visualize two of his techniques, Intermediate Layer Activations and Heatmaps of Class Activations, on our yeast cell segmentation model.
Intermediate Layer Activations
This first technique shows the output of intermediate layers in a forward pass of the network on a test image. This lets us see what features of the input image are highlighted at each layer. After inputting a test image, we visualized the first few outputs for some convolutional layers in our network:

In the encoder layers, filters close to the input detect more detail and those close to the output of the model detect more general features, which is to be expected. In the decoder layers, we see the opposite pattern, of going from abstract to more specific details, which is also to be expected.
Heatmaps of Class Activations
Next, we look at class activation maps. These heat maps let you see how important each location of the image is for predicting an output class. Here, we visualize the final layer of our yeast cell model, since the class prediction label will largely depend on it.

We see from the heat maps that the cell locations are correctly activated, along with parts of the image border, which is somewhat surprising.
We also looked at the last technique in the tutorial, which shows what images each convolutional filter maximally responds to, but the visualizations were not very informative for our specific yeast cell model.
4. Making and Sharing a Docker Image
Finding an awesome model and trying to run it, only to find that it doesn’t work in your environment due to mysterious dependency issues, is very frustrating. We addressed this by creating a Docker image for our tool. This allows us to completely define the environment that the code is run in, all the way down to the operating system. For this project, we based our Docker image off of the jupyter/tensorflow-notebook
image from Jupyter Docker Stacks. Then we just added a few lines to install the libraries we needed and to copy the contents of our GitHub repository into the Docker image. If you’re curious, you can see our final Dockerfile here. Finally, we pushed this image to Docker Hub for easy distribution. You can try it out by running:
sudo docker run -p 8888:8888 ianhuntisaak/ac295-final-project:v3
Conclusion and Future Work
This tool lets you easily train a segmentation model on new images in a user-friendly way. While it works, there is still room for improvement in usability, customization, and model performance. In the future, we hope to:
- Improve the lasso tool by building a custom Jupyter Widget using the html5 canvas to reduce lag when manually segmenting
- Explore new loss functions and models (like this U-net pre-trained on broad nucleus dataset](https://bioimage.io/?model=2D%20UNet%20Nuclei%20Broad))) as a basis for transfer learning
- Make it easier to interpret visualizations and suggest methods of improving the results to the user
We would like to thank our professor Pavlos Protopapas and the Harvard Applied Computation 295 course teaching staff for their guidance and support.