Machine Learning Tips

So you have trained your ML/DL model and now thinking to deploy it on the cloud. One of most preferred way is to wrap your model in a Flask app and serve it using a REST API. But hold on, let’s add some security to your app. I’ll be briefly explaining the code with the help of a real-world example and also why its important consider the security aspects of your app. I’ll try to be as straight forward as I can with the implementation part. But first, like any other tutorial, let’s understand what is CSRF and why do we actually need it?
What the heck is CSRF?
Cross site request forgery (CSRF), also known as XSRF, Sea Surf or Session Riding, is an attack vector that tricks a web browser into executing an unwanted action in an application to which a user is logged in.
What does that mean? Let me explain it in the simplest term.
I was working on a personal deep learning project where it takes an image (Indian Paper Currency) as an input and predicts the class of image from 10, 20, 50, 100, 200, 500, 2000 new denomination of Indian currency.
Question: Well I’m not a security guy but while working on this project, one thing I noticed that my API endpoints are exposed to public now. Now the question is why should you care if it is exposed? What does that even mean for a developer?
Answer: In my case, any person on the internet can send a legitimate POST request to my app with the required data and get the predictions. You don’t really want yourself in such scenario where you’re storing the uploaded image and prediction in a S3 Bucket. If you’re using free tier of AWS (like me!), then these bad guys can consume your free quota of free tier in no time and it would eventually get you a bill like this.

I tried to reproduce the scenario on my local machine and with no surprise, I was able to get the predictions in my terminal. After successfully executing the command, the input image and predictions also got stored in my S3 Bucket.

This type of legitimate but forged request are known as Cross site request forgery (Csrf). Now the question is…
How to implement CSRF protection for your app?
According to the documentation, if you’re using FlaskForm to process the request, you’re already getting CSRF protection. If you’re not using FlaskForm or making AJAX request, you probably need to explicitly add another layer of Security with the provided extension.
# app.py
from flask import Flask
from flask_wtf.csrf import CSRFProtect, CSRFError
app = Flask(__name__)
app.config['SECRET_KEY'] = "anything_unique"
app.config['WTF_CSRF_TIME_LIMIT'] = WTF_CSRF_TIME_LIMIT
CSRFProtect(app)
As you might have guessed, in the first line, we are importing the Flask class. An instance of this class will be our WSGI application. In the second line, we are importing the CSRFProtect class from flask_wtf module.
Next we’re creating the instance of Flask class. In the next two lines, we’re putting the secret key and CSRF Token expiration time limit. It is important to note that without this secret key, you can’t really take the benefit of CSRF protection.
Note: CSRF protection requires a secret key to securely sign the token. By default this will use the Flask app’s
SECRET_KEY
. If you’d like to use a separate token you can setWTF_CSRF_SECRET_KEY
.
I’d strongly recommend that you store your keys in .env
file or as an environment variable so that doesn’t get distributed while pushing your code on the production. I’ll talk about WTF_CSRF_TIME_LIMIT
in the last part of this blog. Last line is where the magic happens! Registering with the ProtectCSRF
extension will enable the CSRF protection globally for a Flask app.
Now we’ll add the value of CSRF token in our form.
<!-- index.html -->
<form method="post" >
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}" />
...
...
...
</form>
However, the official documentation suggests adding csrf_token
in the headers of an AJAX request.
And guess what? We’re done!

Okay, so you don’t believe me? Here’s the proof 😅
After adding the code, I executed the same command in my terminal and got the following response.

The Flask app couldn’t find the csrf_token
in the request’s body, hence the bad request.
Customization
- You can set the expiration time of your CSRF Token using
WTF_CSRF_TIME_LIMIT
. It is the maximum age in seconds for CSRF tokens. The default is value is 3600. If set to None, the CSRF token is valid for the life of the session. - You can also catch the error in case of missing/invalid CSRF Token and show it in the view of your app. The class
CSRFError
is in our rescue here. You just need to define a simple route in your Flask app to catch the CSRF Token exception.
@app.errorhandler(CSRFError)
def handle_csrf_error(e):
return jsonify({"error": e.description}), 400
In the above code-stab, we’re catching the possible exceptions of CSRF and returning it in the JSON format with the status code of 400 (Bad Request).
You can modify it according to your application. Here’s the response of a POST request after altering the CSRF Token.

End Notes
All the code snippets used in this blog are the part of my Indian Paper Currency Prediction project.
I’d be more than happy to e-meet you. You can visit my personal website www.rohitswami.com to know more about me. Also, you can find me on LinkedIn and GitHub 🎉
I would love to hear your feedback on this article. Feel free to spam the comment section below. 😊
Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seeking professional advice. See our Reader Terms for details.