Model Pruning in Deep Neural Networks Using the TensorFlow API

Perceval Desforges
Towards Data Science
7 min readFeb 11, 2022

--

Photo by Etienne Delorieux on Unsplash

One of the most common problems in machine learning is overfitting. This can occur for a variety of reasons [1]. To address this problem, one common solution is to add regularization terms to the model. Another consists in reducing the complexity of the model by reducing the amount of parameters. For an optimal solution, a combination of both approaches should be taken. In this article, we will explore the latter, and more specifically how to incorporate model pruning (which consists in removing superfluous weights in your models) in your Keras Tensorflow models using the Tensorflow Model Optimization API.

What is pruning and why choose pruning?

Pruning a machine learning model consists of removing (setting to 0 permanently) certain weights. Usually the weights that are pruned are those that are already close to 0 (in absolute value). This stops a model from overfitting since the weights that were deemed useless at the start of the training cannot be re-activated again.

There are different ways to prune a model. One can prune at the start some random amount of weights. One can prune at the end of training as well to simplify a model and make it lighter. Here we will see how we can set up a pruning scheduler that will slowly prune some weights at the end of every batch of learning until a desired sparsity (percentage of weights set to 0) is achieved.

One may wonder why a model should be pruned instead of being initialized with less trainable parameters from the get-go. The reason for this is that you may wish to keep a rather complex model architecture, in order to increase the model capacity, with many possible interactions between features, but limit their number. Furthermore, fine-tuning which layers should be decreased in size, which ones should be increased, and which features should be kept, is oftentimes a tedious and fruitless venture. It is much simpler to prune the model during training in order to little by little get rid of excess weights. An advantage of this method is that it allows one to train multiple different models that may have been pruned differently, and one can then combine these models using various ensemble learning techniques [2]. These aggregate models are usually much more robust than a single model. We will now see how to implement this technique.

The Tensorflow Model Optimization Toolkit

The goal is then to eliminate the weakest weights at the end of every training step (batch). While one could implement their own callback in order to do this, luckily there already exists a Tensorflow API called Tensorflow Model Optimization (tfmot) that does precisely this [3]. This tool allows one to define a pruning scheduler which will automatically take care of weight elimination. The scheduler follows a polynomial decay schedule. One supplies an initial sparsity, a final sparsity, a step at which to begin the pruning, a step at which to end it, and finally the exponent of the polynomial decay. At each step, the toolkit will eliminate enough weights so that the achieved sparsity is:

Formula for the polynomial decay scheduler

where S is the sparsity, Sₑ is the final sparsity, S₀ is the initial sparsity, t is the current time step, tₑ is the end step, t₀ is the begin step, and α is the exponent. The default value for α is 3. The other hyper-parameters need to be toyed around with in order to find the optimal values. Ideally, the weights should be pruned little by little, over a fairly long period, so that the network has time to adapt to the loss of weights. Indeed, if the weights are pruned too quickly, one runs the risk of removing weights that could have been influential simply because they were initialized close to 0 and the network had not had time to learn.

In order to better understand these concepts, we will illustrate through an example how to apply pruning in your models.

Pruning: an example

Pruning is best employed for models with many parameters (such as MLPs) and where we suspect many are not useful. This can arise when our data-set has many features (a couple hundred), but not all of them are necessarily relevant. Naturally, if one knows which features are needed and which ones are not, the irrelevant features should be eliminated during the feature engineering step of the process, using dimension reduction techniques (such as PCA, NMF [4], etc.). However, this is not always the case.

For our example dataset, we will be using the make_regression function of the sklearn.datasets library. This function generates a regression dataset where one can specify the amount of observations required, as well as the number of features, the number of relevant features, and how noisy the data must be.

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
# Parameters of the data-set
n_samples = 10000
n_features = 1000
n_informative = 500
noise = 3
# Create dataset and preprocess it
x, y = make_regression(n_samples=n_samples, n_features=n_features, n_informative=n_informative, noise=noise)
x = x / abs(x).max(axis=0)
y = y / abs(y).max()
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)

