Writing your own Scikit-learn classes — for beginners.

Basics to get you started as fast as possible

Anup Sebastian
Towards Data Science

--

Image from: https://en.wikipedia.org/wiki/File:Scikit_learn_logo_small.svg

Scikit-learn is one of the most popular Python libraries used for supervised machine learning, and it is an extremely powerful tool while being simple at the same time. But the reason for its success, I believe is not mainly because of its core functionality, but because of how cohesive and consistent the framework is for all its modules. This is also the reason besides its popularity why many other popular libraries like XGBoost, pyGAM, lighGBM and pyEarth are written in the style of Scikit-learn, and are ensured to be compatible with it, almost as thought they are part of the core library. A good number of such useful projects made to be compatible with Scikit-learn can be found here.

This cohesive compatibility is the result of Scikit-learn having its own set of simple guidelines that if followed will ensure everything works well it. It also provides a set of base classes that make the process of creating your own classes much easier. If you are new to writing your own classes, whether it is for your own use or for a library, this article can provide you with a solid starting point to get into it.

Why should you write your own classes as a beginner?

Now, you may be wondering why someone should write their own classes, if there are already so many professionally written ones to use? You would be right to think about it, because you shouldn’t really bother with writing your own if the freely available ones do exactly what you need them to do. However, even as a beginner there is a chance you might have run into some operation you wanted to do, in your own unique workflow, but there was nothing available that does it exactly like you want it to. If this operation is something that you end up doing often, it would be useful to write your own class for it.

Also, if your operation is in between a series of operations that have pre-written classes, it could prevent you from adding it to a pipeline to fully automate it. Also it would not be possible to use model selection on the entire process using Scikit-learn tools like GridSearchCV.

Some other good reasons to do it are:

  • It develops your programming skills.
  • It could be developed over time to have more options and functionality for your own work, and also could end up being useful to others.
  • It enables you to contribute to the open source community. Understanding that the reason why most of the extremely powerful tools you use are available for free because of such contributions.

So with that, let’s dive straight into creating your own classes.

Let’s say you wanted to build a transformer that subtracts the smallest number in a column, from all the values in the columns. For example, let’s say a column has the numbers [10, 11, 12, 13, 14, 15], our transformer would change this to [0, 1, 2, 3, 4, 5]. This may not be a particularly useful transformer bit it is a very simple example just to get started.

Lets create 2 toy dataframes:

import pandas as pddf_a = pd.DataFrame({‘Number’:[11, 12, 13, 14, 15],’Color’: [‘Red’, ‘Blue’, ‘Green’, ‘Yellow’, ‘Orange’]})
df_b = pd.DataFrame({‘Number’:[21, 22, 23, 24, 25],’Color’: [‘Violet’, ‘Black’, ‘White’, ‘Brown’, ‘Pink’]})

This is what they look like:

Now lets write the transformer:

# These allow us the class to inherit Scikit-learn methods
# such as fit and transform
from sklearn.base import BaseEstimator, TransformerMixin
# This function just makes sure that the object is fitted
from sklearn.utils.validation import check_is_fitted
class SubtractMin(BaseEstimator, TransformerMixin):
def __init__(self, cols_to_operate):
self.columns = cols_to_operate
def fit(self, X, y = None):
self.min_val_ = X[self.columns].min()
return self

def transform(self, X):
# make sure that it was fitted
check_is_fitted(self, ‘min_val_’)

X = X.copy() # This is so we do not make changes to the
original dataframe
X[self.columns] = X[self.columns] — self.min_val_
return X

The classes we import from sklearn.base are the glue that makes it all work. They are what allow our function to fit in with Scikit-learn’s pipelines, and model selection tools. The BaseEstimator just gives it the get_params and set_params methods that all Scikit-learn estimators require. The TransformerMixin gives it the fit_transform method.

The BaseEstimator documentation can be found here, and the TransformerMixin documentation can be found here.

There are other kinds of ‘Mixins’ in available based on the type of estimator you want to build, such as the ClassifierMixin and the RegressorMixin, which give it access to the score methods.As you are probably aware, in general, data processors use the fit method (along with fit_transform) and models (regressors and classifiers) use the predict method. So, you need to add in the appropriate classes as parent classes while you create your estimator. For instance, a custom regressor would require the BaseEstimator and RegressorMixin, other more complicated classes may require more than these based on their functionality.

The check_is_fitted is just a validation method to make sure that you have already fitted your object before you are trying to do something with it.

In the __init__ method we add it all the initial parameters that the transformer requires. Over here it is just that name of the columns of the dataframe that we need to operate on.

The fit and transform methods are required methods for all Scikit-Learn transformers (for regressors and classifiers it would be fit and predict). The fit method in our case just finds out the minimum value in the column and stores it. You probably noticed we have a y = None in there even though we do not use y at all for our fit method. This is so that that the remains uniform with other classes in Scikit-learn and ensures compatibility in pipelines. The fit method also always has to return self.

The transform method does the work and return the output. We make a copy so the original dataframe is not touched, and then subtract the minimum value that the fit method stored, and then return the output. This would obviously be more elaborate in your own useful methods.

Lets see our transformer in action.

sub = SubtractMin(‘Number’)
sub.fit_transform(df_a)
sub.transform(df_b)

The first thing you probably noticed, was that we used the fit_transform method even though we did not explicitly use it. As mentioned earlier, this was inherited from the TransformerMixin base class.

Our transformer essentially ‘learned’ the minimum value from df_a (10) and subtracted it from the column when the fit_transform method was called. It then subtracted the same minimum value from b when the transform method was called.

Scikit-learn has a set of more elaborate guidelines, that need to be followed to contribute to them, but these should be enough to ensure compatibility and to help you write classes for your own projects. It would be a good idea to have a look at their guidelines (available here) once you get comfortable with writing your classes, so that you can contribute to the community. These guidelines are more suited towards contributing to Scikit-learn, or similar libraries. Following all of them is not necessary but recommended. There are plenty of projects which do not follow them as written. Consider it a long term goal to work towards in your own classes.

Hopefully, this should be a quick enough guide to get you started writing your own custom Scikit-learn classes. Good Luck!

--

--