The world’s leading publication for data science, AI, and ML professionals.

Create and Deploy a REST API Extracting Predominant Colors from Images

Using unsupervised machine learning, FastAPI and Docker

Image by author.
Image by author.

Table of contents

  1. Problem statement
  2. Extract colors from images
  3. Project structure
  4. Code
  5. Deploy the Docker container
  6. Let’s try it!
  7. API documentation
  8. Conclusions
  9. License disclaimer

1. Problem statement

Let us imagine a control room of a manufacturing facility, where the fabricated products need to be sorted automatically. For instance, based on their color, goods may be redirected to different branches of a roller conveyor for further processing or packaging.

Otherwise, we can also imagine an online retailer trying to enhance the user experience by adding a search-by-color functionality. Customers may more easily find a clothing item from a particular color, thus simplifying their access to products of interest.

Or, just like the author, you can picture yourself as an IT consultant implementing a simple, fast and reusable tool to generate color palettes for presentations, charts and apps from input images.

These are just few examples of how extracting the main colors from a picture may either improve operational efficiency or boost customer experience.

In this blog post, we use Python to implement the extraction of predominant colors from a given picture. Then, we use FastAPI and Docker to package and deploy the solution as a service.

The purpose of this post is to share an end-to-end illustration about the deployment of a lightweight and self-consistent service leveraging Machine Learning techniques to carry out a business purpose. Such a service may be easily integrated in a microservice architecture.

2. Extract colors from images

A digital image is essentially a 2-dimensional grid of individual components known as pixels. Pixels are the smallest unit of display in the image, and carry information about its color. A popular approach to color representation is the RGB color model. This additive model uses combination of the three primary colors – red, green and blue (hence the name, RGB) to create a broad array of colors. The intensity of each primary color is represented by an 8-bit value. Therefore, each pixel has three intensity values, one for each of the primary colors, ranging from 0 to 255:

The additive RGB color model. From Wikipedia.
The additive RGB color model. From Wikipedia.

We can extract predominant colors from an image through clustering. In brief, clustering techniques try to group similar objects together. Among clustering methods, we are going to use the K-Means algorithm. It aims at creating "compact" groups by minimizing the sum of squared distances between the data points and their respective group centers, named centroids. We can choose the number of groups/clusters k to create. In our case, the data points are the RGB values of each pixel. After model training, we can consider the centroids as representatives of the main colors inside the image.

Let us jump into practice and create a ColorAnalyzer class accepting an input image and extracting its main colors. The class will have the following methods:

  • load_image loads the image into a 2D array from the local path or URL.
  • is_url checks if the input path is URL.
  • preprocess_image resizes the image to improve processing speed.
  • find_clusters applies K-Means clustering to extract important colors as centroids.
  • sort_clusters_by_size sorts colors by cluster size in descending order.
  • plot_image displays the original input image (resized).
  • plot_3d_clusters shows a 3-D plot of the clusters. The observations display their centroid’s (predominant) color.
  • plot_predominant_colors plots a bar chart of the main colors ordered by presence in the image.
  • get_predominant_colors returns a list of the main colors as JSON objects.
import cv2
from PIL import Image
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
import requests
import json
from urllib.parse import urlparse
from io import BytesIO

