Object Detection using RetinaNet and KerasCV

Object detection using the power and simplicity of the KerasCV library.

Ed Izaguirre
Towards Data Science

--

An image of leaves on a plant. Created in DALL·E 2.

Table of Contents

  1. Wait, what’s KerasCV?
  2. Inspecting the Data
  3. Pre-Processing Images
  4. RetinaNet Model Background
  5. Training RetinaNet
  6. Making Predictions
  7. Conclusion
  8. References

Relevant Links

  • Working Kaggle Notebook: Feel free to make a copy of the notebook, play around with the code, and use that free GPU.
  • PlantDoc Dataset: This is the dataset used in this notebook, hosted on Roboflow. The dataset is published under the CC BY 4.0 DEED license, which means you can copy and redistribute the material in any medium or format for any purpose, even commercially.

Wait, what’s KerasCV?

After finishing a mini-project based on image segmentation (see here), I was ready to move into another common task under the computer vision umbrella: object detection. Object detection refers to taking an image and producing boxes around objects of interest, as well as classifying the objects the boxes contain. As a simple example, take a look at the image below:

Example of object detection. Notice the bounding box and class label. Image by author.

The blue box is referred to as a bounding box and the class name is placed right above it. Object detection can thus be broken down into two mini-problems:

  1. A regression problem where the model must predict x and y coordinates for the both the upper left corner and the lower right corner of the box.
  2. A classification problem where the model must predict what class of object the box is observing.

In this example, the bounding box was created and labeled by a human. We would like to automate this process, and a well-trained object detection model can do just that.

I sat down to review my study material regarding object detection, and was promptly disappointed. Unfortunately, most introductory material scarcely mention object detection. François Chollet in Deep Learning with Python [1] states:

Note that we won’t cover object detection, because it would be too specialized and too complicated for an introductory book.

Aurélion Géron [2] provides a lot of textual content covering the ideas behind object detection, but provides only a few lines of code covering an object detection task with dummy bounding boxes, far from the end-to-end pipeline I was looking for. Andrew Ng’s [3] famous Deep Learning Specialization course goes the deepest on object detection, but even he ends the coding lab by loading a pre-trained object detection model and just doing inference.

Looking to go deeper, I started to sketch out the outline of an object detection pipeline. Just to do pre-processing for a RetinaNet model, one would have to do the following (note: other object detection models such as YOLO would require different steps):

  • Take input images and resize them all to be the same size, with padding to prevent the aspect ratio from getting messed up. Oh, don’t forget about the bounding boxes; these also need to be appropriately reshaped or you will ruin your data.
  • Generate anchor boxes at different scales and aspect ratios based on the ground truth bounding boxes in the training set. These anchor boxes act as reference points for the model during training.
  • Assign labels to the anchor boxes based on their overlap with ground truth boxes. Anchor boxes with high overlap are labeled as positive examples, while those with low overlap are labeled as negative examples.
  • There are multiple ways to describe the same bounding box. You would need to implement functions for converting between these different formats. More on this in a moment.
  • Implement data augmentation, taking care to not only augment the images but also the boxes. In theory you can omit this, but in practice this is necessary to help our models generalize well.

Take a look at this example on the Keras website. Yikes. Post-processing of our model predictions would take even more work. To paraphrase the Keras team: this is a technically complex problem.

As I was beginning to despair, I started desperately browsing the internet and stumbled upon a library I had never heard of before: KerasCV. As I read the documentation, it began to dawn on me that this is the future of computer vision in TensorFlow/Keras. From their introduction:

KerasCV can be understood as a horizontal extension of the Keras API: the components are new first-party Keras objects that are too specialized to be added to core Keras. They receive the same level of polish and backwards compatibility guarantees as the core Keras API, and they are maintained by the Keras team.

