The world’s leading publication for data science, AI, and ML professionals.

Should I Really Eat That Mushroom?

Classifying edible and poisonous mushrooms with CatBoost gradient boosted decision trees

Most educational and real-world datasets contain categorical features. Today we will cover gradient boosted decision trees from the CatBoost library, which provides native support for categorical data. We will use a dataset of mushrooms that are either edible or poisonous. The mushrooms are described by categorical features such as their color, odor, and shape, and the question we want to answer is:

Is it safe to eat this mushroom – based on its categorical features?

As you can see, the stakes are high. We want to make sure that we get the machine learning model right so that our mushroom omelet does not end in a disaster. As a bonus, at the end we will provide a feature importance ranking that tells you which categorical feature is the strongest predictor of mushroom safety.

Introducing the mushroom dataset

The mushroom dataset is available here: https://archive.ics.uci.edu/dataset/73/mushroom [1]. For clarity of presentation, we create a pandas DataFrame from the original cryptic short-form variables and annotate it with proper column names and long-form variables. We use pandas’ replace function with long-form variables taken from the dataset description. The target variable can only take True and False values – the dataset creators played it safe and classified questionable mushrooms as inedible.

After checking the dataset for missing values, we find that only one column – stalk_root – is affected. We drop this column.

Exploration of the dataset reveals that the data is fairly balanced: Of the 8124 mushrooms, 4208 are edible and 3916 are poisonous. We divide the dataframe into the target variable, is_edible, and the remaining mushroom features. Then, we split the dataset into training and test data by stratifying on the target variable. This ensures that the distribution of classes is comparable in both splits.

The CatBoost library

CatBoost is an open source machine learning package for gradient boosting decision trees. The Catboost Python package can be obtained by following the installation instructions. The most important components for us are the catboost.Pool, which organizes the dataset and specifies categorical and numerical features, and our model, the catboost.CatBoostClassifier . Categorical features can be difficult to handle with machine learning algorithms. They must be encoded into numerical values before they can be used for training. Each categorical value is associated with a number, e.g. for mushroom colors,brown->0, black->1, yellow->2, ... . CatBoost can automatically handle categorical input variables, which saves us from adding one-hot encoding to the procedure. Not only is this convenient, but CatBoost algorithms are also optimized to train fast with categorical variables.

Gradient-boosted decision trees

Decision trees are well-established Machine Learning algorithms that classify samples into different categories based on the value of their features. A single decision tree is prone to overfitting. Therefore, ensembles of decision trees are typically used to achieve better performance. In gradient boosted decision trees, the ensemble of trees is constructed by iteratively updating the tree. Each iteration of the tree provides a small improvement over the previous iterations by training on the residuals left by applying the previous tree. The process stops when the loss converges, i.e., when adding more trees does not add value, or when the fixed number of total trees is reached. For a more detailed introduction to gradient-boosted devision trees, see the recommended blog posts at the bottom of this page.

Classifying mushrooms

In the mushroom dataset, all features are categorical and are specified accordingly in the Pool. We construct a Pool for both splits, training and testing. The target variable is cast to numeric values, as this integrates better with the loss routines of the CatBoostClassifier . The classifier itself is specified in a format similar to scikit-learn. There are many attributes that can be changed, including the learning rate, the total number of trees, and the regularization of the tree. The loss function is log-loss , since we are dealing with binary classification.

We define the dataset and the model in the code box below. For comparison, we train a single decision tree and a full gradient boosted decision tree.

Evaluation

Now we are ready to evaluate the performance of the classifier on the test data. Eating poisonous mushrooms can cause serious health problems, so we are interested in reducing false positives. We calculate the precision metric, which is the fraction of the number of mushrooms that are actually edible by the number of mushrooms that are predicted to be edible.

The single decision tree gives a precision of 97%, which is quite good for a classification algorithm. But with gradient boosted trees, we can improve precision to 100%, and there are no poisonous mushrooms that are mislabeled as edible in the test dataset. The confusion matrix shows that the gradient boosted decision tree provides optimal performance on the test set.

Feature importance

This is great, but we may not have all day to determine 22 features for each mushroom we want to eat. So what is the most important feature to determine whether a mushroom is edible?

To answer this question, we use the built-in model attribute feature_importances_ to derive a feature importance ranking for the gradient boosted tree classifier. As it turns out, odor dominates the feature importance ranking, followed by spore print color and population.

A closer look at the possible odor values reveals that this feature itself is already a good predictor of whether a mushroom will be a tasty addition to your meal or end your day in hospital. All mushrooms in this dataset that smell of anise or almond are edible. Mushrooms without odor are mostly edible. You should stay away from fishy, spicy, pungent, foul, creosote, and musty mushrooms – which to be honest do not sound tasty in the first place.

Summary

We have presented the mushroom dataset, which contains samples of edible and poisonous mushrooms described only by categorical variables. We introduced the catboost package, which works well with categorical data and provides gradient boosted decision trees. A model was trained to classify the mushrooms accordingly and achieved satisfactory performance. Odor is the strongest predictor of mushroom safety. We hope you enjoyed this blog post, and take no responsibility for the application of the model to real mushrooms 🙂 .

Further reading

Dataset reference

[1] Mushroom. UCI Machine Learning Repository (1987). https://doi.org/10.24432/C5959T. This dataset is licensed under a Creative Commons Attribution 4.0 International (CC BY 4.0) license.


Related Articles