class ColorAnalyzer:
    '''
    This class analyzes the predominant colors in an image 
    using K-Means clustering based on the RGB color paradigm.

    Attributes:
        url_or_path (str): The URL or local file path of the image.
        num_clusters (int): The number of clusters to identify as predominant colors.
        scaling_factor (int): The percentage by which to scale the image for preprocessing.
        image (numpy.ndarray): The loaded and preprocessed image.
        pixels (numpy.ndarray): Reshaped image data for clustering.
        image_rgb (numpy.ndarray): Resized image in RGB format.
        centroids (numpy.ndarray): Centroids (predominant colors) obtained through clustering.
        percentages (numpy.ndarray): Percentage of pixels belonging to each cluster.
        labels (numpy.ndarray): Labels indicating cluster membership for each pixel.
        sorted_colors (numpy.ndarray): Predominant colors sorted by cluster size.
        sorted_percentages (numpy.ndarray): Percentages of pixels per cluster, sorted by cluster size.
    '''
    def __init__(self, url_or_path, num_clusters=4, scaling_factor=10):
        '''
        Initializes the ImageColorAnalyzer with the provided parameters.

        Args:
            url_or_path (str): The URL or local file path of the image.
            num_clusters (int, optional): The number of clusters to identify as predominant colors (default is 5).
            scaling_factor (int, optional): The percentage by which to scale the image for preprocessing (default is 10).
        '''
        self.url_or_path = url_or_path
        self.num_clusters = num_clusters
        self.scaling_factor = scaling_factor
        self.image = self.load_image()
        self.pixels, self.image_rgb = self.preprocess_image()
        self.centroids, self.percentages, self.labels = self.find_clusters()
        self.sorted_colors, self.sorted_percentages = self.sort_clusters_by_size()

    def load_image(self):
        '''
        Load the image into a 2D array from the local path or URL.

        Returns:
            numpy.ndarray: The loaded image.

        Raises:
            Exception: If the URL does not exist or is broken, or if the image path is invalid.
        '''
        # If the input image path is a URL 
        if self.is_url():

            # Get the response
            response = requests.get(self.url_or_path)

            # If there is a problem in getting the response..
            if response.status_code != 200:

                # ..raise an exception
                raise Exception('URL does not exist or it is broken.')

            # Try to extract the image from the URL
            try:

                # Get PIL image object file from the response 
                image = Image.open(BytesIO(response.content))

                # Convert image from PIL to OpenCV format
                image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

            # If an error occurs in processing the URL..
            except:

                # ..raise an exception
                raise Exception('URL may not contain an image.')

        # If the input image path is not a URL
        else:

            # Try to load the image
            try:

                # Load the image from a local path
                image = cv2.imread(self.url_or_path)

            # If there is a problem reading the local path..
            except:

                # ..raise an exception
                raise Exception('Invalid image path.')

        # return the loaded image
        return image

    def is_url(self):
        '''
        Check if the input path is URL.

        Returns:
            bool: True if the path is a URL, False otherwise.
        '''
        # Return True if the path is a URL, False otherwise
        return 'http' in urlparse(self.url_or_path).scheme

    def preprocess_image(self):
        '''
        Resize the image to improve processing speed.

        Returns:
            tuple: 
                Resized image data for clustering
                Image for plotting.
        '''
        # Resize the image by the scaling factor for performances
        width = int(self.image.shape[1] * self.scaling_factor / 100)
        height = int(self.image.shape[0] * self.scaling_factor / 100)
        resized_img = cv2.resize(
            self.image, (width, height), interpolation=cv2.INTER_AREA)

        # Convert the image back to RGB
        image_rgb = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)

        # Extract pixels as 2D array for clustering
        pixels = image_rgb.reshape(-1, 3)

        # Return array for clustering and image for plotting
        return pixels, image_rgb

    def find_clusters(self):
        '''
        Find predominant colors through clustering.

        Returns:
            tuple: 
                centroids (predominant colors)
                percentages of pixels per cluster
                labels of each point.           
        '''
        # Instantiate clustering model
        kmeans = KMeans(n_clusters=self.num_clusters, n_init=10)

        # Fit the model on the image and get labels
        labels = kmeans.fit_predict(self.pixels)

        # Get centroids (predominant colors)
        centroids = kmeans.cluster_centers_.round(0).astype(int)

        # Get percentage of pixels belonging to each cluster
        percentages = np.bincount(labels) / len(self.pixels) * 100

        # Return:
        #   - centroids 
        #   - percentage of pixels per cluster        
        #   - labels of each point
        return centroids, percentages, labels

    def sort_clusters_by_size(self):
        '''
        Sort predominant colors and percentages 
        of pixels per cluster by cluster size 
        in descending order.

        Returns:
            tuple: 
                Predominant colors sorted by cluster size
                Percentages of pixels per cluster sorted by cluster size.
        '''
        sorted_indices = np.argsort(self.percentages)[::-1]
        sorted_colors = self.centroids[sorted_indices]
        sorted_percentages = self.percentages[sorted_indices]
        return sorted_colors, sorted_percentages

    def plot_image(self):
        '''
        Plot the preprocessed image (resized).
        '''
        plt.imshow(self.image_rgb)
        plt.title('Preprocessed Image')
        plt.axis('off')
        plt.show()

    def plot_3d_clusters(self, width=15, height=12):
        '''
        Plot a 3D visualization of the clustering.

        Args:
            width (int, optional): Width of the plot (default is 15).
            height (int, optional): Height of the plot (default is 12).
        '''
        # Prepare figure
        fig = plt.figure(figsize=(width, height))
        ax = fig.add_subplot(111, projection='3d')

        # Plot point labels with their cluster's color
        for label, color in zip(np.unique(self.labels), self.centroids):
            cluster_pixels = self.pixels[self.labels == label]
            r, g, b = color
            ax.scatter(cluster_pixels[:, 0], 
                       cluster_pixels[:, 1], 
                       cluster_pixels[:, 2], 
                       c=[[r/255, g/255, b/255]],  
                       label=f'Cluster {label+1}')

        # Display title, axis labels and legend
        ax.set_title('3D Cluster Visualization')
        ax.set_xlabel('r')
        ax.set_ylabel('g')
        ax.set_zlabel('b')
        plt.legend()
        plt.show()

    def plot_predominant_colors(self, width=12, height=8):
        '''
        Plot a bar chart of predominant colors 
        ordered by presence in the picture.

        Args:
            width (int, optional): Width of the plot (default is 12).
            height (int, optional): Height of the plot (default is 8).
        '''
        # Prepare color labels for the plot
        color_labels = [f'Color {i+1}' for i in range(self.num_clusters)]

        # Prepare figure
        plt.figure(figsize=(width, height))

        # Plot bars
        bars = plt.bar(color_labels, 
                       self.sorted_percentages, 
                       color=self.sorted_colors / 255.0, 
                       edgecolor='black')

        # Add percentage of each bar on the plot
        for bar, percentage in zip(bars, 
                                   self.sorted_percentages):
            plt.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height(), 
                f'{percentage:.2f}%', 
                ha='center', 
                va='bottom')

        # Display title and axis labels
        plt.title(f'Top {self.num_clusters} Predominant Colors')
        plt.xlabel('Colors')
        plt.ylabel('Percentage of Pixels')
        plt.xticks(rotation=45)
        plt.show()

    def get_predominant_colors(self):
        '''
        Return a list of predominant colors.
        Each color is a JSON object with RGB code and percentage.
        '''
        # Prepare output list
        colors_json = []

        # For each predominant color
        for color, percentage in zip(self.sorted_colors, 
                                     self.sorted_percentages):
            # Get the RGB code
            r, g, b = color

            # Prepare JSON object
            color_entry = {'color': {'R': f'{r}', 
                                     'G': f'{g}', 
                                     'B': f'{b}'}, 
                           'percentage': f'{percentage:.2f}%'}

            # Append JSON object to color list
            colors_json.append(color_entry)

        # Return the results
        return colors_json

