The world’s leading publication for data science, AI, and ML professionals.

An End-to-End Machine Learning Project – Heart Failure Prediction Part 1

Data exploration, model training, validation and storage

In this series, I will be walking through an end-to-end Machine Learning project covering everything from data exploration to model deployment via a web application. My goal is to provide general insight into the different components involved in getting a model to production; this series is not a comprehensive overview of the machine learning pipeline. This article will cover data exploration, model training, validation and storage. Throughout the series, the reader will be exposed to various languages and technologies such as Git, Python, SQL, Catboost, Flask, HTML, CSS, JavaScript and Heroku. All code for this series can be found on GitHub. Part 2 is now available here.

Problem Statement

Formulating a problem statement is the first and most essential prerequisite for any machine learning project. The problem statement should be clear enough to show how machine learning can be used as a solution. I have chosen the following problem statement for this series:

We would like to analyze risk factors for heart failure and model the probability of heart failure in an individual.

There are two components to this problem statement that can potentially be addressed with machine learning: analyzing risk factors and modeling the probability of heart failure. In particular, one way to address this problem is to build a model that can accurately predict an individual’s chance of heart failure and back the prediction with evidence.

Data

The next, and most vital, prerequisite for a machine learning project is to have a reliable, clean and structured data source. In an enterprise setting, these datasets are often the product of a hard-working data engineering team and should never be taken for granted. Fortunately, websites like Kaggle provide us with clean datasets ready for modeling. We will use a popular heart disease dataset for this project.

Heart Failure Prediction Dataset

This dataset contains 11 features that we will use to model heart failure probability. Moreover, when making a prediction, we would like our model to tell us which features contributed the most (more on this in future articles). In order to make the data source a little more realistic, I will be converting the .csv file to a .db file. This will make data retrieval more modular as it will imitate an actual database.

Data Exploration

We will start by examining the data used to model heart failure probability. I won’t go over every line of code, but all of it can be found in this notebook. To start, we will import all of the necessary libraries and specify a few storage paths that will be used later on:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from catboost import CatBoostClassifier, Pool, cv
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import optuna
from optuna.samplers import TPESampler
import sqlite3
from datetime import datetime
# display all of the columns of the dataframe in the notebook
pd.pandas.set_option('display.max_columns',None)
%matplotlib inline
# Path to mock database
SQL_PATH = 'data/MockDatabase.db'
# Path to model storage location
MODEL_STORAGE_PATH = 'models/'

Next, we will read data from the .db file using a SQL query. For anyone interested in knowing how to create a .db file, see this notebook. The ‘create_connection’ function used below comes directly from a sqlite tutorial.

# Read in training data
conn = create_connection(SQL_PATH)
read_query = '''SELECT Age,
                       Sex, 
                       ChestPainType,
                       RestingBP,
                       Cholesterol,
                       FastingBS,
                       RestingECG,
                       MaxHR,
                       ExerciseAngina,
                       Oldpeak,
                       ST_Slope,
                       HeartDisease
                 FROM heart'''
data_train = pd.read_sql(read_query, conn)
conn.close()
print(data_train.shape)
data_train.head()
A peak at the training data. Image by Author.
A peak at the training data. Image by Author.

Our dataset appears to be a nice mix of categorical and numeric features. With only 918 rows, there isn’t a whole lot of training data to work with. However, this will make model training time fast and allow us to easily perform hyper-parameter tuning and cross validation. Fortunately, there are no missing values in the data; this is particularly surprising for medical data and is a sign that the data engineering team has done an excellent job. The importance of good data engineering cannot be stressed enough.

Our next objective is to analyze the categorical and numeric features. The code below creates two lists, one for categorical and one for numeric features. The categorical features list will be used later on when building the model.

# Separate columns by data type for analysis
cat_cols = [col for col in data_train.columns if data_train[col].dtype == np.object]
num_cols = [col for col in data_train.columns if data_train[col].dtype != np.object]
# Ensure that all columns have been accounted for
assert len(cat_cols) + len(num_cols) == data_train.shape[1]

From here, we can see the cardinality (number of unique values) of each categorical feature.

# Look at cardinality of the categorical columns
cards = [len(data_train[col].unique()) for col in cat_cols]
fig,ax = plt.subplots(figsize=(18,6))
sns.barplot(x=cat_cols, y=cards)
ax.set_xlabel('Feature')
ax.set_ylabel('Number of Categories')
ax.set_title('Feature Cardinality')
Cardinality of each categorical feature. Image by Author.
Cardinality of each categorical feature. Image by Author.