“But why did none of my study materials even mention this?” I wondered. The answer is simple: this is a fairly new library. The first commit on GitHub was on April 13th, 2022, too new to show up even in the latest editions of my textbooks. In fact, the 1.0 version of the library hasn’t even been released yet (as of November 10th, 2023 it is on 0.6.4). I expect KerasCV will be discussed in detail by the next editions of my textbooks and online courses (to be fair, Gèron does mention in passing a “new Keras NLP project” and Keras CV project that the reader may be interested in).

Being so new, KerasCV doesn’t have many tutorials aside from those published by the Keras team themselves (see here). In this tutorial I will demonstrate an end-to-end object detection pipeline to recognize healthy and diseased leaves using techniques inspired by but distinct from the official Keras guides. With KerasCV, even beginners can take labeled datasets and use them to build effective object detection pipelines.

A few notes before we begin. KerasCV is a fast changing library, with the codebase and documentation being updated on a regular basis. The implementation shown here will work with KerasCV version 0.6.4. The Keras team has stated that: “there is no backwards compatibility contract until KerasCV reaches v1.0.0.” This implies that there is no guarantee the methods used in this tutorial will continue to work as KerasCV gets updated. I have hard coded the KerasCV version number in the linked Kaggle notebook to prevent these sorts of issues.

KerasCV has quite a few bugs that are already noted in the Issues tab on GitHub. In addition, the documentation is lacking in some areas (I’m looking at you, MultiClassNonMaxSuppression). As you play around with KerasCV, try not to be discouraged by these issues. In fact, this is a great opportunity to become a contributor to the KerasCV codebase!

This tutorial will focus on implementation details of KerasCV. I will briefly review some high-level concepts in object detection, but I will assume the reader has some background knowledge on concepts such as the RetinaNet architecture. The code shown here has been edited and rearranged for clarity, please see the Kaggle notebook linked above for the complete code.

Finally, a note on safety. The model created here is not intended to be state-of-the-art; treat this as a high-level tutorial. Further fine-tuning and data cleaning would be expected before this plant disease detection model could be implemented in production. It would be a good idea to run any predictions a model makes by a human expert to confirm a diagnosis.

Inspecting the Data

The PlantDoc dataset contains 2,569 images across 13 plant species and 30 classes. The goal of the dataset is set out in the abstract of the paper PlantDoc: A Dataset for Visual Plant Disease Detection by Singh et. al [4].

India loses 35% of the annual crop yield due to plant diseases. Early detection of plant diseases remains difficult due to the lack of lab infrastructure and expertise. In this paper, we explore the possibility of computer vision approaches for scalable and early plant disease detection.

This is a noble goal, and an area where computer vision can do a lot of good for farmers.

Roboflow allows us to download the dataset in a variety of different formats. Since we are using TensorFlow, let’s download the dataset as a TFRecord. A TFRecord is a specific format used in TensorFlow that is designed to store large amounts of data efficiently. The data is represented by a sequence of records, where each record is a key-value pair. Each key is a referred to as a feature. The downloaded zip file contains four files, two for training and two for validation:

  • leaves_label_map.pbtxt : This is a Protocol Buffers text format file, which is used to describe the structure of the data. Opening the file in a text editor, I see that there are thirty classes. There are a mixture of healthy leaves such as Apple leaf and unhealthy leaves such as Apple Scab Leaf .
  • leaves.tfrecord : This is the TFRecord file that contains all of our data.

Our first step is to inspect leaves.tfrecord. What features do our records contain? Unfortunately this is not specified by Roboflow.

train_tfrecord_file = '/kaggle/input/plants-dataset/leaves.tfrecord'
val_tfrecord_file = '/kaggle/input/plants-dataset/test_leaves.tfrecord'

# Create a TFRecordDataset
train_dataset = tf.data.TFRecordDataset([train_tfrecord_file])
val_dataset = tf.data.TFRecordDataset([val_tfrecord_file])

# Iterate over a few entries and print their content. Uncomment this to look at the raw data
for record in train_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(record.numpy())
print(example)