We can test our class using a Python notebook:

  • Instantiate the class using an input URL pointing to an online image:
colors_extractor = ColorAnalyzer(
  'https://fastly.picsum.photos/id/63/5000/2813.jpg?hmac=HvaeSK6WT-G9bYF_CyB2m1ARQirL8UMnygdU9W6PDvM',
  num_clusters=4)
  • Plot the original image (resized):
colors_extractor.plot_image()
Image by author.
Image by author.
  • Plot a bar chart of the most predominant colors:
colors_extractor.plot_predominant_colors()
Image by author.
Image by author.
  • Display the obtained clusters:
colors_extractor.plot_3d_clusters()
Image by author.
Image by author.
  • Return a list of extracted colors as JSON objects:
print(json.dumps(
    colors_extractor.get_predominant_colors(), 
    indent=3)
)
[
   {
      "color": {
         "R": "140",
         "G": "15",
         "B": "19"
      },
      "percentage": "54.56%"
   },
   {
      "color": {
         "R": "231",
         "G": "37",
         "B": "47"
      },
      "percentage": "35.55%"
   },
   {
      "color": {
         "R": "163",
         "G": "111",
         "B": "63"
      },
      "percentage": "6.78%"
   },
   {
      "color": {
         "R": "211",
         "G": "201",
         "B": "186"
      },
      "percentage": "3.11%"
   }
]

