Build your object detection model to be flexible and easily-retrainable when introducing a new class of detectable objects.
The ever-challenging question in training and deploying object detection neural network models – how do I add a new class to my already trained model? Usually, that task becomes a titanic endeavor which includes retraining the entire model with an extended dataset. Training of an object detection model can take many hours, if not days, not to mention the time required to label additional examples with a new class we want to introduce. And what if those new objects also appear in the already labeled images? We would have to go back and review them all!
There might a better way to go about this. I would like to present you with the solution I came up with for our client. For obvious reasons, I’m unable to offer that exact solution. Still, I’m going to explain the architecture used in that solution and how it performs on a different dataset – in this case, The Oxford-IIIT Pet Dataset.
We’re gonna go over the training of an object detection model, a detector, with TensorFlow Object Detection API and using that model to extract data for our Classification model. After which we’re gonna train the classifier to classify cat breeds. The next steps will focus on how this architecture aids us in solving the problem presented in the first sentence – introducing another class (cat breed) to our classifier.
I won’t be including too much code in this article; for that, please have a look at the notebook on my GitHub. Feel free to run it on your side and test the architecture with a different dataset.
Ready? Let’s go!
1. Detector – TensorFlow Object Detection
Dataset
As mentioned – the dataset of choice is The Oxford-IIIT Pet Dataset, which contains "a 37 category pet dataset with roughly 200 images for each class". That’s a lot of data to train our models with.
This dataset also contains annotations – the exact location of the animal’s face in the picture. We’ll build our detector to detect cat faces only, so that requires filtering only cat breeds from the dataset. Additionally – let’s take out three breeds to add them later to test how good the detector performs when facing a cat breed it has never seen before. This will help us test how well the model generalized the idea of what "a cat" is.
Object Detection Model
For object detection, we’re gonna use TensorFlow Object Detection (TFOD). Cat detection model is based on EfficientDet D2 model from TensorFlow Detection Model Zoo.
After installing all the necessary libraries and preparing the dataset we can launch model training – this can take up to a few hours! However, this step is gonna be done only once, and the cat detector should serve its purpose no matter what cat breed is found in the picture.
Saving the TFOD model looks complicated but is pretty straightforward – we run a script installed with the libraries, choose input type – this can be an image tensor, or a B64-encoded string – , point to our config file, checkpoint directory, and finally the output directory.
Look into the notebook for a more detailed explanation of the training process.
We have our model trained and ready, so let’s test in on a picture with lots of cats!

Uh oh… This doesn’t look right. Why did the model detect only one cat if there are four in the picture?
The reason is quite simple – our dataset had only one annotation per image, so that’s also what our model learned to do – detect only one cat per image. Anything else – even if it is a cat – is assigned very low confidence; hence it won’t get through to our output as we set our confidence threshold at 50%.
Let’s see how the model performs on crops of the above image:

Much better! Cat detector managed to detect cats in all but one crop. We could expect that this model would perform even better with more training time and/or more data, but what we have now is good – the model learned what a cat is!
But why do we even need a detector model if the locations of cat faces in the images within our dataset are already known? With this model, we can create cutouts of cat faces from any image, as long as there is a cat face recognizable by the model. Additionally – it’s better to train the classifier on data it will be served to classify, which are crops made from detector output. And last but not least – the classifier can focus on analyzing cat faces only without the need to process much larger images, which in some cases could lead to a loss of important information and in effect – a wrong classification.
2. Classifier – TensorFlow Transfer Learning

And now we can use the detector to go over the cat dataset to obtain cutouts for the training of our classifier. For now, only the cat breeds we used for training the detection model.
This is also a good place to evaluate how well the detector performs!
Breeds detecting network never seen before:
Breed: Russian Blue number of files:191
Breed: British Shorthair number of files:185
Breed: Maine Coon number of files:200
Breed: Persian number of files:160
Breed: Egyptian Mau number of files:180
Breed: Ragdoll number of files:183
Breeds used for training:
Breed: Abyssinian number of files:190
Breed: Bengal number of files:176
Breed: Birman number of files:196
Breed: Bombay number of files:176
Breed: Sphynx number of files:197
Breed: Siamese number of files:199
Assuming each breed has around 200 examples, we can see that our recall is well above 85–90% for each of the classes. Keep in mind that the detector was trained on a small subset of the chosen breeds data, as we couldn’t use all images due to lacking .xml annotation files. Despite that – there isn’t any significant difference in recall between breeds used for training and the rest. This means that the detector performs very well and understands what "a cat" generally is.
Having these cutouts – we can finally train our classifier using transfer learning. This way we save a lot of time – training done on Google Colab with GPU runtime enabled takes less than 5 minutes to reach 85% accuracy! Such a short training time is extremely important if we want to have our classifier regularly updated with new data and classes.
Transfer learning can be done in a couple of lines of code:
First of all – we load the Xception model, which is an already trained image recognition model. We omit the top classification layers of the model so that we can add classification layers suitable for our detection. We freeze the loaded model so that it won’t be affected by the training – it’s already trained to extract vital features from images, a task that takes the most time and effort to complete. With transfer learning, we don’t need to do this again; we only train – in this case – 6 layers, instead of 132 layers the Xception model is built of.