I see the following features printed:

  • image/encoded : This is the encoded binary representation of an image. In the case of this dataset the images are encoded in the jpeg format.
  • image/height : This is the height of each image.
  • image/width : This is the width of each image.
  • image/object/bbox/xmin : This is the x-coordinate of the top-left corner of our bounding box.
  • image/object/bbox/xmax : This is the x-coordinate of the bottom-right corner of our bounding box.
  • image/object/bbox/ymin : This is the y-coordinate of the top-left corner of our bounding box.
  • image/object/bbox/ymax : This is the y-coordinate of the bottom-right corner of our bounding box.
  • image/object/class/label : These are the labels associated with each bounding box.

Now we want to take all of the images and associated bounding boxes and put them together in a TensorFlow Dataset object. Dataset objects allow you to store large amounts of data without overwhelming your system’s memory. This is accomplished through features such as lazy loading and batching. Lazy loading means that the data is not loaded into memory until its explicitly requested (for example when performing transformations or during training). Batching means that only a select number of images (usually 8, 16, 32, etc.) get loaded into memory at once. In short, I recommend always converting your data into Dataset objects, especially when you are dealing with large amounts of data (typical in object detection).

To convert a TFRecord to a Dataset object in TensorFlow, you can use the tf.data.TFRecordDataset class to create a dataset from our TFRecord file, and then apply parsing functions using the map method to extract and preprocess features. The parsing code is shown below.

def parse_tfrecord_fn(example):
feature_description = {
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/height': tf.io.FixedLenFeature([], tf.int64),
'image/width': tf.io.FixedLenFeature([], tf.int64),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
}

parsed_example = tf.io.parse_single_example(example, feature_description)

# Decode the JPEG image and normalize the pixel values to the [0, 255] range.
img = tf.image.decode_jpeg(parsed_example['image/encoded'], channels=3) # Returned as uint8

# Get the bounding box coordinates and class labels.
xmin = tf.sparse.to_dense(parsed_example['image/object/bbox/xmin'])
xmax = tf.sparse.to_dense(parsed_example['image/object/bbox/xmax'])
ymin = tf.sparse.to_dense(parsed_example['image/object/bbox/ymin'])
ymax = tf.sparse.to_dense(parsed_example['image/object/bbox/ymax'])
labels = tf.sparse.to_dense(parsed_example['image/object/class/label'])

# Stack the bounding box coordinates to create a [num_boxes, 4] tensor.
rel_boxes = tf.stack([xmin, ymin, xmax, ymax], axis=-1)
boxes = keras_cv.bounding_box.convert_format(rel_boxes, source='rel_xyxy', target='xyxy', images=img)

# Create the final dictionary.
image_dataset = {
'images': img,
'bounding_boxes': {
'classes': labels,
'boxes': boxes
}
}
return image_dataset

Let’s break this down:

  • feature_description : This is a dictionary that describes the expected format of each of our features. We use tf.io.FixedLenFeature when the length of a feature is fixed across all examples in the dataset, and tf.io.VarLenFeature when some variability in the length is expected. Since the number of bounding boxes is not constant across our dataset (some images have more boxes, others have less), we use tf.io.VarLenFeature for anything related to bounding boxes.
  • We decode the image files using tf.image.decode_jpeg , since our images are encoded in the JPEG format.
  • Note the use of tf.sparse.to_dense used for the bounding box coordinates and labels. When we use tf.io.VarLenFeature the information comes back as a sparse matrix. A sparse matrix is a matrix in which most of the elements are zero, resulting in a data structure that efficiently stores only the non-zero values along with their indices. Unfortunately, many pre-processing functions in TensorFlow require dense matrices. This includes tf.stack , which we use to horizontally stack information from multiple bounding boxes together. To fix this issue, we use tf.sparse.to_dense to convert the sparse matrices to dense matrices.
  • After stacking the boxes, we use KerasCV’s keras_cv.bounding_box.convert_format function. When inspecting the data, I noticed that the bounding box coordinates were normalized between 0 and 1. This means that the numbers represent percentages of the images total width/height. So a value of 0.5 represents 50% * image_width, as an example. This is a relative format, which Keras refers to as REL_XYXY , rather than the absolute format XYXY. In theory converting to the absolute format is not necessary, but I was running into bugs when training my model with relative coordinates. See the KerasCV documentation for some other supported bounding box formats.
  • Finally, we take the images and bounding boxes and convert them into the format that KerasCV wants: dictionaries. A Python dictionary is a data type that contains key-value pairs. Specifically, KerasCV expects the following format:
image_dataset = {
"images": [width, height, channels],
bounding_boxes = {
"classes": [num_boxes],
"boxes": [num_boxes, 4]
}
}

This is actually a “dictionary within a dictionary”, since bounding_boxes is also a dictionary.

Finally use the .map function to apply the parsing function to our TFRecord. You may then inspect the Dataset object. Everything checks out.

train_dataset = train_dataset.map(parse_tfrecord_fn)
val_dataset = val_dataset.map(parse_tfrecord_fn)

# Inspecting the data
for data in train_dataset.take(1):
print(data)

Congratulations, the hardest part is now over with. Creating the “dictionary within a dictionary” that KerasCV wants is the most difficult task in my opinion. The rest is more straightforward.

Pre-Processing Images

Our data is already split into training and validation sets. So we will begin by batching our datasets.

# Batching
BATCH_SIZE = 32
# Adding autotune for pre-fetching
AUTOTUNE = tf.data.experimental.AUTOTUNE

train_dataset = train_dataset.ragged_batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
val_dataset = val_dataset.ragged_batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

NUM_ROWS = 4
NUM_COLS = 8
IMG_SIZE = 416
BBOX_FORMAT = "xyxy"

A few notes:

  • We are using ragged_batch for the same reason we used VarLenFeature : we don’t know in advance how many bounding boxes we will have for each image. If all of the images had the same number of bounding boxes, then we would just use batch instead.
  • We set BBOX_FORMAT=“xyxy” . Recall that earlier when loading in data we converted the bounding box format from the relative XYXY format to the absolute XYXY format.

Now we can implement data augmentation. Data augmentation is a common technique in computer vision problems. It modifies your training images slightly e.g. a slight rotation, horizontally flipping the images, etc. This helps solve the problem of having too little data and also helps with regularization. Here we will introduce the following augmentations:

  • KerasCV’s JitteredResize function. This function is designed for object detection pipelines and implements an image augmentation technique that involves randomly scaling, resizing, cropping, and padding images along with corresponding bounding boxes. This process introduces variability in scale and local features, enhancing the diversity of the training data for improved generalization.
  • We then add horizontal and vertical RandomFlips as well as a RandomRotation. Here the factor is a float that represents a fraction of 2π. We use .25, which means that our augmenter will rotate our images by some number between -25% of π to 25% of π. In degrees this means between -45° to 45° rotations.
  • Finally we add in RandomSaturation and RandomHue . A saturation of 0.0 would leave a grayscale image, while 1.0 would be fully saturated. A factor of 0.5 would leave no change, so choosing a range of 0.4–0.6 results in a subtle change. A hue factor of 0.0 would leave no change. Putting factor=0.2 implies a range of 0.0–0.2, another subtle change.
augmenter = keras.Sequential(
[
keras_cv.layers.JitteredResize(
target_size=(IMG_SIZE, IMG_SIZE), scale_factor=(0.8, 1.25), bounding_box_format=BBOX_FORMAT
),
keras_cv.layers.RandomFlip(mode="horizontal_and_vertical", bounding_box_format=BBOX_FORMAT),
keras_cv.layers.RandomRotation(factor=0.25, bounding_box_format=BBOX_FORMAT),
keras_cv.layers.RandomSaturation(factor=(0.4, 0.6)),
keras_cv.layers.RandomHue(factor=0.2, value_range=[0,255])
]
)