We can repeat the process for any input image (URL or local path) and inspect the outcome. For instance:

# Different image
colors_extractor = ColorAnalyzer(
  'https://fastly.picsum.photos/id/165/2000/1333.jpg?hmac=KK4nT-Drh_vgMxg3hb7rOd6peHRIYmxMg0IEyxlTVFg',
  num_clusters=4) 

# Plot resized original image
colors_extractor.plot_image()

# 3D plot of the centroids and data points
colors_extractor.plot_3d_clusters()

# Bar chart of predominant colors
colors_extractor.plot_predominant_colors()

# Predominant colors as list of JSON
print(json.dumps(
    colors_extractor.get_predominant_colors(), 
    indent=3)
)
Image by author.
Image by author.

How can we create a web service that provides this image analysis capability on request? We need to turn our notebook into a Python project that exposes a REST API.

3. Project structure

Let us introduce the main ingredients of our project:

  1. [REST](https://en.wikipedia.org/wiki/REST) API: REST (Representational State Transfer) API is an architectural style for designing applications. It uses standard HTTP methods (GET, POST, …) to allow communication between different systems. In our case, we want to enable a client to request the extraction of predominant colors from an input image using HTTP requests. We will use FastAPI to build the API service.
  2. Docker: a platform that allows to build, deploy, and run applications in isolated containers. Using Docker will help us to package together all the dependencies needed for the color extraction task, ensuring consistency, portability and removing operational headaches (What Python version am I using? Did I install all packages? What version of OpenCV do I need?).
REST API. Image by author.
REST API. Image by author.

We can structure the project as follows:

colors-extractor/
├── api/
│   ├── __init__.py
│   └── endpoints.py
├── dto/
│   ├── __init__.py
│   └── image_data.py
├── service/
│   ├── __init__.py
│   └── image_analyzer.py
├── notebooks/
│   └── extract_colors.ipynb
├── main.py
├── requirements.txt
├── Dockerfile
└── README.md
  • README.md: project documentation in Markdown.
  • requirements.txt: list of Python dependencies needed to run the project.
  • Dockerfile: text document with all the commands to assemble a Docker image, i.e. an isolated environment for the project.
  • main.py: the entry point of our application.
  • api/: sub-folder with the definition of the REST API endpoints. In our simple example, we just have one endpoint to request color extraction.
  • dto/: sub-folder with the data classes used in the requests and responses for the API service. The name dto derives from Data Transfer Object, as these classes represent the interface between the client and the service.
  • service/: sub-folder with the application logic. In our example, the ColorAnalyzer class provides the image processing capability.
  • notebooks/: sub-folder with notebook experiments.

The separation between data (or model), endpoints definition (or controller) and application logic (or service) in separate modules (dto/, api/, service/) is one way to guarantee clarity, maintainability, and reusability. It also promotes a cleaner architecture and simplify further developments. The reader may refer to the MVC design pattern for more information.

4. Code

Let’s start looking into main.py. In our entry point, we:

  • Create a FastAPI application: app = FastAPI().
  • Enable CORS to allow client-side calls to the web service through the add_middleware method.
  • Define a GET request handler for the root endpoint ("/") returning a simple message.
  • Include a router to our api module with the "api/" prefix. The prefix will become part of the final endpoint as follows: "http://<host>:<port>/api/<endpoint>".
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.endpoints import router as api_router

# Create a FastAPI application instance
app = FastAPI()

app.add_middleware(
    CORSMiddleware,            # Add Cross-Origin Resource Sharing (CORS) middleware to handle browser security restrictions
    allow_origins = ['*'],     # Allow requests from all origins (insecure, for development only)
    allow_credentials = True,  # Allow credentials like cookies in the requests
    allow_methods = ['*'],     # Allow all HTTP methods (GET, POST, etc.)
    allow_headers = ['*'],     # Allow all headers in requests
)

# Define a GET request handler for the root endpoint ('/')
@app.get('/')
# Define an asynchronous function for the root endpoint
async def root():
    # Return a JSON response with a message
    return {'message': 'API for color extraction from images.'}

# Include the API router with a prefix of '/api'
app.include_router(api_router, prefix='/api')

The definition of the endpoints is the api module. Inside the"api/endpoints.py" file, we:

  • Create the router for the FastAPI app: router = APIRouter(), which imported in the main.py.
  • Define a POST request for the /colors endpoint. The application expects a request of class ColorExtractionRequest, and returns a response of class ColorExtractionResponse. The two objects are defined in the dto module.
  • Upon receiving a request, a ColorAnalyzer object is instantiated and the color extraction results are returned as response to the user.
from fastapi import APIRouter, HTTPException
from service.image_analyzer import ColorAnalyzer
from dto.image_data import ColorExtractionRequest, ColorExtractionResponse
import logging 

# Define the router for the FastAPI app
router = APIRouter()

# Logging configuration
logging.basicConfig(
    format = '%(levelname)s:     %(asctime)s, %(module)s, %(processName)s, %(message)s', 
    level = logging.INFO)

# Instantiate logger
logger = logging.getLogger(__name__)

# Define a POST request handler for the '/colors' endpoint
@router.post(
        '/colors',                                 # Endpoint name
        response_model = ColorExtractionResponse,  # Data model for the response 
        tags = ['Colors Extraction']               # Tag used for documentation
        )
# Define an asynchronous function accepting a 'ColorExtractionRequest' as request body
async def colors(input_data: ColorExtractionRequest):
    '''
    Analyze an image and return predominant colors.

    Parameters:
      - input_data[ColorExtractionRequest]: Request data including 'url_or_path' (str) and 'num_clusters' (int, optional).

    Returns:
      - ImageAnalysisResponse: Response data containing a list of predominant colors.

    Example Usage:
      - Send a POST request with JSON data containing the 'url_or_path' parameter to extract colors from an image.
    '''

    # Log request information
    logger.info(f'Analysis for image key: {input_data.url_or_path}.')
    logger.info(f'Requested colors: {input_data.num_clusters}.')

    # Perform the color extraction
    try:

        # Instantiate the ColorAnalyzer class for image processing
        color_json = ColorAnalyzer(
                input_data.url_or_path, 
                input_data.num_clusters
            ).get_predominant_colors()

        logger.info(f'Analysis completed.')

        # Return the predominant colors
        return {'predominant_colors': color_json}

    # If an error occurs
    except Exception as e:

        # Log the error message 
        logger.error(f'Exception in image processing: {str(e)}.')

        # Raise an exception
        raise HTTPException(status_code = 500, detail = str(e))

Let us explore the data model for the requests and responses. Its classes are in the dto module, inside the dto/image_data.py file:

from pydantic import BaseModel
from typing import List

class Color(BaseModel):
    '''
    Color representation as RGB values.
    '''
    R: int
    G: int
    B: int

class ColorInfo(BaseModel):
    '''
    Information about a color: RGB and percentage of pixels across image.
    '''
    color: Color
    percentage: str

class ColorExtractionRequest(BaseModel):
    '''
    Colors extraction request.
    '''
    url_or_path: str
    num_clusters: int = 4 # Default to 4 most predominant colors if not provided

class ColorExtractionResponse(BaseModel):
    '''
    Color extraction response from an image analysis request.
    '''
    predominant_colors: List[ColorInfo]

This data model is straightforward. In brief, the service:

  • Accepts an input URL or path, as well as a desired number of clusters/predominant colors.
  • Returns a list of JSON objects made of RGB values and the percentage of pixels in the image belonging to that cluster.

Interestingly, defining pydantic classes not only increases readability and maintenance, but also simplifies the generation of API documentation leveraging the FastAPI framework (more on that in the following paragraphs).

Now, we only need to define a Docker image and deploy the Docker container.

5. Deploy the Docker container

Let us observe the Dockerfile inside the project root. This text document contains all the commands to create a Docker image for our project. In detail:

  • FROM python:3.8-slim sets the base image as starting point.
  • WORKDIR /colors-extractor sets the working directory inside the container to /colors-extractor.
  • COPY requirements.txt requirements.txt copies the dependencies from our local machine into the container inside the WORKDIR.
  • RUN pip install -r requirements.txt installs the Python dependencies listed in requirements.txt on the Docker container.
  • COPY . . copies the project files from our local machine into the container. We do this later, after installing the dependencies, because Docker creates an image by layering subsequent commands. If we just update our code base, given the current order of commands the Docker engine will not reinstall all the dependencies.
  • EXPOSE 8000 exposes port 8000, on which our FastAPI app runs.
  • CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] is the command to run when the container starts. In this case, we instruct the container to run the FastAPI app using Uvicorn. The mapping main:app sets the correct entry point of our app, in our case main.py.