All of the categorical features have relatively low cardinality, making it easy for the model to encode and process these features. One interesting property of this dataset is that roughly 79% of the patients were male. Moreover, over 90% of patients with heart disease were male.

Far more patients being studied were male. Image by Author.
Far more patients being studied were male. Image by Author.
Over 90% of patients with heart disease were male. Image by Author.
Over 90% of patients with heart disease were male. Image by Author.

Next, we look at the distribution of resting blood pressure between patients with and without disease.

Image by Author.
Image by Author.

Although the distributions look relatively similar, it appears that patients with disease tend to have slightly higher blood pressures – this observation could be validated by a statistical test. For now, we should keep some of these plots in mind as we start to analyze model results.

Lastly, before we start modeling, we should see how imbalanced the classes are. That is, how many patients have and don’t have heart disease.

Heart disease distribution. Image by Author.
Heart disease distribution. Image by Author.

Roughly 55% of the patients studied had heart disease, and this gives a baseline percentage to benchmark our model against. In other words, if our model learns anything from the data, it should have an accuracy of over 55%. If our model is 55% accurate at predicting heart disease, this indicates that no novel information has been derived from the data. The model can simply predict that every patient has heart disease and it will be 55% accurate. In many real-world classification problems such as fraud detection, baseline accuracy can be 90% or higher (the majority of transactions are not fraud). In these situations, it is common to track metrics other than accuracy such as precision, recall, and f1 score.

Initial Model Training

The model of choice for this series is Catboost, and we will begin by training and evaluating a simple model instance. Below, we create training and testing datasets. The training dataset consists of ~80% of the original data and the testing set is ~20%. To replicate the results of training, specify a random state of 13 in the ‘train_test_split’ function.

# Create training and testing data
x, y = data_train.drop(['HeartDisease'], axis=1), data_train['HeartDisease']

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=13)
print(x_train.shape)
print(x_test.shape)
x_train.head()

The next step is to create a Catboost model instance. Before we do this, a list must be created containing the column numbers of each categorical feature. The list is easy to create by utilizing the ‘cat_cols’ list that was created previously.

# Specify index of categorical features in input data
cat_features = [x.columns.get_loc(col) for col in cat_cols]

Then we can create a dictionary with the desired model parameters. To start, we will only specify a few parameters since Catboost is known for excellent ‘out-of-the-box’ performance.

# Model parameter dict
params = {'iterations':5000,
          'loss_function':'Logloss',
          'depth':4,
          'early_stopping_rounds':20,
          'custom_loss':['AUC', 'Accuracy']}

Descriptions of Catboost parameters can be found here. The important parameters to point out here are ‘iterations’ and ‘early_stopping_rounds’. In this case, we will train the model until either 5000 iterations (trees) have been passed or 20 iterations have gone by with no reduction in test error. Finally, we instantiate the model and call the ‘fit’ method. Within the ‘fit’ method, we pass the test sets so that the model can implement early stopping rounds (it is actually more common to create train, test, and validation sets).

# Instantiate model
model = CatBoostClassifier(**params)
# Fit model
model.fit(
    x_train,
    y_train,
    cat_features=cat_features,
    eval_set=(x_test, y_test),
    verbose=50,
    plot=True
)

After the model finishes training, we make predictions on the test data and output performance metrics.

# Make predictions on test data
preds = model.predict(x_test)
# Evaluate predictions
print(classification_report(y_test, preds))
print(confusion_matrix(y_test, preds))
Performance metrics for the initial train/test split. Image by Author.
Performance metrics for the initial train/test split. Image by Author.

On this test set comprised of 184 patients, the model achieved an accuracy of 89%. This means that the model was able to correctly identify the presence or absence of heart disease in 89% of the patients in the test set. We can interpret the precision as follows: when the model predicted that a patient had heart disease, it was correct 89% of the time. Similarly, we can interpret recall as follows: of the number of patients that had heart disease, the model successfully predicted the presence of heart disease in these patients 91% of the time.

Moving forward, we can look at the global feature importance of the model instance.

fig,ax = plt.subplots(figsize=(10,6))
feature_importance_data = pd.DataFrame({'feature':model.feature_names_, 'importance':model.feature_importances_})
feature_importance_data.sort_values('importance', ascending=False, inplace=True)
sns.barplot(x='importance', y='feature', data=feature_importance_data)
Feature importance of the first model instance. Image by Author.
Feature importance of the first model instance. Image by Author.