train_dataset = train_dataset.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)

We typically only augment the training set because we want the model to avoid “memorizing” patterns and instead make sure the model learns general patterns that will be found in the real world. This increases the diversity of what the model sees during training.

We also want to resize the validation images to be the same size (with padding). These will be resized without any distortion. The bounding boxes must also be reshaped accordingly. KerasCV can handle this difficult task with ease:

# Resize and pad images
inference_resizing = keras_cv.layers.Resizing(
IMG_SIZE, IMG_SIZE, pad_to_aspect_ratio=True, bounding_box_format=BBOX_FORMAT
)

val_dataset = val_dataset.map(inference_resizing, num_parallel_calls=tf.data.AUTOTUNE)

Finally we can visualize our images and bounding boxes with the pre-processing included:

class_mapping = {
1: 'Apple Scab Leaf',
2: 'Apple leaf',
3: 'Apple rust leaf',
4: 'Bell_pepper leaf',
5: 'Bell_pepper leaf spot',
6: 'Blueberry leaf',
7: 'Cherry leaf',
8: 'Corn Gray leaf spot',
9: 'Corn leaf blight',
10: 'Corn rust leaf',
11: 'Peach leaf',
12: 'Potato leaf',
13: 'Potato leaf early blight',
14: 'Potato leaf late blight',
15: 'Raspberry leaf',
16: 'Soyabean leaf',
17: 'Soybean leaf',
18: 'Squash Powdery mildew leaf',
19: 'Strawberry leaf',
20: 'Tomato Early blight leaf',
21: 'Tomato Septoria leaf spot',
22: 'Tomato leaf',
23: 'Tomato leaf bacterial spot',
24: 'Tomato leaf late blight',
25: 'Tomato leaf mosaic virus',
26: 'Tomato leaf yellow virus',
27: 'Tomato mold leaf',
28: 'Tomato two spotted spider mites leaf',
29: 'grape leaf',
30: 'grape leaf black rot'
}

def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):
inputs = next(iter(inputs.take(1)))
images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
visualization.plot_bounding_box_gallery(
images,
value_range=value_range,
rows=rows,
cols=cols,
y_true=bounding_boxes,
scale=5,
font_scale=0.7,
bounding_box_format=bounding_box_format,
class_mapping=class_mapping,
)

# Visualize training set
visualize_dataset(
train_dataset, bounding_box_format=BBOX_FORMAT, value_range=(0, 255), rows=NUM_ROWS, cols=NUM_COLS
)

# Visualize validation set
visualize_dataset(
val_dataset, bounding_box_format=BBOX_FORMAT, value_range=(0, 255), rows=NUM_ROWS, cols=NUM_COLS
)

This type of visualization function is common in KerasCV. It plots a grid of images and boxes with the rows and columns specified in the arguments. We see that our training images have been slightly rotated, some have been horizontally or vertically flipped, they may have been zoomed in or out., and subtle changes in hue/saturation can be seen. With all augmentation layers in KerasCV, the bounding boxes also get augmented if necessary. Note that class_mapping is a simple dictionary. I got both keys and labels from the leaves_label_map.pbtxt text file mentioned earlier.

Examples of the original images on the left (validation set) and the augmented images on the right (training set). Images by author.

One last thing before looking at the RetinaNet model. Earlier we had to create the “dictionary within a dictionary” to get the data into a format compatible with KerasCV pre-processing, but now we need to convert it to a tuple of numbers to feed to our model for training. This is fairly straight forward to do:

def dict_to_tuple(inputs):
return inputs["images"], bounding_box.to_dense(
inputs["bounding_boxes"], max_boxes=32
)

train_dataset = train_dataset.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = val_dataset.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)

RetinaNet Model Background