# Use the official Python image as the base image
FROM python:3.8-slim

# Set the working directory
WORKDIR /colors-extractor

# Copy the requirements.txt file and install dependencies
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt

# Copy the project files into the container
COPY . .

# Expose the port that the FastAPI app will listen on
EXPOSE 8000

# Command to run the FastAPI app using Uvicorn (handled by Dockerfile)
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

To deploy the Docker container locally, from command line:

# Move into the project directory
cd colors-extractor

# Create Docker image
docker build -t colors-extractor .

# Execute Docker container
docker run -d -p 8000:8000 colors-extractor

We can show running containers through:

docker ps

With this command, we can get the container id associated to our application, and use it to inspect the logs:

docker logs <container_id>

The logs confirm that the application is running:

Image by author.
Image by author.

6. Let’s try it!

Let us remember the endpoint structure:

  • main.py declares a router to the api module with an /api prefix.
  • The api module defines a POST request handler for the /colors endpoint inside the endpoints.py file.
  • The running port is 8000.

Therefore, we should perform a POST call to:

  • [http://localhost:8000/api/colors](http://localhost:8000/api/colors.)

To test the service, we may use tools such as curl or Postman:

curl --location 'http://localhost:8000/api/colors' 
--header 'Content-Type: application/json' 
--data '{
    "url_or_path": "https://fastly.picsum.photos/id/63/5000/2813.jpg?hmac=HvaeSK6WT-G9bYF_CyB2m1ARQirL8UMnygdU9W6PDvM",
    "num_clusters": 3
}'

