The world’s leading publication for data science, AI, and ML professionals.

SAM: Segment Anything Model

Quickly customize your product landing page with SAM

Introduction

Transformers have been widely applied to Natural Language Processing use cases but they can also be applied to several other domains of Artificial Intelligence such as time series forecasting or computer vision.

Great examples of Transformers models applied to computer vision are Stable Diffusion for image generation, Detection Transformer for object detection or, more recently, SAM for image segmentation. The great benefit that these models bring is that we can use text prompts to manipulate images without much effort, all it takes is a good prompt.

The use cases for this type of models are endless, specially if you work at an e-commerce company. A simple, time consuming and expensive use case is the process from photographing an item to posting it on the website for sale. Companies need to photograph the items, remove the props used and, finally, in-paint the hole left by the prop before posting the item in the website. What if this entire process could be automated by AI and our human resources would just handle the complex use cases and review what was done by AI?

In this article, I provide a detailed explanation of SAM, an image segmentation model, and its implementation on a hypothetical use case where we want to perform an A/B test to understand which type of background would increase conversion rate.

Figure 1: Segment Anything Model (image generated by the author with DALL-E)
Figure 1: Segment Anything Model (image generated by the author with DALL-E)

As always, the code is available on Github.

#

Segment Anything Model (SAM) [1] is a segmentation model developed by Meta that aims to create masks of the objects in an image guided by a prompt that can be text, a mask, a bounding box or just a point in an image.

The inspiration comes from the latest developments in Natural Language Processing and, particularly, from Large Language Models, where given an ambiguous prompt, the user expects a coherent response. In the same line of thought, the authors wanted to create a model that would return a valid segmentation mask even when the prompt is ambiguous and could refer to multiple objects in an image. This reasoning led to the development of a pre-trained algorithm and a general method for zero-shot transfer to downstream tasks.

SAM, as a prompt-able segmentation task solver, can solve new and different segmentation tasks than what was trained for, via prompting engineering, making it accessible for a wide variety of use cases with few or none fine-tuning on your own data.

Figure 2: SAM architecture (source)
Figure 2: SAM architecture (source)

How does it work?

As shown in Figure 2, SAM has three main components:

1. Image Encoder which is an adaptation of the encoder from Masked AutoEncoder model (MAE) [2]. MAE is pre-trained on images divided into regular non-overlapping patches where 75% of them are masked.

After this image transformation, the encoder receives the unmasked patches and encodes it into an embedding vector. This vector is then concatenated to the mask tokens (that identify a missing patch that needs to be predicted) and positional embeddings before going through the decoder to reconstruct the original image.

Figure 3: Training process of MAE model (image made by the author)
Figure 3: Training process of MAE model (image made by the author)

The most important part in this whole process is the decision of which patches should be masked. It relies on a random sampling without replacement that together with the high masking ratio (75%) guarantees a complex image reconstruction task that cannot be solved by simply extrapolating from the visible neighbouring patches. Thereby, the encoder must learn how to create a high quality vector representation of the image, so that the decoder can correctly reconstruct the original image.

The authors adapted the image encoder to produce an embedding which is a 16x downscaling of the original image with 64×64 dimension and 256 channels.

2. Flexible Prompt Encoder has four different components that are triggered depending on the prompt provided.

  • Masks are represented by a Convolution Neural Network that downscales two times the image by a factor of 4 using a kernel 2×2 and a stride-2 convolution operation with output channels of 4 and 16, respectively. After that, a 1×1 convolution maps the channel dimension to 256 which is then added element-wise to the output of the Image Encoder. If no mask is provided then a learned embedding representation of "no mask" replaces the mask embedding.
Figure 4: Processing mask prompts in SAM architecture (image made by the author)
Figure 4: Processing mask prompts in SAM architecture (image made by the author)
  • Points are represented as a positional embedding and one of two learned embeddings that indicate if the point is in the background or the foreground. The positional embedding is achieved by applying the work developed by the authors in [3]. The coordinates of a point are mapped into Fourier features before feeding a Multi Layer Perceptron (MLP) which improves significantly the output generated. As shown in Figure 5, for the image regression task the model is able to generate a non blurred image when using Fourier features. For SAM, the authors apply the same logic and create a 256 dimensional vector to represent the point position.