The final accuracy on the training set was 82.19%, and 85.02% on the validation set. The training lasted for 20 epochs – increasing that number could also increase the final accuracy of the model. The time per epoch was 6 seconds + a few seconds for the validation phase.
3. Adding a New Class – "undefined/unknown"
A bit of theory
Our classifier was trained on 6 cat breeds only. Now, what if we show it a face cutout of a cat breed it wasn’t trained to recognize? It will have to classify it as one of these 6 classes, and no matter which class it will be – this classification will be wrong.
Additionally – what if our detector outputs cutouts of something that’s not even a cat? Again – our classifier will still classify it as one of those 6 breeds.
That’s the reason why we introduce a negative class, or an umbrella class to gather everything that our detector outputs that doesn’t belong to any of the classes we want our classifier to learn.
In the cat example, the detector doesn’t output anything that we could classify into that "undefined/unknown" group, but we can simulate it by running the detector on different images of cats taken from the internet.
To better understand this issue, let’s consider a different application of this architecture – car brand logo detection. In this case, we could expect that our detector would output cutouts that are not car logos. Depending on the confidence threshold set in our algorithm we could manipulate how much wrong cutouts our detector outputs. Keep in mind, however, that with a higher confidence threshold, you’re also sacrificing correct detections which don’t reach that threshold. What we want to do is to maximize recall without sacrificing precision too much. What are those parameters?
Let’s say you have a picture with 10 cars; each car has one car brand logo on its front bumper. Your model detects 12 car brand logos, but only 8 of them are car logos; the remaining 4 detections are a weird reflection in the windshield, a blurry sticker on the fence in the background, a circular oil stain on the pavement, and a letter on one of the car’s plates.
Recall – if your model detects 8 brand logos out of 10 total brand logos in the picture, it means its recall is 80%.
Precision – how many of what we detected are car logos, in this case, 8/12 or 75%. In this example, precision is also the accuracy, since we classify only one label.
Now, we could lower the confidence threshold to also catch those two undetected logos, but with that – precision would drop. This is because, alongside these two new detections, some number of wrong ones would also be included in the output, so in the end, we’d end up with 20 cutouts – 10 car brand logos and 10 anything else. 100% recall, 50% precision.
A threshold of 50% is a good middle ground, but you can increase or decrease it depending on your use-case and the performance of your detector, as well as the performance of the classifier.
Adding a new class
To add a new class you need to run the detection process on a new set of images. In this case – it’s pictures of cats we downloaded from Bing. If you want to introduce a new cat breed from the dataset – just run the detection process on those images.
In situations when your detector outputs not only the cutouts of objects that you want to classify but also erroneous ones – as explained in the car logo detection example – you need to go over the whole set of cutouts and first assigned them to proper folders.
Back to cats!
After running the detector over pictures of cats we end up with 150 new cutouts.7 classes – 6 cat breeds we extracted before and one class "unknown" that contains pictures of cats which doesn’t belong to any of those 6 classes.
Now we can train the classifier so that it’s able to not only classify 6 of the breeds but also classify other cats as belonging to none of the breeds.


After 20 epochs accuracy on the training set was 87.50%, and 76.19% on the validation set. It is slightly lower than in the previous 6-class only training, Still, the unknown class might contain breeds already defined and used for training, meaning they confuse the model and decrease its accuracy.
3. Final training on all classes
Final training – 12 classes + unknown


After 40 epochs – so 20 epochs longer than the previous training – the model reached 78.59% accuracy on the training set and 74.45% on the validation set.
Out of curiosity— let’s see how the training process looks like without the "unknown" class:


The final accuracy on the train set is 86.25% and 79.74% on the validation dataset. Higher than before! Unknown class decreases the accuracy of the model as -as mentioned before – it might contain cats that belong to the breeds already defined in the dataset. Additionally – one more class means a bigger dataset which requires more training time and maybe even alteration in the model architecture.
4. Key takeaways
Ok, now that we went through the process and saw it work pretty well – let’s sum this all up!
- The architecture consists of two neural networks – Detector and Classifier. A detector is an Object Detection Neural Network. This one we train – hopefully – only once. We train it to recognize only one class that encapsulates the general features of what it is we want to classify – a cat, a mobile app, a car brand logo.
- The detector is then used to extract cutouts of detected objects. Depending on the confidence threshold and quality of the model – the number of false detections might be higher or lower. The idea is to catch all correct classes and as few wrong ones as possible – high recall, fairly good precision.
- Having all the cutouts it’s time for some manual work – every extracted cutout has to be now put in an appropriate folder. Those are folders for each class to be recognized by the classifier and one folder to store everything not belonging to any of those classes.
- A classifier is then trained on that data. It’s worth building the architecture with the use of transfer learning. This reduces the necessary training time making continuous learning or adding a new class a relatively smooth and rapid process.
- To introduce a new class the process has to be repeated for a new data – extraction of cutouts, manual categorization into appropriate classes, and then re-training the classifier to recognize n+m classes, where n is the previous number of classes and m is the number of newly introduced classes.
- Finally, our detection process looks like this: the image goes to the detector which outputs cutouts. Cutouts are then fed to the classifier to obtain the final classification of each cutout and thus – all the detected objects in the analyzed image.
I hope you found this article insightful and helpful in solving your Machine Learning challenges. Please feel free to ask questions regarding the method or the code – be it down below or on GitHub.
Tschüss!