Building Image Classification API with Tensorflow and FastAPI
Learn to build an image classification API with Tensorflow and FastAPI from scratch.
FastAPI is a high-performance asynchronous framework for building APIs in Python.
Video tutorial is also available for this blog
Source code for this blog is available aniketmaurya/tensorflow-fastapi-starter-pack
Let's start with a simple hello-world example
First, we import FastAPI
class and create an object app
. This class has useful parameters like we can pass the title and description for Swagger UI.
from fastapi import FastAPI
app = FastAPI(title='Hello world')
We define a function and decorate it with @app.get
. This means that our API /index
supports the GET method. The function defined here is async, FastAPI automatically takes care of async and without async methods by creating a thread pool for the normal def functions and it uses an async event loop for async functions.
@app.get('/index')
async def hello_world():
return "hello world"
Image recognition API
We will create an API to classify images, we name it predict/image
. We will use Tensorflow for creating the image classification model.
Tutorial for Image Classification with Tensorflow
We create a function load_model
, which will return a MobileNet CNN Model with pre-trained weights i.e. it is already trained to classify 1000 unique categories of images.
import tensorflow as tfdef load_model():
model = tf.keras.applications.MobileNetV2(weights="imagenet")
print("Model loaded")
return modelmodel = load_model()
We define a predict
function that will accept an image and returns the predictions. We resize the image to 224x224 and normalize the pixel values to be in [-1, 1].
from tensorflow.keras.applications.imagenet_utils import decode_predictions
decode_predictions
is used to decode the class name of the predicted object. Here we will return the top-2 probable class.
def predict(image: Image.Image): image = np.asarray(image.resize((224, 224)))[..., :3]
image = np.expand_dims(image, 0)
image = image / 127.5 - 1.0 result = decode_predictions(model.predict(image), 2)[0] response = []
for i, res in enumerate(result):
resp = {}
resp["class"] = res[1]
resp["confidence"] = f"{res[2]*100:0.2f} %" response.append(resp) return response
Now we will create an API /predict/image
that supports file upload. We will filter the file extension to support only jpg, jpeg, and png format of images.
We will use Pillow to load the uploaded image.
def read_imagefile(file) -> Image.Image:
image = Image.open(BytesIO(file))
return image@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
if not extension:
return "Image must be jpg or png format!"
image = read_imagefile(await file.read())
prediction = predict(image) return prediction
Final code
import uvicorn
from fastapi import FastAPI, File, UploadFilefrom application.components import predict, read_imagefileapp = FastAPI()@app.post("/predict/image")
async def predict_api(file: UploadFile = File(...)):
extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
if not extension:
return "Image must be jpg or png format!"
image = read_imagefile(await file.read())
prediction = predict(image) return prediction
@app.post("/api/covid-symptom-check")
def check_risk(symptom: Symptom):
return symptom_check.get_risk_level(symptom)
if __name__ == "__main__":
uvicorn.run(app, debug=True)
FastAPI documentation is the best place to learn more about core concepts of the framework.
Hope you liked the article.
Feel free to ask your questions in the comments or reach me out personally
๐ Twitter: https://twitter.com/aniketmaurya
๐ Linkedin: https://linkedin.com/in/aniketmaurya