Figure 5: Positional Embedding Creation (image made by the author)
Figure 5: Positional Embedding Creation (image made by the author)
  • Boxes follow the same principle of Points where there is a positional embedding for top-left corner and another for bottom-right corner, but instead of two learned embeddings to identify the foreground or the background, it has two learned embeddings to identify the top-left corner and the bottom-right corner.
  • Text is encoded by the text encoder from CLIP [4]. CLIP is a model developed by OpenAI that was trained to predict which caption goes with which image rather than the traditional approach of predicting a fixed set of object classes. This approach aligns the embedding created by the text encoder with the image encoder which allows to perform zero-shot classification based on the cosine similarity between both vector embeddings. The output of the text encoder in SAM is a 256 dimensional vector.

Figure 6: Text and Image Encoder trained together to create similar embeddings (image made by the author)
Figure 6: Text and Image Encoder trained together to create similar embeddings (image made by the author)

3. Finally, Fast Mask Decoder has 2 decoder layers that map the image embedding and a set of prompt embeddings into an output mask.

  • The decoder layer receives the prompt tokens as input that are transformed through a self-attention layer.
  • Its output is combined with the image embedding in a cross-attention layer to update the prompt embedding with image information.
  • Finally, the new prompt embedding goes through a MLP layer that feeds a new cross-attention which is responsible to update the image embedding with prompt information.

The output of the second decoder layer, which is an image embedding conditioned by the input prompts, goes through two transposed convolutional layers to upscale the image.

At the same time, the MLP output is combined with the same image embedding in a new cross-attention layer to feed a 3 layer MLP to produce a vector that is combined with the upscaled image embedding through a spatially point-wise product to generate the final masks.

Note that at every attention layer, positional encodings are added to the image embedding.

Figure 7: Decoder architecture (image made by the author)
Figure 7: Decoder architecture (image made by the author)

Customize your product landing page using SAM

In this section, we will implement SAM in a hypothetical use case where we aim to create several versions of a product landing page in order to perform an A/B test and check which background leads to a higher conversion rate.

For that we will use the model facebook/same-vit-huge available in Hugging Face 🤗 (HF).

We start by importing the libraries, defining the HF token and some helper functions to visualise the results:

  • show_mask is used to show a specific mask in the image.
  • show_masks_on_image makes use of show_mask to plot all masks created by SAM in the image.
  • add_background is used to combine the original image with a background of your choice.
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import ipyplot
import gc
from transformers import pipeline
access_token = "YOUR HUGGING FACE TOKEN"

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    del mask
    gc.collect()

def show_masks_on_image(raw_image, masks):
    plt.imshow(np.array(raw_image))
    ax = plt.gca()
    ax.set_autoscale_on(False)
    for mask in masks:
        show_mask(mask, ax=ax, random_color=True)
    plt.axis("off")
    plt.show()
    del mask
    gc.collect()

def add_background(bg_path, original_image, sam_mask) -> None:

    # load background
    background = np.array(Image.open(bg_path).resize((512,512)))

    # reverse mask and add the third dimension for easier numpy manipulation
    reverse_mask = np.where(sam_mask==0, 1, 0)
    reverse_mask = reverse_mask.reshape((reverse_mask.shape[0], reverse_mask.shape[1], 1))

    # create a hole in the background equal to the object
    background = background*reverse_mask

    # convert original image to numpy and add third dimension to mask
    np_image = np.array(original_image)
    mask = sam_mask.reshape((sam_mask.shape[0],sam_mask.shape[1], 1))

    # add speaker to background
    img = background+(mask*np_image)

    # plot all images
    ipyplot.plot_images([original_image, Image.fromarray(background.astype('uint8')), Image.fromarray(img.astype('uint8'))], labels=["original", "background", "final"], img_width=512)

After that, we can load the image we want to segment. In our case, we want to segment the Marshall Speaker to be able to customise the background.

raw_image = Image.open("IMG_0126.jpg").resize((512,512)) # model was trained with squared images
ipyplot.plot_images([raw_image], labels=["original"], img_width=512)
Figure 8: Marshall Speaker to be segmented (picture taken by the author)
Figure 8: Marshall Speaker to be segmented (picture taken by the author)

