Building a Deep Learning Model API for Tagging Trout Species

A full-stack Deep Learning Research Project for Identifying Trout Species & More

Perry Johnson
Towards Data Science

--

This blog is in reference to on-going work for the Aqua Vision Deep Learning Research Project

I recently got into dry fly fishing in Montana. It’s a technical sport and I’m not very good (yet!) but it’s a blast getting out on the river and learning more each time I’m out there.

Rock Creek, Montana

One thing I’ve noticed when I’m out fishing for trout is that I have little knowledge of the different types of trout species. When I catch trout (which is still not very often) I find myself Google-ing photos and comparing them to what I had caught. I can barely differentiate between the different species. It made me think, maybe there is a way to programmatically identify trout species with computer vision deep learning. I asked a few friends, who are professional fishing guides and die-hard amateurs about this idea, as I thought image recognition could be helpful, and this is what they said:

“Trout have ~30 species so it’s hard to tell some of them apart.”

“This could be very interesting for all of the hybrid trout species like cut-bows” (a cut-bow is a cutthroat and rainbow trout crossbreed)

I set out to ask the following questions:

  1. Can I build an accurate deep learning computer vision model to identify trout species from images?
  2. Can I build a web-application, REST API, and mobile application as supporting infrastructure for the deep learning model?
  3. Can I build an open-source trout photo collection library to encourage further deep learning model development, image collection/labeling and trout research?

Data Collection and Machine Learning Pipeline

The full data pipeline for building Aqua Vision Deep Learning, a Deep Learning Application and API

Data

There isn’t a comprehensive labeled dataset available for trout images (let alone any other fish species) so I built my own dataset of 500+ images from scratch. I collected images using a few different methods outlined below.

Data Sources:

  • Google Image Search: Ran a Javascript command in the console to scrape images and download files
  • Command Line tool for Instagram Scraper by hashtag (i.e. #browntrout, #rainbowtrout)
  • Downloaded photos from fishing blogs and research outlets

After collecting the images, I manually reviewed all of the photos to ensure that the images matched their respective trout species label and removed any miscellaneous images that didn’t belong (fishing equipment, a river, tackle box, incorrectly labeled fish species, etc). The data collection and cleaning phase took about three weeks. Now equipped with 500+ images across a handful of trout species, I had an initial image dataset to start building a deep learning trout identification model.

Building the Deep Learning Model

The Fastai deep learning library has some really nice built-in tools to scrape images from Google Image Search and then load the image files into an ImageDataBunch object. The ImageDatabunch can then apply data augmentation transformers, setting the size of image pixel resolution, loading in transfer learning and splitting the data into training, validation and testing sets.

For data augmentation, I kept it pretty simple with default settings but I did randomly flip images horizontally to help normalize for images with fish facing in each direction.

Data Augmentation: Horizontal Flip example for Brown Trout

Data augmentation is really nice because it essentially helps normalize the existing dataset to train on more examples and the augmented images all look pretty different but we didn’t have to do any extra manual labeling so they’re like free extra data. I’ve found that experimenting with data augmentation parameters coupled with domain expertise of the dataset is really the way to get the most out of augmenting the data.

I took advantage of transfer learning by preloading the model weights from ImageNet instead of initializing deep learning model weights from scratch. ImageNet is a dataset of millions of labeled high-resolution images belonging roughly to 22k categories. The images were collected from the internet and labeled by humans using a crowd-sourcing tool. Transfer learning is a great way to leverage a previously trained model and reduce compute cost to then fine tune with domain specific data to perform a domain specific task.

The ImageDataBunch is then loaded into a Learner object along with a neural network architecture and loss function to train a deep learning model.

learn = cnn_learner(ImageDataBunch, models.resnet.34, metrics=error_rate)

A little trick to improve model performance is to start training at a smaller image size (224) and then take this fitted model and fit it some more on a bigger image size (352). This really helps performance in almost all cases.

I trained the model on 70% of the images and reserved the remaining 30% for testing. I like using Paperspace to train deep learning models as they have free GPU-backed Jupyter notebooks that are really easy to set up and load in deep learning libraries.

After fitting the data to a deeper pixel resolution on a few runs through the data (epochs), the model peaked at 96% classification accuracy on test set data. This was good enough to ship as the first version deep learning trout identifier for the web-app and API.

Deploying the Deep Learning Model as an API & Web-App

I exported the trained deep learning model to use in a production environment to make predictions. I wrapped the model in Python’s web framework Flask to make predictions via a web-application and a REST API. The web-app and API documentation supporting this Aqua Vision Deep Learning Research Project can be found here.

Rest API

There are currently two API endpoints available: one to get the classes (species) currently supported by the deep learning model and one to make species predictions by posting an image file. You can test the API by making the following cURL commands from your command line.

Classes (Species) currently supported:

You can see which trout species the model currently supports at the endpoint:

GET /api/classes

curl -X GET “https://aqua-vision.herokuapp.com/api/classes"

You’ll get a return like this: ["brown_trout", "bull_trout", "rainbow_trout"]

Species Predictions

The predictions endpoint is at:

POST /api/predict

curl -X POST -F image=@brown_trout.png “https://aqua-vision.herokuapp.com/api/predict"

The -X flag and POST value indicates we're performing a POST request. We supply -F image=@brown_trout.jpg to indicate we're submitting form encoded data. The image key is then set to the contents of the brown_trout.jpg file. Supplying the @ prior to brown_trout.jpg implies we would like cURL to load the contents of the image and pass the data to the request.

You’ll get a return like this: {“class”:”rainbow_trout”,”success”:true}

Accessing the API Programmatically via Python

Here is a simple Python script using the requests library to submit data to the API and then consume the returned predictions.

Ongoing Work

You can check out the on-going progress by visiting the Project’s Website and the GitHub Issues.

The continued goals of the project are to:

  1. Build an Open-Source Trout Photo Library to continue to improve upon the deep learning model accuracy and increase the number of supporting classes for trout species identification (eventually expanding to other fish types)
  2. Label Images with Bounding Boxes to incorporate Object Detection into the Deep Learning image classification
  3. Geotag Photos with coordinates and timestamp for future Trout species research (population distributions, migratory patterns, cross-breeding, geospatial analysis, etc)
  4. Build a supporting native Mobile Application

Code

The code for this project can be found on my GitHub

If you’d like to contribute to this project or if you have any comments or questions shoot me an email at: perryrjohnson7@gmail.com or reach out on LinkedIn.

--

--