Entity Embeddings for ML

Give them a try: they perform better than one-hot encodings.

Adam Mehdi
Towards Data Science

--

Representing categorical data continuously, like this continuous view of the savanna. [source]

This post is a guide for using embeddings to represent categorical variables in machine learning algorithms. I’ve seen other articles address this method, but none actually showing how to do it in code. So I decided to fill that lacuna with my own post.

If you are itchy for the code, feel free to skip to the Implementation section. There is also a concomitant notebook with the pipeline laid out in full.

Why Entity Embeddings?

Simply put, they perform better than one-hot encodings because they represent categorical variables in a compact and continuous way.

We can replace one-hot encodings with embeddings to represent categorical variables in practically any modelling algorithm, from Neural Nets to k-Nearest Neighbors and tree ensembles. Whereas one-hot encodings ignore informative relations between a feature’s values, entity embeddings can map related values closer together in embedding space, revealing the inherent continuity of the data (Guo 2016).

For instance, when using word embeddings (which are essentially the same as entity embeddings) to represent each category, a perfect set of embeddings would hold the relationship: king - queen = husband - wife.

Values in a categorical variable virtually always exhibit some sort of relationship. Another example: if representing colors with entity embeddings, “brown” and “black” have similar values in an element indicating shade and different values in another element, say, one indicating composition of primary colors. Such a representation gives the model a sense of how each variable interrelates, thereby making the learning process easier and boosting performance.

Harnessing these benefits is really as simple as replacing one-hot matrices with embedding matrices (taken from a NN trained with those categories).

Performance (mean absolute percent error — lower is better) of algorithms not using and using Entity Embeddings to represent categorical variables on the Rossman dataset (Guo 2016).

Enough talk; let’s get into the implementation.

Implementation

I’ll be using PyTorch, fastai, and sklearn. There are three steps to the pipeline:

1. Train a neural network with embeddings

# import modules, read data, and define options
from fastai.tabular.all import *
df = pd.read_csv('/train.csv', low_memory=False)
cont,cat = cont_cat_split(df_nn, max_card=9000, dep_var='target')
procs = [Categorify, Normalize]
splits = RandomSplitter()(df)
device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# feed the data into the Learner and train
to_nn = TabularPandas(nn, procs, cat, cont,
splits=splits, y_names='target')
dls = to_nn.dataloaders(1024, device = device)
learn = tabular_learner(dls, layers=[500,250], n_out=1)
learn.fit_one_cycle(12, 3e-3)

The end result is a neural network with an embedding layer that can be used to represent each categorical variable. Usually the size of the embeddings is a hyperparameter to be specified, but fast.ai makes this easy by automatically inferring an appropriate embedding size based on the cardinality of the variable.

This step can be done in pure PyTorch or TensorFlow; just make sure to modify the appropriate parts of the subsequent code if you choose to do so.

2. Replace each categorical value with its embedding vector

def embed_features(learner, xs):
"""
learner: fastai Learner used to train the neural net
xs: DataFrame containing input variables. Categorical values are defined by their rank.
::return:: copy of `xs` with embeddings replacing each categorical variable
"""
xs = xs.copy()
for i,col in enumerate(learn.dls.cat_names):

# get matrix containing each row's embedding vector
emb = learn.model.embeds[i]
emb_data = emb(tensor(xs[col], dtype=torch.int64))
emb_names = [f'{col}_{j}' for j in range(emb_data.shape[1])]

# join the embedded category and drop the old feature column
feat_df = pd.DataFrame(data=emb_data, index=xs.index,
columns=emb_names)
xs = xs.drop(col, axis=1)
xs = xs.join(feat_df)
return xs

This function expands each categorical column (a vector of size n_rows) into an embedding matrix of shape (n_rows, embedding_dim). Now we use it to embed the categorical columns of our data.

emb_xs = embed_features(learn, to.train.xs)
emb_valid_xs = embed_features(learn, to.valid.xs)

It is a steep task to follow the code without experimenting with it yourself, so I provide the before and after of a sample dataset:

Dataset before entity embedding method is applied.
Dataset after the entity embedding method is applied.

3. Train ML algorithms on the embedded data

Most of the heavy lifting has already been done; we can now train our machine learning algorithm in the standard way, but passing the embedded data as input.

Here’s an example of a pipeline for training a random forest:

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score

rf = RandomForestClassifier(n_estimators=40, max_samples=100_000,
max_features=.5, min_samples_leaf=5)
rf = rf.fit(emb_xs, y)
roc_auc_score(rf.predict(emb_valid_xs), to.valid.y)

If you wish to use a different algorithm, just replace RandomForestClassifier with GradientBoostingClassifier or whatever other algorithm, and you should see a boost in performance relative to the one-hot encoding method.

Feature selection

There is one final consideration to make. There are many more features now that we have expanded columns ranked by their values to embedding matrices. A tree ensemble would be much slower because there are more features to loop through and accordingly more splits to evaluate. With one-hot encodings, a decision tree rapidly trains, splitting examples into groups of 1 or 0.

However, in the embedded representation, each column contains a continuous interval (0, 1). The tree must therefore cut the interval into several bins and evaluate each bin separately. So, for decision trees, the entity embedding representation is a lot more computation than the one-hot encodings, meaning that training and inference take much longer.

We can ameliorate that problem: choose the most salient features and only train on those.

Sklearn tree ensembles can automatically calculate feature importance with the feature_importances_ attribute of the model, which returns a list of percentages corresponding to each feature column. To calculate those percentages, sklearn loops through each split on each tree, sums the information gain of each split on which a feature is used, and uses that accumulated information gain as a proxy for contribution.

m = RandomForestClassifier().fit(emb_xs, y)
fi = pd.DataFrame({'cols':df.columns, 'imp':m.feature_importances_})
emb_xs_filt = emb_xs.loc[fi['imp'] > .002]

The code is short and sweet: fit a model, define a DataFramethat relates each feature to their feature importance, and drop any column whose feature importance is at or under .002 (a hyperparameter to tune).

Our reduced dataset confers just as good a performance as the full-sized one at a fraction of the dimensionality. Less columns means less prone to overfitting and more robust to changes in new data, however, so it is a good idea to apply feature selection even if time constraints are not an issue.

If pressed for time when using a tree ensemble and unable do the entity embeddings method, though, just represent categorical variables as you would any ordinal variable. Ordinal-style ranked representations are faster and at least as good as one-hot encodings (Wright 2019).

Conclusion

There are better replacements for one-hot encodings. Using embeddings from a trained neural network to represent categorical variables outperforms the one-hot encoding method in machine learning algorithms. Entity embeddings can represent categorical variables in a continuous way, retaining the relationship between different data values and thereby facilitating the model’s training. The better-performing machine learning models can be used in an ensemble or as a substitute for neural nets if requiring interpretability.

References

  1. Guo, Cheng et al. “Entity Embeddings of Categorical Variables”. arXiv:1604.06737
  2. Wright MN, König IR. 2019. “Splitting on categorical predictors in random forests”. PeerJ. https://doi.org/10.7717/peerj.6339.
  3. Howard, Jeremy. Deep Learning for Coders with Fastai and PyTorch: AI Applications Without a PhD. https://www.amazon.com/Deep-Learning-Coders-fastai-PyTorch/dp/1492045527

And thank you to the fastai forums and fastai book [3] which provided resources to learn the material.

--

--

Thinking about AI & epistemology. Researching CV & ML as published Assistant Researcher. Studying CS @ Columbia Engineering.