Full working example to serve your model using asynchronous Celery tasks and FastAPI.

Overview
There is an abundance of material online related to building and training all kinds of machine learning models. However once a high performance model has been trained there is significantly less material for how to put it into production.
This post walks through a working example for serving a ML model using Celery and FastAPI. All code can be found in the repository here.
We won’t specifically discuss the ML model used for this example however it was trained using example Bank customer churn data (https://www.kaggle.com/sakshigoyal7/credit-card-customers). There is a notebook in the repository outlining the training for a LightGBM model including hyperparameter optimization and performance evaluation.
Potential Options
Below is a summary of potential approaches for deploying your trained models to production:
- Load model directly in application: this option involves having the pretrained model directly in the main application code. For small models this might be feasible however large models may introduce memory issues. This option also introduces a direct dependency on the model within the main application (coupled).
- Offline batch prediction: Use cases that do not require near real-time predictions can make use of this option. The model can be used the make predictions for a batch of data in a process that runs at defined intervals (e.g. overnight). The predictions can then be utilized by the application once the batch job is complete. Resource for prediction is only required when the batch process runs which can be beneficial.
- API: The third option is to deploy the model as its own microservice and communicate with it via an API. This decouples the application from the model and allows it to be utilized from multiple other services. The ML service can serve requests in one of the two ways described below.
Synchronous: the client requests a prediction and must wait for the model service to return a prediction. This is suitable for small models that require a small number of computations, or where the client cannot continue other processing steps without a prediction.
Asynchronous: instead of directly returning a prediction the model service will return a unique identifier for a task. Whilst the prediction task is being completed by the model service the client is free to continue other processing. The result can then be fetched via a results endpoint using the unique task id.
Process Flow
The steps below describe the actions taken to handle a prediction request:

- Client sends a POST request to the FastAPI prediction endpoint, with the relevant feature information contained in the request body (JSON).
- The request body is validated by FastAPI against a defined model (i.e. checks if the expected features have been provided). If the validation is successful then a Celery prediction task is created and passed to the configured broker (e.g. RabbitMQ).
- The unique id is returned to the client if a task is created successfully.
- The prediction task is delivered to an available worker by the broker. Once delivered the worker generates a prediction using the pretrained ML model.
- Once a prediction has been generated the result is stored using the Celery backend (e.g. Redis).
- At any point after step 3 the client can begin to poll the FastAPI results endpoint using the unique task id. Once the prediction is ready it will be returned to the client.
Now let’s look at some example code that implements this architecture.
Project Structure
The project structure is as follows:
serving_ml
│ app.py
│ models.py
│ README.md
│ requirements.txt
│ test_client.py
│
├───celery_task_app
│ │ tasks.py
│ │ worker.py
│ │ __init__.py
│ │
│ ├───ml
│ │ │ model.py
│ │ │ __init__.py
- app.py: FastAPI application including route definitions.
- models.py: Pydantic model definitions that are used for the API validation and response structure.
- _testclient.py: Script used for testing the set-up. We’ll cover this in more detail later.
- _celery_taskapptasks.py: Contains Celery task definition, specifically the prediction task in our case.
- _celery_taskappworker.py: Defines the celery app instance and associated config.
- _celery_taskappmlmodel.py: Machine Learning model wrapper class used to load pretrained model and serve predictions.
ML Model
First let’s look at how we are going to load the pretrained model and calculate predictions. The code below defines a wrapper class for a pretrained model that loads from file on creation and calculates class probability or membership in its predict method.
This implementation can be re-used for a variety of ML models as long as the model has predict and predict_proba methods (i.e. Scikit-Learn or Keras implementations).
In our example the saved model is in fact a Scikit-Learn pipeline object the contains a preprocessing step, so we don’t need to worry about having additonal preprocessing code before prediction. We can simply create a DataFrame with the feature data and call the pipeline predict method.
Celery
Celery is a simple task queue implementation that can used to distribute tasks across threads and/or machines. The implementation requires a broker and optionally a backend:
- Broker: This is used to deliver messages between clients and workers. To initiate a task the client adds a message to the queue, the broker then delivers that message to a worker. RabbitMQ is often used as the broker and is the default used by Celery.
- Backend: This is optional and its only function is to store task results to be retrieved at a later date. Redis is commonly used as the backend.
First let’s look at how we define our Celery app instance:
This app definition is very simple however there are a wide range of additional config options that can be defined (e.g. timezone, serialization).
The include argument is used to specify modules where tasks are defined for the Celery app. In this case we define a single task in tasks.py:
Our task implementation is slightly more complex than usual. Most simple tasks can be defined using the task decorator which overrides the run method of Celery’s base task class. However if we were to use this approach it would lead to a model class being defined and hence loaded from disk for each task processed.
By extending the Celery Task object we can override the default behavior such that the ML model is loaded only once when the task is first called. Subsequent calls can then use the same loaded model.
Although this is better than loading the model each time a task is run, there are still some considerations:
- There will be a PredictTask object for each worker process. Therefore if a worker has four threads then the model will be loaded four times (each is then stored in memory as well…).
- This introduces a cold start scenario for each worker process, where the first task will be slow as the model needs to be loaded. There are different approaches that can be used to tackle this, and the fact that task are asynchronous makes this less of an issue.
We can then use the usual decorator approach and specify base=PredictTask so Celery knows to use our custom Task class instead of the default. The bind parameter allows us to access the model attribute using self (as if we were defining the run method directly in the class).
FastAPI
Finally we can create an API that will be used to generate tasks and fetch results from the backend based on client requests.
Two endpoints are required:
- …/churn/predict (POST): Client sends JSON containing required features. Unique task id returned.
- …/churn/result/: (GET): Checks is task result is available in the backend, and returns the prediction if it is.
We implement these endpoints as follows:
A specific response is implemented within the results route to handle cases where the task result is not ready. This can be used by the client to poll the result endpoint until the result is ready (we’ll use this for testing…).
Testing
To check everything is working as expected we can create a separate Python script that imitates a prediction request. To do so we make use of Python’s requests package.
Once both the broker and backend servers are running, we can start the API using uvicorn:
uvicorn app:app
Next start a worker process:
celery -A celery_task_app.worker worker -l info
Running _testclient.py will do the following (see repo for code):
- An example feature JSON is sent to the predict endpoint. The features are hardcoded in a dictionary within the script.
- If successful, the task id returned will be used to poll the results endpoint. The dummy client will wait 5 seconds between each request and will make a maximum of 5 attempts.
- If the task is successful, the result (in this case the probability of membership to class 1) will be printed.
(venv) python.exe test_client.py
0.011178750583075114
Next Steps
The solution discussed above is simply a working example and should be adapted with more advanced Celery and FastAPI configuration for full production use.
Another possibility would be to dockerize the entire solution such that it can be deployed easily on cloud infrastructure.