Building Prediction APIs in Python (Part 3): Automated Testing

Chris Moradi
Towards Data Science
14 min readFeb 7, 2018

--

In the last post, we improved the error handling of our prediction API and considered the nuanced decision about which records we should score. In this post, we’ll look at how we can test our API using pytest.

As always, we’re going to use Python 3, and I’m going to assume that you are either using Anaconda or that you’ve set up an environment with these packages installed: flask, scikit-learn, and pytest.

Note: I’ll walk through snippets of code in this post, but it may be helpful to see the full files in context. You can find the complete examples on GitHub.

The Case for Testing

While automated testing is a core practice in modern software development, it has yet to be fully embraced by many data scientists. Simply put, this is additional code that runs tests on your main code base. Testing frameworks, like pytest, make it easy to define and execute a suite of these tests. As new functionality is implemented or existing code is refactored, these tests help developers confirm that existing functionality hasn’t been broken or locate bugs that have been introduced.

Here are some common excuses for not writing tests with my (snarky) responses:

  • There is very little code, so we don’t need tests. What happens when you need to extend your current functionality? Each enhancement might be small, but in the long run you can end up with a large and untested code base. If there’s currently very little code, it should be easy to write tests, so do it!
  • I already checked each function against multiple test cases when I was writing the original code. That’s great! One challenge with constructing tests is to come up with good test cases, and you’ve already done that. Won’t it be great to have the ability to run all those test cases every time you make a minor change?
  • I don’t make mistakes when I write code¹. Of course! You don’t need test cases for your code, but what happens when someone else gets added to the project and starts modifying your beautiful, perfect code? How will we know if they broke something? Are you willing to check every commit they make in detail? Are you excited to maintain this code for the rest of your life?

Before getting into the meat of this post, I want to point out that I’m going to be a hypocrite. We’ll focus on testing the API. Along the way, we’ll also need to modify the code that builds our model. Astute readers will notice that I’m not writing any tests for the model build pipeline. If it helps, I do feel bad about this.

Adjustments for Better Testing

Before we get into writing the actual tests, there’s one thing that I’d like to change about the API response we’re generating. Currently, we’re only sending back the predicted class (iris type). While this is what the users of our prediction API need, it doesn’t provide a ton of information for us to test. The predicted class is chosen as the class (iris type) with the highest model score. Because of this thresholding, the predicted class could be the same even if the underlying scores are somewhat different. This is analogous to a function that performs complicated and precise calculations, but returns a value that is rounded to the nearest integer. Even if the returned integer values for many inputs match the expected values, the underlying calculations may not be correct. Ideally, we’d like to verify that the precise results of the calculations are correct before they are thresholded or rounded.

To provide ourselves with more data to verify that we’re correctly scoring the model, we’re going to change the API response to include the probabilities of each class². This is accomplished by calling MODEL.predict_proba() instead of MODEL.predict(). We then use argmax() to get the index of the largest value which gives us the predicted class. We'll return the raw class probabilities to the user through probabilities (see below for an example response).

Updated API to return class probabilities

We can then run the API (python predict_api.py) and make a test call through requests:

Basic Test of the API

We’re going to begin with a simple example where we just score the same example as shown above, but we’ll use pytest to do this. We first need to create a testing file: test_predict_api.py. For now, we can put this in the same directory as our API file: predict_api.py. Note: pytest is capable of automatically locating testing files and functions, but you need to assist it in doing this. By default, it will inspect any file whose name begins with the “test_” prefix and will run any test functions that begin with “test_”.

When we construct tests, we typically follow this pattern:

  1. Set Up (optional): Initialization for the test or environment. Examples: initialize a database, create instances of a class that will be used in the tests, initialize the state of the system, etc…
  2. Run Code: Run some code from your main code base (code under test) on a predefined test case or within the predefined environment. This may include: calling a function/method, creating an instance of a class, initializing a resource, call an API, etc…
  3. Verify Results: Check that effects of the code meet expectations through the use of assert statements: returned value of function call is correct, exceptions were appropriately raised, the system has changed to the correct state, etc…
  4. Tear Down (optional): Clean up after the test has run to restore the environment back to a default.

Here’s how this example will align with these steps:

  1. Set Up: Instantiate a test_client (described below) that will allow us to simulate calls to the API.
  2. Run Code: Make a call to the /predict end point with a predefined set of features.
  3. Verify Results: We get back a response with status_code of 200 and the content is JSON with the correct format and values.
  4. Tear Down: This is somewhat implicit. We’re going to use a context manager during the “set up” of the test_client. Upon exit, the client will be cleaned up.

As noted above, the test_client we’re using is a feature of Flask that allows us to simulate calls without having to run the server. Here, we create a new client using a context manager. From this, we make a GET request. The query_string keyword argument provides functionality similar to how params works in requests.get(); it allows us to pass data that is used to create the query string.

From the response, we’re checking that we received the “200 OK” status, and finally we check that the payload of response matches what we expected. Since it comes as bytes, we can use json.loads() to convert it into a dict.

We can now execute the tests using pytest³ at the command line.

Always Fail First