The feature importances indicate that ‘ST_Slope’ was the number one predictor of heart disease. ‘ST_Slope’ is a measurement of electrical activity in the heart recorded while the patient is exercising; it takes on values ‘Up’, ‘Down’, and ‘Flat’. ‘ChestPainType’ is the second most influential predictor, taking on values ‘TA’ (Typical Angina), ‘ ATA’ (Atypical Angina), ‘NAP’ (Non-Anginal Pain), and ‘ASY’ (Asymptomatic). Thirdly, ‘Oldpeak’ is a measurement of the magnitude of ‘ST_Slope’. Overall, without having extensive cardiology knowledge, the order of the feature importance seems intuitive.

Hyper-parameter Tuning

Hyper-parameter tuning is often referred to as the ‘dark art’ of machine learning. This is because there are many ways one can go about determining the best set of model hyper-parameters, and, in most cases, it is impossible to know if one has found the optimal set of parameters. Moreover, as the training data size increases, and the number of possible hyper-parameters increases, tuning becomes an extremely expensive high-dimensional optimization problem. Luckily for us, the training data size is relatively small and the number of useful hyper-parameters to search over are few.

We will use a well-known optimization library, Optuna, to search for a good set of hyper-parameters. Some of the code used for this section was adapted from an awesome article written by Zachary Warnes.

We must first define an objective function to optimize:

See this article for more detail. Image by Author.
See this article for more detail. Image by Author.

The ‘params’ dictionary defines all of the hyper-parameters to search over, as well as the range of values to consider. It is important to note that we are optimizing the 5-fold cross-validated test accuracy, not just the accuracy on an arbitrary testing set. This is accomplished using the ‘cv’ function from Catboost.

The optimization trials are conducted with the following code:

classification_study = optuna.create_study(sampler=TPESampler(), direction="maximize")
classification_study.optimize(classification_objective, n_trials=20, timeout=600) 
trial = classification_study.best_trial
print(f"Highest Accuracy: {trial.value}")
print("Optimal Parameters:")
for key, val in trial.params.items():
    print(f"{key}:{val}")

The two main things to point out are the number of trials and the timeout. With the above configuration, the optimization algorithm will run for 20 iterations or it will terminate before 20 iterations if 10 minutes have passed. After the first iteration, Optuna will output something like this:

Optuna output after the first iteration. Image by Author.
Optuna output after the first iteration. Image by Author.

This tells us that the cross-validated test accuracy with the first set of hyper-parameters was ~86.7%. One might notice that this is lower than the 89% accuracy we achieved on the initial test set. This is to be expected as cross-validated accuracy gives us a better estimate of true model accuracy, i.e., how accurate the model will be in production. After 20 iterations, Optuna was able to find a set of parameters that improved cross-validated accuracy to ~88%.

Output of the best hyper-parameter set found. Image by Author.
Output of the best hyper-parameter set found. Image by Author.

The number of iterations (trees) in the model will be determined by cross-validation using this set of hyper-parameters. Specifically, the final number of iterations is determined by the iteration number that maximizes cross-validated test accuracy.

# Create new parameter dictionary using optimal hyper-parameters
new_params = trial.params.copy()
new_params['loss_function'] = 'Logloss'
new_params['custom_loss'] = ['AUC','Accuracy']
cv_data = cv(
        params = new_params,
        pool = Pool(x, label=y, cat_features=cat_features),
        fold_count=5,
        shuffle=True,
        partition_random_seed=0,
        plot=False,
        stratified=False,
        verbose=False)
final_params = new_params.copy()
# The final number of iterations is iteration number that maximizes cross-validated accuracy
final_params['iterations'] = np.argmax(cv_data['test-Accuracy-mean'])
final_params['cat_features'] = cat_features

Train and export final model

Lastly, we train the final model using the optimal set of hyper-parameters and export it to a desired storage location.

final_model = CatBoostClassifier(**final_params)
final_model.fit(x,y,verbose=100)
# Export model
model_name = f'heart_disease_model_{str(datetime.today())[0:10]}'
final_model.save_model(MODEL_STORAGE_PATH + model_name)

Summary of part 1

In part 1 of this series, we examined training data, trained and evaluated a model instance, performed hyper-parameter tuning, and exported the final model. Again, this was by no means a comprehensive article. Every machine learning project is unique and most problems require a lot more attention to the training data than was given here. However, I hope that the code presented here provides the reader with a general framework for a machine learning project.

In the next article, we will take a peek into the world of web development by creating an application that serves our model to end users. This will involve the use of HTML, CSS, JavaScript, and Flask. Thanks for reading!


Related Articles