One popular model for conducting object detection is called RetinaNet. A detailed description of the model is beyond the scope of this article. In brief, RetinaNet is a single-stage detector, meaning it only looks at the image once before predicting bounding boxes and classes. This is similar to the famous YOLO (You Only Look Once) model, but with some important differences. What I want to highlight here is the novel classification loss function used: the focal loss. This solves the issue of class imbalance in an image.

To understand why this is important, consider the following analogy: imagine you are a teacher in a room of 100 students. 95 students are loud and rambunctious, always yelling and raising their hands. 5 students are quiet and don’t say much. As the teacher you need to pay attention to everyone equally, but the loud students are crowding out the quiet ones. Here you have a problem of class imbalance. To fix the issue, you develop a special hearing aid that boosts the quiet students and deemphasizes the loud students. In this analogy, the loud students are the vast majority of background pixels in our images that do not contain leaves, while the quiet students are those small regions that do. The “hearing aid” is the focal loss, which allows us to focus our model on those pixels that contain leaves, without paying too much attention to those that do not.

There are three important components of the RetinaNet model:

  • A backbone. This forms the base of the model. We also call this a feature extractor. As the name suggests, it takes an image and scans for features. Low-level layers extract low-level features (e.g. lines and curves) while higher-level layers extract high-level features (e.g. lips and eyes). In this project the backbone will be a YOLOv8 model that has been pre-trained on the COCO dataset. We are only using YOLO only as a feature extractor, not as an object detector.
  • Feature pyramid network (FPN). This is a model architecture that generates a “pyramid” of feature maps at different scales to detect objects of various sizes. It does this by combining low-resolution, semantically strong features with high-resolution, semantically weak features via a top-down pathway and lateral connections. Take a look at this video for a detailed explanation or take a look at the paper [5] that introduced the FPN.
  • Two task-specific subnetworks. These subnetworks take each level of the pyramid and detects objects in each. One subnetwork identifies classes (classification) while the other identifies bounding boxes (regression). These subnetworks are untrained.
Simplified RetinaNet architecture. Image by author.

Earlier we resized the images to be of size 416 by 416. This is a somewhat arbitrary choice, although the object detection model you pick will often specify a desired minimum size. For the YOLOv8 backbone we are using, the image size should be divisible by 32. This is because the maximum stride of the backbone is 32 and it is a fully convolutional network. Do your homework on any model you use to figure out this factor for your own projects.

Training RetinaNet

Let’s begin by setting up some basic parameters, such as the optimizer and the metrics we will be using. Here we will be using Adam as our optimizer. Note the global_clip_norm argument. According to the KerasCV object detection guide:

You will always want to include a global_clipnorm when training object detection models. This is to remedy exploding gradient problems that frequently occur when training object detection models.

base_lr = 0.0001
# including a global_clipnorm is extremely important in object detection tasks
optimizer_Adam = tf.keras.optimizers.Adam(
learning_rate=base_lr,
global_clipnorm=10.0
)

We will follow their advice. For our metrics we will be using the BoxCOCOMetrics. These are popular metrics for object detection. They essentially consist of the mean average precision (mAP) and the mean average recall (mAR). In summary, mAP quantifies how effectively the model locates and identifies objects by measuring the average area of correct object detections relative to the total area covered by the model’s predictions. mAR is a different score that assesses the model's ability to capture the complete extent of objects by calculating the average ratio of the correctly identified object area to the actual object area. See this article for exact details on the metrics. This video does a great explanation of the basics of precision and recall.

coco_metrics = keras_cv.metrics.BoxCOCOMetrics(
bounding_box_format=BBOX_FORMAT, evaluate_freq=5
)

Because the box metrics are computationally expensive to compute, we pass the evaluate_freq=5 argument to tell our model to compute the metrics after every five batches rather than every single batch during training. I noticed that with too high a number the validation metrics weren’t being printed out at all.

Let’s continue by looking at the callbacks we will be using:

class VisualizeDetections(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs):
if (epoch+1)%5==0:
visualize_detections(
self.model, bounding_box_format=BBOX_FORMAT, dataset=val_dataset, rows=NUM_ROWS, cols=NUM_COLS
)

checkpoint_path="best-custom-model"

callbacks_list = [
# Conducting early stopping to stop after 6 epochs of non-improving validation loss
keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=6,
),

# Saving the best model
keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
monitor="val_loss",
save_best_only=True,
save_weights_only=True
),

# Custom metrics printing after each epoch
tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs:
print(f"\nEpoch #{epoch+1} \n" +
f"Loss: {logs['loss']:.4f} \n" +
f"mAP: {logs['MaP']:.4f} \n" +
f"Validation Loss: {logs['val_loss']:.4f} \n" +
f"Validation mAP: {logs['val_MaP']:.4f} \n")
),

# Visualizing results after each five epochs
VisualizeDetections()
]
  • Early stopping. If the validation loss has not improved after six epochs, we will stop training.
  • Model checkpoint. We will be checking the val_loss after every epoch, saving the model weights if it bests an earlier epoch.
  • Lambda callback. A lambda callback is a custom callback that allows you to define and execute arbitrary Python functions during training at different points in each epoch. In this case we are using it to print custom metrics after each epoch. If you just print the COCOMetrics out it is a mess of numbers. For simplicity, we will only print out the training and validation loss and mAP.
  • Visualization of detections. This will print out a 4 by 8 grid of images along with predicted bounding boxes after every five epochs. This will give us insight into how good (or terrible) our model is. If all goes well these visualizations should get better as training progresses.

Finally we create our model. Recall that the backbone is a YOLOv8 model. We must pass the num_classes we will be using, as well as the bounding_box_format.

# Building a RetinaNet model with a backbone trained on coco datset
def create_model():
model = keras_cv.models.RetinaNet.from_preset(
"yolo_v8_m_backbone_coco",
num_classes=len(class_mapping),
bounding_box_format=BBOX_FORMAT
)
return model

model = create_model()

We also have to customize the non-max suppression parameter of our model. Non-max suppression is used in object detection to filter out multiple overlapping predicted bounding boxes that correspond to the same object. It only keeps the box with the highest confidence score and removes redundant boxes, ensuring that each object is detected only once. It incorporates two parameters: the iou_threshold and the confidence_threshold:

  1. IoU, or intersection over union, is a number between 0 and 1 that measures how much overlap there is between one predicted box and another predicted box. If the overlap is higher than the iou_threshold then the predicted box with the lower confidence score is thrown away.
  2. The confidence score reflects the model’s confidence in its predicted bounding box. If the confidence score for a predicted box is lower than the confidence_threshold the box is thrown away.

Although these parameters do not affect training, they still need to be tuned to your particular application for prediction purposes. Setting iou_threshold=0.5 and confidence_threshold=0.5 is a good starting place.

One note before beginning training: we discussed why it is helpful for the classification loss to be the focal loss, but we have not discussed a suitable regression loss to define the error on our predicted bounding boxes coordinates. A popular regression loss (or box_loss) is the smooth L1 loss. I think of smooth L1 as a “best of both worlds” loss. It incorporates both the L1 loss (absolute error) and the L2 loss (mean squared error). The loss is quadratic for small error values, and linear for large error values (see this link). KerasCV has a built-in smooth L1 loss for our convenience. The loss that will be displayed during training will be the sum of box_loss and classification_loss .

# Using focal classification loss and smoothl1 box loss with coco metrics
model.compile(
classification_loss="focal",
box_loss="smoothl1",
optimizer=optimizer_Adam,
metrics=[coco_metrics]
)

history = model.fit(
train_dataset,
validation_data=validation_dataset,
epochs=40,
callbacks=callbacks_list,
verbose=0,
)