We apply SAM to our image, that will segment every object and not just the Marshall Speaker since we do not provide any prompt.

Although, the paper says that it can receive text prompt as input, the implementation available do not allow it. Therefore, we need to search which mask is related to the speaker and, in this case, it is the third mask.

# load model
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0, token=access_token)

# generate masks
outputs = generator(raw_image, points_per_batch=64)

# plot masks on original image
masks = outputs["masks"]
show_masks_on_image(raw_image, masks)

# in our case we want the mask that covers the marshall speaker
mask = masks[2]
show_masks_on_image(raw_image, [mask])
Figure 9: On the left we have all masks and on the right the mask of interest (image made by the author)
Figure 9: On the left we have all masks and on the right the mask of interest (image made by the author)

With the mask created and identified, it is just a matter of creating several images with different backgrounds to test in the product landing page and see which one leads to a higher conversion rate.

The mask allows us to know where the speaker is located in the image and, therefore, we can manipulate the original and the background image with numpy operations to combine both of them. We assign the value 0 to the non speaker pixels of the original image and we do the same for the pixels where the speaker should be placed in the background image. Finally, we sum both numpy matrices to create the final image.

for i in range(1,5):
  add_background(f"back{i}.jpg", raw_image, mask)
Figure 10: Using SAM masks to personalise the background of a landing product page (image made by the author)
Figure 10: Using SAM masks to personalise the background of a landing product page (image made by the author)

Good News!

The implementation in HF does not allow text prompts but there is a project on GitHub called luca-medeiros/lang-segment-anything that enabled the text prompt. You just need to pip install -U git+https://github.com/luca-medeiros/lang-segment-anything.git and run the following lines of code:

from lang_sam import LangSAM

def display_image_with_masks(image, masks):
    num_masks = len(masks)

    fig, axes = plt.subplots(1, num_masks + 1, figsize=(15, 5))
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    for i, mask_np in enumerate(masks):
        axes[i+1].imshow(mask_np, cmap='gray')
        axes[i+1].set_title(f"Mask {i+1}")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

model = LangSAM()
image_pil = Image.open("IMG_0126.jpg").resize((512,512)).convert("RGB")
text_prompt = "speaker"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)

# Convert masks to numpy arrays
masks_np = [mask.squeeze().cpu().numpy() for mask in masks]

# Display the original image and masks side by side
display_image_with_masks(image_pil, masks_np)
Figure 11: Mask generated with SAM and text prompt (image made by the author)
Figure 11: Mask generated with SAM and text prompt (image made by the author)

As you can see, the model managed to correctly create a mask for the speaker by just adding the text prompt ‘speaker’ to the inputs. With this implementation, we can automate the mask creation process and we no longer need to search for the mask of the object we are interested in.

Conclusion

In the fast-paced world of e-commerce, where customers are more demanding than ever, being able to provide a personalised experience is a great way to get ahead and distinguish ourselves from the competition.

As shown in this article, SAM’s ability to precisely identify and delineate objects within an image can transform the way products are cataloged reducing the time to market through an automated editing process. Apart from that, we can easily create multiple options of an image either like we did with the background personalisation or with more sophisticated approaches where we combine SAM with a Stable Diffusion in-painting model to generate wilder backgrounds or add some other objects that can enhance the picture.

The possibilities are immense and it is only a matter of being creative when solving the problems that you are facing!

Keep in touch: LinkedIn, Medium

References

[1] Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, Ross Girshick. Segment Anything. arXiv:2304.02643, 2023.

[2] Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. Masked Autoencoders Are Scalable Vision Learners. arXiv:2111.06377, 2021.

[3] Matthew Tancik, Pratul P. Srinivasan, Ben Mildenhall, Sara Fridovich-Keil, Nithin Raghavan, Utkarsh Singhal, Ravi Ramamoorthi, Jonathan T. Barron, Ren Ng. Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains. arXiv:2006.10739, 2020.

[4] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. Learning Transferable Visual Models From Natural Language Supervision. arXiv:2103.00020, 2021.


Related Articles