Now that we have our example, let’s build our model. Since there is a large amount of features, and for the purpose of our example, we will design a very large MLP with many parameters:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, ReLU
model = tf.keras.Sequential()
model.add(Dense(1024, kernel_initializer="he_normal", input_dim=n_features))
model.add(ReLU())
model.add(Dense(1024))
model.add(ReLU())
model.add(Dense(1))

This architecture results in a model with over 2,000,000 parameters, which is a lot, especially since we know that half of the features are not actually useful. How do we integrate pruning? It is quite simple. We must first set the various pruning parameters in a dictionary. We then define a new model based on the original model with the pruning parameters using tfmot. And finally, we add a pruning callback.

import tensorflow_model_optimization as tfmotinitial_sparsity = 0.0
final_sparsity = 0.75
begin_step = 1000
end_step = 5000
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=initial_sparsity,
final_sparsity=final_sparsity,
begin_step=begin_step,
end_step=end_step)
}
model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep()

We then compile and fit the model as usual, without forgetting to add the callback in the fitting step:

model.compile(
loss="mse",
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
)
model.fit(
x_train,
y_train,
epochs=200,
batch_size=1024,
callbacks= pruning_callback,
verbose=1
)

Choosing the Hyperparameters

We now have all the elements needed for pruning. But how effective is it? And how do you set the parameters? I usually always keep the initial sparsity at 0, unless you purposefully want to create multiple sparse models with very different random initializations in order to aggregate them. I set the begin step so that the pruning starts happening after a few epochs. The end step I usually set so that the pruning happens over a few epochs as well (something in the order of 10).

In my experience these parameters are not as important as the final sparsity parameter, which will really depend on the problem at hand. In the case of overfitting, the performance of the model on the training set will be much better than on the validation set. Introducing some sparsity in the model will very likely decrease the performance on the training set, but improve it on the validation set. Further increasing the sparsity will continue to improve the performance on the validation set. Obviously, when the final sparsity is too large, the model becomes too simple and the performance on both the train and validation set will start decreasing, which means that out model is now underfitting the data. There is therefore a minimum for the performance on the validation set as a function of the final sparsity.

Results

In the following graph, I have taken the aforementioned example, and trained different models with different final sparsities. The performance of the models were averaged over 10 different trainings with different random initializations:

Averaged loss on the train and validation sets for final sparsities in [0, 1]

For very low final sparsities, the model clearly overfits, as can be seen in the above figure. The loss on the validation set is of several orders of magnitude greater than the loss on the training set. As the final sparsity increases, the loss on the training set monotonically increases, while the loss on the validation set has a minimum at around 95% sparsity, which results in a model with about 100,000 parameters. This can be seen more clearly in the following figure:

Averaged loss on the train and validation sets for final sparsities in [0.8, 1]

Closing remarks

In this short article, we have seen that it is quite simple to implement pruning in your Tensorflow models using the Tensorflow Model Optimization Toolkit adding only a few lines of code. Pruning is yet another way to reduce overfitting, to be used with other common methods such as regularization, dropout, and batch normalization. Pruning excels in situations where models have a great amount of parameters and datasets have very many features. It allows one to train models with great capacity while remaining parsimonious in accordance with Occam’s razor principle. These models can then be combined through ensemble learning techniques in order to build more robust models.

Acknowledgments

I would like to thank Nicolas Morizet and Christophe GEISSLER for their help and support for the writing of this article.

References

[1] https://hastie.su.domains/ElemStatLearn/ (chapter 7)

[2] https://link.springer.com/article/10.1007/s10462-009-9124-7

[3] https://www.tensorflow.org/model_optimization/

[4] https://hal.archives-ouvertes.fr/hal-03141876/document

[5] https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras#fine-tune_pre-trained_model_with_pruning

[6] https://www.jmlr.org/papers/volume22/21-0366/21-0366.pdf

About us

Advestis is a European Contract Research Organization (CRO) with a deep understanding and practice of statistics and interpretable machine learning techniques. The expertise of Advestis covers the modeling of complex systems and predictive analysis for temporal phenomena.
LinkedIn: https://www.linkedin.com/company/advestis/

--

--