Whenever you’re writing automated tests, it’s important to verify that you’re actually testing something — that your tests can fail. While this is an obvious statement, a common pitfall is that tests are written that don’t actually test anything. When they pass, the developer assumes that the code that is being tested is correct. However, the tests are passing because the tests were poorly written.

One way of preventing this kind of error is to use test-driven development (TDD). We’re not going to go into this in depth, but it’s a development process where:

  1. You start by writing a test for a new feature.
  2. You verify that the test fails.
  3. You write code to implement the new feature.
  4. You verify that the test now passes.

If you haven’t tried TDD before, I definitely recommend it. It takes discipline, especially when getting started, and buy-in from other developers and stakeholders. However, dedication to the process is rewarded with fewer bugs and lower stress when implementing new features. To learn more, read Harry Percival’s excellent book, Test-Driven Development with Python, which he’s made available online for free.

If you’re not up for TDD, the lazy method that I’ve used to verify that each test is actually testing something is to change my assert expressions to explicitly fail. For example, we’d change assert response.status_code == 200 to assert response.status_code != 200. If you make this change and rerun the tests, you should receive a failure similar to this:

If you’re going to use this method, be aware that pytest will only report the first AssertionError that occurs. So, you must change each assert separately and retest.

More Testing

We now have a test call for our API, and it’s working. How can we extend this to test multiple calls with different feature values and different expected results (labels and probabilities)? One quick option is to use our test dataset that we created during the model build. However, we need to get the class probabilities and the predicted label to use as the expected result for each input record.

One important thing to note is that we are testing the API platform and not the model itself. Basically, this means that we don’t care if the model is making false predictions; we just want to verify that the model output as scored on the API platform matches the model output from the build/offline/development environment. We’ll also need to test that the preparation of features (e.g., mean imputation) is done correctly on the API platform, but we’ll save that for the next section of this post.

Since our test dataset may change with each new version of our model, we should incorporate the generation of these data into our model build. I did some light refactoring (more is needed) to our model build script and added the code for generating the test dataset after the model is built. We’ll store our test cases in a JSON file, and each test case will be structured as:

Here’s the modified version of our model build code that incorporates the test dataset generation:

At the top, there’s a function called prep_test_cases() which just reformats the features, class probabilities, and predicted label for each test into our test case format.

Now that we’ve generated our test data, we need to add a new test to score all the records in this file and check the responses:

Because we structured the test data in a clean way where each test case has the features (API inputs) as well as the expected response (API outputs), the test code is fairly straightforward. One deficiency with this approach is that class probabilities are floats, and we’re doing an exact comparison of these values. Typically, some tolerance is allowed when comparing float values so that values that are very close are considered equivalent. To handle this, we’d need parse the expected response and use pytest.approx() when comparing each value in the probabilities. It doesn't require much more code, but I thought it would clutter this discussion a bit, so I'm leaving out the implementation.

Handling Missing Values

Our API is configured to use mean imputation to replace bad or missing values, but our test dataset doesn’t include any records with missing values. However, this is not an issue since we can simulate these data using the data we already have. We simply need to replace the existing values with the mean value for a feature and rescore the records. To our model build script we’ll add the following after our original test data generation code:

We likely don’t need to be this thorough, but for each record, we are testing every combination of features that could be missing. We create two versions of each record: one that has None for missing values and one that has the imputed by the mean value. The first will be stored (after None-valued features are dropped) in the test case as features. The second will be scored to get the prediction probabilities that we’ll expect to see returned by the API. In order to drop the None-valued features we must make a small change to the creation of feat_dict in the prep_test_cases function. Here's the revised function:

We also need change our tests to use this new file. While we could just copy the last test function test_api() and replace the filename testdata_iris_v1.0.json, that would result duplicate code. Since we need the test function to be exactly the same except for the filename, a better approach is to use pytest’s parametrize functionality. We simply add a decorator that allows us to specify arguments for the test function and rerun the test for each of these values. In this case, we'll pass in the filename:

Testing Errors

In the last post on error handling, I mentioned that we could be more selective in which records we were willing to score, but I did not provide an example of this. Here we’ll tweak our API to look at a simple example where we’ll reject requests that are missing data for petal_width. We'll score all other records, using mean imputation if needed.

As a slight tangent, how did I chose petal_width? Well, if we look at the feature importances (using model.feature_importances_), we see that the fourth feature (petal_width) has the highest value with a normalized score of 0.51. Since this is the most important feature in our model, it makes the most sense to reject records that are missing this feature.

A simple way to implement this is to remove the default value for petal_width and then handle the case if it's missing. Nearly all of the code remains the same, but I've included it here for context.

We can do a quick test to make sure it works for a simple case:

Great! It works! Now, we just need to add this to our test suite.

For simplicity, I’m going to skip showing how to modify our old missing value tests and just implement the new tests that handle missing or bad values for petal_width. Basically, I removed all tuples from missing_grps that had a 3 in them (index of petal_width) and the test where all features are missing.

For our new tests, we could use the same JSON format. This would be a cleaner implementation. For clarity though, I’m just going to implement these tests in a separate function that has two test cases for petal_width: the feature is missing and the feature has a bad value ("junk").