Result:

{
  "predominant_colors": [
    {
      "color": {
        "R": 140,
        "G": 16,
        "B": 19
      },
      "percentage": "54.97%"
    },
    {
      "color": {
        "R": 231,
        "G": 37,
        "B": 47
      },
      "percentage": "35.55%"
    },
{
      "color": {
        "R": 180,
        "G": 142,
        "B": 104
      },
      "percentage": "9.47%"
    }
  ]
}

Similarly, using Postman:

Testing the service using Postman. Image by author.
Testing the service using Postman. Image by author.

We can inspect the container again to verify the presence of our test calls in the logs:

Logs. Image by author.
Logs. Image by author.

7. API documentation

FastAPI automatically provides a documentation of the implemented APIs. By default, it is available at:

  • [http://<host>:<port>/docs](http://localhost:8000/api/colors.)

By reaching that URL, we can find a web user interface (Swagger UI) fully documenting our endpoints:

API documentation. Image by author.
API documentation. Image by author.

The data model for the requests and responses is under the Schemas section of the web interface, and it is populated with the pydantic models defined in the dto module and associated to the router’s handlers:

Data models. Image by author.
Data models. Image by author.

8. Conclusions

In this blog post, we shared a step-by-step implementation of a service leveraging:

  • Unsupervised Machine Learning techniques to reach a business goal, i.e. extracting predominant colors from images.
  • FastAPI to serve the solution as a REST API.
  • Docker for isolated and consistent deployment.

Our goal is to showcase a comprehensive example that could be easily reused and extended to deploy a Machine Learning model as a Rest Api.

The full code for this blog is available on GitHub.

9. License disclaimer

To write this post, we used two images:

Both sources are free to use for personal and commercial use under the Unplash license. We generated image URLs using Picsum (GitHub repo) available under the MIT license.


Related Articles