Building Image Classification API with Tensorflow and FastAPI

Learn to build an image classification API with Tensorflow and FastAPI from scratch.

Aniket Maurya
Towards Data Science

--

Source: aniketmaurya

FastAPI is a high-performance asynchronous framework for building APIs in Python.

Video tutorial is also available for this blog

Source code for this blog is available aniketmaurya/tensorflow-fastapi-starter-pack

Let's start with a simple hello-world example

First, we import FastAPI class and create an object app. This class has useful parameters like we can pass the title and description for Swagger UI.

from fastapi import FastAPI
app = FastAPI(title='Hello world')

We define a function and decorate it with @app.get. This means that our API /index supports the GET method. The function defined here is async, FastAPI automatically takes care of async and without async methods by creating a thread pool for the normal def functions and it uses an async event loop for async functions.

@app.get('/index')
async def hello_world():
return "hello world"

Image recognition API

We will create an API to classify images, we name it predict/image. We will use Tensorflow for creating the image classification model.

Tutorial for Image Classification with Tensorflow

We create a function load_model, which will return a MobileNet CNN Model with pre-trained weights i.e. it is already trained to classify 1000 unique categories of images.

import tensorflow as tfdef load_model():
model = tf.keras.applications.MobileNetV2(weights="imagenet")
print("Model loaded")
return model
model = load_model()

We define a predict function that will accept an image and returns the predictions. We resize the image to 224x224 and normalize the pixel values to be in [-1, 1].

from tensorflow.keras.applications.imagenet_utils import decode_predictions

decode_predictions is used to decode the class name of the predicted object. Here we will return the top-2 probable class.

def predict(image: Image.Image):    image = np.asarray(image.resize((224, 224)))[..., :3]
image = np.expand_dims(image, 0)
image = image / 127.5 - 1.0
result = decode_predictions(model.predict(image), 2)[0] response = []
for i, res in enumerate(result):
resp = {}
resp["class"] = res[1]
resp["confidence"] = f"{res[2]*100:0.2f} %"
response.append(resp) return response

Now we will create an API /predict/image that supports file upload. We will filter the file extension to support only jpg, jpeg, and png format of images.

We will use Pillow to load the uploaded image.

def read_imagefile(file) -> Image.Image:
image = Image.open(BytesIO(file))
return image
@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
if not extension:
return "Image must be jpg or png format!"
image = read_imagefile(await file.read())
prediction = predict(image)
return prediction

Final code

import uvicorn
from fastapi import FastAPI, File, UploadFile
from application.components import predict, read_imagefileapp = FastAPI()@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
if not extension:
return "Image must be jpg or png format!"
image = read_imagefile(await file.read())
prediction = predict(image)
return prediction
@app.post("/api/covid-symptom-check")
def check_risk(symptom: Symptom):
return symptom_check.get_risk_level(symptom)
if __name__ == "__main__":
uvicorn.run(app, debug=True)

FastAPI documentation is the best place to learn more about core concepts of the framework.

Hope you liked the article.

Feel free to ask your questions in the comments or reach me out personally

๐Ÿ‘‰ Twitter: https://twitter.com/aniketmaurya

๐Ÿ‘‰ Linkedin: https://linkedin.com/in/aniketmaurya

--

--

ML Engineer at Lightning AI (PyTorch Lightning) โšก๏ธ ; Creator of Gradsflow ; https://aniketmaurya.com