The world’s leading publication for data science, AI, and ML professionals.

Using Entity Embeddings to improve the performance of Machine Learning Models

Tutorial on implementing Embeddings learned by a neural net in ML models

Photo by Mika Baumeister on Unsplash
Photo by Mika Baumeister on Unsplash

This article’s purpose is to provide information on how to implement Embeddings learned by a neural net in ML models. Thus, we won’t go into detail about the theory of Embeddings.

Note: It is assumed that you know the basics of Deep Learning and Machine Learning

What are Entity Embeddings and why use Entity Embeddings?

To put it loosely, an entity embedding is a vector representation of categorical variables in a continuous manner. In the context of neural networks, embeddings transform features from its original space into a low-dimensional vector space representation for each instance while preserving the information from its features and also meaningfully represent each category in the embedding space. As such, using embeddings gives the model an idea of how each variable interrelate with each other, allowing for performance boosts.

Implementation

The general steps to using embeddings learned by a neural net for training ML models are:

  1. Train a neural network with embedding layers.
  2. Extract the embeddings from the trained neural network.
  3. Replace the categorical variables with the embeddings of the categorical variables from the trained neural network.
  4. Train yourML model using the embeddings.

In this tutorial, we will be using sklearn, Fastai, PyTorch and the famous Titanic dataset for demonstration purposes. You can replicate this using the framework of your choice. Data cleaning and feature engineering have been done prior to this tutorial.

#required libraries
from sklearn.metrics import classification_report
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier
from fastai.tabular.all import *

Preprocess data before feeding data into neural network using FastAI

dls = TabularPandas(df_train, y_names="Survived", y_block=CategoryBlock, 
    cat_names = ['Cabin', 'Title', 'Sex'],
    cont_names = ['Age', 'Pclass', 'Fam_size'],
    procs = [Categorify, FillMissing, Normalize], splits = RandomSplitter(valid_pct=0.2)(range_of(df_train)))
to_nn = dls.dataloaders()

Create a TabularLearner and use lr_find() to find a suitable learning rate.

learn = tabular_learner(to_nn, metrics=accuracy)
learn.lr_find()

Train the neural network.

learn.fit_one_cycle(8, 2e-2)

Extract embeddings from the trained neural network and replace the categorical variables with the embeddings of the categorical variables from the trained neural network.

#function to embed features ,obtained from fastai forums
def embed_features(learner, xs):
    xs = xs.copy()
    for i, feature in enumerate(learner.dls.cat_names):
        emb = learner.model.embeds[i]
        new_feat = pd.DataFrame(emb(tensor(xs[feature], dtype=torch.int64)), index=xs.index, columns=[f'{feature}_{j}' for j in range(emb.embedding_dim)])
        xs.drop(columns=feature, inplace=True)
        xs = xs.join(new_feat)
    return xs
emb_xs = embed_features(learn, to_nn.train.xs)
emb_valid_xs = embed_features(learn, to_nn.valid.xs)

Train your ML model using the embeddings.

rf = RandomForestClassifier(n_estimators=400, min_samples_leaf=10,      
                          max_features=1/2, max_samples = 50)
rf = rf.fit(emb_xs,to_nn.train.y)
valid_preds = rf.predict(emb_valid_xs)
print(classification_report( to_nn.valid.y,valid_preds))
Image from author.
Image from author.

Compared to the random forest trained on the original dataset, we obtained an accuracy increase of 3% on the random forest trained on the embedded data, which is a pretty substantial increase!

Conclusion

Using entity embeddings not only can help us understand and visualize our data better but it can also improve the performance of ML models. Embeddings are a useful tool to handle categorical variables and is an upgrade to the traditional encoding methods such as one-hot encodings.


Related Articles