Implementing TabNet in PyTorch

Samrat Thapa
Towards Data Science
7 min readOct 23, 2020

--

https://unsplash.com/photos/Wpnoqo2plFA

Deep Learning has taken over vision, natural language processing, speech recognition, and many other fields achieving astonishing results and even superhuman performance in some. However, the use of deep learning to model tabular data has been relatively limited.

For tabular data, the most common approach is the use of tree-based models and their ensembles. The tree-based models globally select features which reduce the entropy the most. Ensemble methods like bagging, boosting improve these tree-based methods further by reducing the model variance. Recent tree-based ensembles like XGBoost and LightGBM have dominated Kaggle competitions.

TabNet is a neural architecture developed by the research team at Google Cloud AI. It was able to achieve state of the art results on several datasets in both regression and classification problems. It combines the features of neural nets to fit very complex functions and the feature selection property of tree-based algorithms. In other words, the model learns to select only the relevant features during the training process. Moreover, contrary to tree-based models which can only do feature-selection globally, the feature selection process in TabNet is instance-wise. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.

In this blog, I will take you through a step-wise beginner-friendly implementation of TabNet in PyTorch. Let’s get started!!

The TabNet Architecture.

1) source:https://arxiv.org/pdf/1908.07442v1.pdf

Figure(1) was taken from the original TabNet paper. We will build each component of the image individually and assemble them in the end. First, let’s go through two essential concepts used in this model- Ghost Batch Normalization (GBN) , and Sparsemax.

Ghost Batch Normalization(GBN):
GBN allows us to train large batches of data and also generalize better at the same time. To put it simply, we split the input batch into equal-sized sub-batches(the virtual batch size) and apply the same Batch Normalization layer on them. All the batch normalization layers used in the model except the first batch normalization layer applied to the input features are GBN layers. It can be implemented in PyTorch as follows:

class GBN(nn.Module):
def __init__(self,inp,vbs=128,momentum=0.01):
super().__init__()
self.bn = nn.BatchNorm1d(inp,momentum=momentum)
self.vbs = vbs
def forward(self,x):
chunk = torch.chunk(x,x.size(0)//self.vbs,0)
res = [self.bn(y) for y in chunk]
return torch.cat(res,0)

Sparsemax:

Sparsemax is a non-linear normalization function just like softmax but as the name suggests, the distribution is ‘sparser’. That is, compared to softmax some numbers in the output probability distribution are much closer to 1 while others are much closer to 0. This enables the model to select relevant features at each decision steps more effectively. We will use sparsemax to project the mask for the feature selection step onto a sparser space. The implementation of sparsemax can be found at: https://github.com/gokceneraslan/SparseMax.torch

To further increase the sparsity in the mask, we will also add a sparsity regularization technique to penalize less-sparser masks. This can be implemented at each decision step as follows:

(mask*torch.log(mask+1e-10)).mean() #F(x)= -∑xlog(x+eps)

The sum of this value over all decision steps can be added to the total loss (after multiplying with a regularization constant λ ).

Attention Transformer:
This is where the models learns the relationship between relevant features and decides which features to pass on to the feature transformer of the current decision step. Each Attention Transformer consists of a fully connected layer, a Ghost Batch Normalization Layer, and a Sparsemax layer. The attention transformer in each decision step receives the input features, processed features from the previous step and prior information about used-features. The prior information is represented by a matrix of size batch_size x input_features. It is initialized with ones and passed to and updated at every decision step’s attention transformer. There is also a relaxation parameter that limits how many times a certain feature can be used in a forward pass. A greater value implies that the model can reuse the same feature several times. I think the code makes everything clear.

class AttentionTransformer(nn.Module):
def __init__(self,d_a,inp_dim,relax,vbs=128):
super().__init__()
self.fc = nn.Linear(d_a,inp_dim)
self.bn = GBN(out_dim,vbs=vbs)
self.smax = Sparsemax()
self.r = relax
#a:feature from previous decision step
def forward(self,a,priors):
a = self.bn(self.fc(a))
mask = self.smax(a*priors)
priors =priors*(self.r-mask) #updating the prior
return mask

This mask is then multiplied(element-wise) to the normalized input features.

Feature Transformer:

The feature transformer is where all the selected features are processed to generate the final output. Each feature transformer is composed of multiple Gated Linear Unit Blocks. A GLU controls which information must be allowed to further flow through the network. To implement a GLU Block, first we double the dimension of the input features to the GLU using a fully connected layer. We normalize the resultant matrix using a GBN Layer . Then, we apply a sigmoid to the second half of the resultant features and multiply the results to the first half. The result is multiplied with a scaling factor(sqrt(0.5) in this case) and added to the input. This summed result is the input for the next GLU Block in the sequence.

A certain number of GLU Blocks are shared among all the decision steps to promote model capacity and efficiency(Optional). The first shared GLU Block (or first independent block if no blocks are shared) is unique as it reduces the dimension of the input features to a dimension equal n_a+n_d. n_a is the dimension of the features input to the attention transformer of the next step and n_d is the dimension of the features used to calculate the final results. These features are processed together until they reach the splitter. The ReLU activation is applied on the n_d dimensioned vector. The outputs of all the decision steps are summed together and passed through a fully connected layer to map them to the output dimension.

class GLU(nn.Module):
def __init__(self,inp_dim,out_dim,fc=None,vbs=128):
super().__init__()
if fc:
self.fc = fc
else:
self.fc = nn.Linear(inp_dim,out_dim*2)
self.bn = GBN(out_dim*2,vbs=vbs)
self.od = out_dim
def forward(self,x):
x = self.bn(self.fc(x))
return x[:,:self.od]*torch.sigmoid(x[:,self.od:])
class FeatureTransformer(nn.Module):
def __init__(self,inp_dim,out_dim,shared,n_ind,vbs=128):
super().__init__()
first = True
self.shared = nn.ModuleList()
if shared:
self.shared.append(GLU(inp_dim,out_dim,shared[0],vbs=vbs))
first= False
for fc in shared[1:]:
self.shared.append(GLU(out_dim,out_dim,fc,vbs=vbs))
else:
self.shared = None
self.independ = nn.ModuleList()
if first:
self.independ.append(GLU(inp,out_dim,vbs=vbs))
for x in range(first, n_ind):
self.independ.append(GLU(out_dim,out_dim,vbs=vbs))
self.scale = torch.sqrt(torch.tensor([.5],device=device))
def forward(self,x):
if self.shared:
x = self.shared[0](x)
for glu in self.shared[1:]:
x = torch.add(x, glu(x))
x = x*self.scale
for glu in self.independ:
x = torch.add(x, glu(x))
x = x*self.scale
return x

Next, let us combine the Attention Transformer and Feature Transformer into a decision step:

class DecisionStep(nn.Module):
def __init__(self,inp_dim,n_d,n_a,shared,n_ind,relax,vbs=128):
super().__init__()
self.fea_tran = FeatureTransformer(inp_dim,n_d+n_a,shared,n_ind,vbs)
self.atten_tran = AttentionTransformer(n_a,inp_dim,relax,vbs)
def forward(self,x,a,priors):
mask = self.atten_tran(a,priors)
sparse_loss = ((-1)*mask*torch.log(mask+1e-10)).mean()
x = self.fea_tran(x*mask)
return x,sparse_loss

Finally, we can complete the model by combining several decision steps together:

class TabNet(nn.Module):
def __init__(self,inp_dim,final_out_dim,n_d=64,n_a=64,
n_shared=2,n_ind=2,n_steps=5,relax=1.2,vbs=128):
super().__init__()
if n_shared>0:
self.shared = nn.ModuleList()
self.shared.append(nn.Linear(inp_dim,2*(n_d+n_a)))
for x in range(n_shared-1):
self.shared.append(nn.Linear(n_d+n_a,2*(n_d+n_a)))
else:
self.shared=None
self.first_step = FeatureTransformer(inp_dim,n_d+n_a,self.shared,n_ind)
self.steps = nn.ModuleList()
for x in range(n_steps-1):
self.steps.append(DecisionStep(inp_dim,n_d,n_a,self.shared,n_ind,relax,vbs))
self.fc = nn.Linear(n_d,final_out_dim)
self.bn = nn.BatchNorm1d(inp_dim)
self.n_d = n_d
def forward(self,x):
x = self.bn(x)
x_a = self.first_step(x)[:,self.n_d:]
sparse_loss = torch.zeros(1).to(x.device)
out = torch.zeros(x.size(0),self.n_d).to(x.device)
priors = torch.ones(x.shape).to(x.device)
for step in self.steps:
x_te,l = step(x,x_a,priors)
out += F.relu(x_te[:,:self.n_d])
x_a = x_te[:,self.n_d:]
sparse_loss += l
return self.fc(out),sparse_loss

Approximate Range of Model Hyperparameters:
n_d, n_a: 8 to 512
batch size: 256 to 32768
virtual batch size: 128 to 2048
sparsity regularization constant: 0 to 0.00001
number of shared GLU Blocks: 2 to 10
number of independent decision Blocks: 2 to 10
relaxation constant: 1 to 2.5
number of decision steps: 2 to 10
batch normalization momentum: 0.5 to 0.98
(Note: These hyperparameters are based on the paper and also my personal experience.)

This TabNet module can be extended for classification as well as regression tasks. You can add embeddings for your categorical variables, apply a Sigmoid or Softmax function to the output, and much more. Experiment and see what works best for you. In the example notebook, I tried replacing the Sparsemax with a Sigmoid and was able to get a slightly better accuracy.

For a use case example on the Mechanisms of Actions(MoA) Prediction dataset, you can find my notebook here:https://www.kaggle.com/samratthapa/drug-prediction. It is the dataset of a competition currently being held on Kaggle.

My implementation of TabNet is a short adaptation of the work of the generous people at DreamQuark. Their complete implementation of TabNet can be found at :
https://github.com/dreamquark-ai/tabnet/tree/develop/pytorch_tabnet.
You should consider reading the paper for a more detailed description of TabNet.

Thank you for reading. I hope I was able to help.

References:
1) Sercan O. Arık, Tomas Pfister. 2020.TabNet: Attentive Interpretable Tabular Learning https://arxiv.org/abs/1908.07442v4
2) Yann N. Dauphin, Angela Fan, Michael Auli, and David Grangier. 2016. Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083.pdf
3) Elad Hoffer, Itay Hubara, and Daniel Soudry. 2017. Train longer, generalize better: closing the generalization gap in large batch training of neural networks.
https://arxiv.org/pdf/1705.08741.pdf
4)Andre F. T. Martins and Ramon Fernandez Astudillo. 2016. From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification.
https://arxiv.org/abs/1602.02068
5)Sparsemax implementation https://github.com/gokceneraslan/SparseMax.torch
6)Complete PyTorch TabNet implementation
https://github.com/dreamquark-ai/tabnet/tree/develop/pytorch_tabnet

--

--