Training on an NVIDIA Tesla P100 GPU takes about one hour and 12 minutes.

Making Predictions

# Create model with the weights of the best model
model = create_model()
model.load_weights(checkpoint_path)

# Customizing non-max supression of model prediction. I found these numbers to work fairly well
model.prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(
bounding_box_format=BBOX_FORMAT,
from_logits=True,
iou_threshold=0.2,
confidence_threshold=0.6,
)

# Visuaize on validation set
visualize_detections(model, dataset=val_dataset, bounding_box_format=BBOX_FORMAT, rows=NUM_ROWS, cols=NUM_COLS)

Now we can load the best model seen during training and use it to make some predictions on the validation set:

Sample visual of validation set predictions. Image by author.

The metrics on our best model are:

  • Loss: 0.4185
  • mAP: 0.2182
  • Validation Loss: 0.4584
  • Validation mAP: 0.2916

Respectable, but this can be improved. More on this in the conclusion. (Note: I noticed that MultiClassNonMaxSuppression does not seem to be working correctly. The bottom left image shown above clearly has boxes that overlap with more than 20% of their area, yet the lower confidence box is not suppressed. This is something I will have to look more into.)

Here is a plot of our training and validation losses per epoch. Some overfitting is seen. Also, it may be wise to add in a learning rate schedule to decrease the learning rate over time. This may help resolve the issue of large jumps being made near the end of training.

A plot of our training and validation losses per epoch. We are seeing signs of overfitting. Image by author.

Conclusion

If you have made it this far give yourself a pat on the back! Object detection is among the more difficult tasks in computer vision. Luckily for us we have the new KerasCV library to make our lives easier. To summarize the workflow for creating an object detection pipeline:

  • Begin by visualizing your dataset. Ask yourself questions like: “What is my bounding box format? Is it xyxy? Relxyxy? How many classes am I dealing with?” Make sure to create a function similar to visualize_dataset to look at your images and bounding boxes.
  • Convert whatever format of data you have into the “dictionary within a dictionary” format that KerasCV wants. Using a TensorFlow Dataset object to hold the data is especially helpful.
  • Do some basic pre-processing, such as image re-sizing and data augmentation. KerasCV makes this fairly simple. Take care to read the literature on your model of choice to make sure the image sizes are appropriate.
  • Convert the dictionaries back into tuples for training.
  • Select an optimizer (Adam is an easy choice), two loss functions (focal for the class loss and L1 smooth for the box loss are easy choices), and metrics (COCO metrics are an easy choice).
  • Visualizing your detections during training can be instructive to see what sorts of objects your model is missing.
Example of a problematic label in the dataset. Image by author.

One of the primary next steps would be to clean up the dataset. For example, take a look at the image above. The labelers correctly identified the potato leaf late blight, but what about all of the other healthy potato leaves? Why were these not labeled as potato leaf? Looking at the health check tab on the Roboflow website, you can see that some classes are vastly underrepresented in the dataset:

Chart showing the class imbalance. From Roboflow’s website.

Try fixing these issues before tweaking any hyperparameters. Best of luck on your object detection tasks!

References

[1] F. Chollet, Deep Learning with Python (2021), Manning Publications Co.

[2] A. Géron, Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow (2022), O’Reily Media Inc.

[3] A. Ng, Deep Learning Specialization, DeepLearning.AI

[4] D. Singh, N. Jain, P. Jain, P. Kayal, S. Kumawat, N. Batra, PlantDoc: A Dataset for Visual Plant Disease Detection (2019), CoDS COMAD 2020

[5] T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, S. Belongie, Feature Pyramid Networks for Object Detection (2017), CVPR 2017

[6] T. Lin, P. Goyal, R. Girshick, K. He, P. Dollar, Focal Loss for Object Detection (2020), IEEE Transactions on Pattern Analysis and Machine Intelligence

--

--

Machine learning engineer. Focus on computer vision and natural language processing. A former college physics educator.