We can rerun our tests and verify that these pass. Of course, we should also try to change the == to != to verify that they fail for each condition too. Again, this will help ensure that we're testing what we actually think we're testing.

Identifying Issues

Now that we have our tests, we might wonder if they’re actually going to catch bugs in our code. Perhaps, you already found some issues when you were creating these tests and running them on your API code. If not, here are some simple changes that you can make that should cause one or more of the tests to fail (do each of these independently):

  • In the API code, change the default (mean imputation) value for sepal_length from 5.8 to 5.3.
  • In the API, put back the mean imputation for petal_width. This will allow the API to score records that are missing petal_width. You should see failures in the tests that are expecting the API to return "400 Bad Request" when petal_width is missing.
  • In the API, change the text of the error message that is sent when petal_width is missing.
  • We can also mimic the case where someone accidentally modifies the model and tries to deploy it. To test this, we can build an alternate version of your model, deploy it, but use the test data files (JSON) for the original model. A quick way to implement this is to use a different value for random_state in the training/test set split (e.g., random_state=30). Remember to change the model output filename in joblib.dump() to something else (e.g., 'iris-rf-altmodel.pkl'); you'll need to change MODEL in the API to reference this file. Also, make sure you don't execute the code that generates the test data files as these will rebuild them based on the alternate model. When you rerun your tests, you will likely see failures in all tests except for the ones that reject requests when petal_width is missing or invalid. If your tests still pass, try another random_state as it's possible that the model will be equivalent because the training set may remain the same or the changes aren't enough to change the model.

Our tests are definitely catching problems, but are we catching all of the problems? The simple answer is that we are probably not catching everything. While creating this post, I tried changing the mean imputation (default value) for sepal_width to 3.1 instead of 3.0. When I reran the tests, they all passed. Perhaps this isn't a big deal; perhaps our model just isn't that sensitive to small shifts in sepal_width around the mean value. This is the feature of lowest importance. However, we used our test set for test cases, and these data points don't necessarily fall near the boundary of different classes. If we had more test cases or just better test cases, we may have been able to catch this type of bug.

Wrapping Up

We have seen that automated tests can help us find bugs in our code. While we started with testing a single API call, we were able to quickly move towards a framework for running numerous test cases, and it only required adding a little extra code.

Testing is an important topic, so it’s likely we’ll revisit this in later posts. Here’s a quick preview of some areas that we didn’t cover, but may in the future:

  • Speed of Tests: It’s beneficial to run tests often while you’re changing your code. This makes it easier to detect errors early when refactoring existing code or adding new features. If the tests take a while to run, developers are less likely to do this. One approach is to separate tests that run quickly from those that take more time. Developers can then run the quick test suite when they make incremental changes, and run the full suite before they integrate the changes back into the main repository.
  • Mocking & Patching: We saw that small changes to the default (mean imputation) value for sepal_width didn't cause our tests to fail. If this was a requirement, we could use patching to intercept the call to model.predict_proba() during scoring to verify that the correct values are being substituted.
  • Fixtures: This is a feature of pytest that you create, configure, and destroy resources in order to establish a clean and consistent environment in which each test can run. If you’ve familiar with “setup” and “teardown” in many unit testing frameworks, fixtures are an extension of this idea.
  • Integration of Subsystems: Currently, we just have our model and our API. In subsequent posts, we’ll look at adding a database backend and perhaps some other services. How do we test these? How do we test the system as a whole?
  • Continuous Integration Tools: These help make it easier to integrate code into a shared repository. By making it easier, the hope is that it will be done more often and in smaller chunks. One common feature of these tools is that they will automatically the run test suite every time a pull request is submitted and the pass/fail results will be available to reviewers.
  • Test Coverage: Did we test every line of our code? We could create a test coverage report to help us see which lines of our code were run during testing and which were not. This won’t tell us if we’ve handled all possible cases, but it can give us information into where our test suite is falling short.
  • Advanced Testing Methodologies: We’re unlikely to cover these topics, but I wanted to mention them. With property-based testing (see Hypothesis), you create parameterized tests and the framework generates an extensive set of test cases for you. This can lead to more comprehensive tests without requiring you to think up all the edge cases. Mutation testing (see Cosmic Ray) takes a very different approach. It works with your existing test cases and actually modifies your source code (code under test) in some small way (mutation) to see if your existing tests fail. If all tests still pass, your test code is incomplete as it’s unable to find the bugs that were introduced by the mutation.

FOOTNOTES

  1. Not sure if anyone would actually says this, but on occasion you meet someone who seems to think this.
  2. You may be thinking, “My users don’t need to know the underlying scores for each class; they just need the prediction. So, why am I changing my response just for testing?” Good question! We’re taking this approach somewhat out of convenience and for the sake of clarity. In a real implementation, you use a flag to specify whether underlying scores are returned and perhaps restrict this functionality to certain users. We could also use mocking and patching to access the scores without modifying the response to include model scores.
  3. If you’re having trouble running pytest try these options instead: py.test or python -m pytest.

--

--