I always thought that you’d need separate modules for pruning but I recently discovered that PyTorch has built-in support for it. The documentation is a bit lacking though … so I decided to pen this article and show you some tips and tricks!

What is Pruning?
Pruning is a technique that removes weights or biases (parameters) from a neural network. If done right, this reduces the memory footprint of the model, improves generalization, speeds-up inference, and allows training/fine-tuning with fewer samples. Of course, you can not just randomly remove parameters from your network and expect it to perform better – but you can determine which parameters are unnecessary for your target and remove those. Needless to say, you should also watch how many parameters you remove: if you remove too many your network will perform much worse or may become entirely defunct if you block the gradient flow (e.g. by pruning all parameters from a connecting layer).
Note that it is pretty common to prune after training, in general, it is also possible to apply pruning before or during training.
What parameters should I prune?
In the previous paragraph, I intentionally used the word "unnecessary" to refer to prunable parameters. But what makes a parameter unnecessary? This is quite a complicated question and is still a research field today. Amongst the most popular methods for finding prunable weights (pruning criterion) are:
- *Random:** Simply prune random parameters.
- *Magnitude:** Prune the parameters with the least weight (e.g. their L2 norm).
- Gradient: Prune parameters based on the accumulated gradient (requires a backward pass and therefore data).
- Information: Leverage other information such as high-order curvature information for pruning.
- Learned: Of course, we can also train our network to prune itself (very expensive, requires training)!
*Pytorch has built-in support for random-, magnitude-based pruning. Both methods are surprisingly effective given how easy it is to compute them, and that they can be computed without any data.
Types of pruning
Unstructured Pruning
Unstructured Puning refers to pruning individual atoms of parameters. E.g. individual weights in linear layers, individual filter pixels in convolution layers, some scaling floats in your custom layer, etc. The point is you prune parameters without their respective structure, hence the name unstructured pruning.
Structured Pruning
As an alternative to unstructured pruning, structured pruning removes entire structures of parameters. This does not mean that it has to be an entire parameter, but you go beyond removing individual atoms e.g. in linear weights you’d drop entire rows or columns, or, in convolution layers entire filters (I point the interested reader to [1] where we have shown that many publicly available models contain a bunch of degenerated filters that should be prunable).
In practice, you can achieve much higher pruning ratios with unstructured pruning, but it probably won’t speed up your model, as you still have to do all computations. Structured pruning can e.g. prune entire convolution channels and therefore significantly lower the number of matrix multiplications you need. Currently, there is a trend to support sparse tensors in both soft- and hardware, so in the future unstructured pruning may become highly relevant.
Local vs. Global Pruning
Pruning can happen per layer (local) or over all multiple/all layers (global).
Pruning in PyTorch
How does pruning work in PyTorch?
Pruning is implemented in torch.nn.utils.prune
.
Interestingly, PyTorch goes beyond simply setting pruned parameters to zero. PyTorch copies the parameter <param>
into a parameter called <param>_original
and creates a buffer that stores the pruning mask <param>_mask
. It also creates a module-level forward_pre_hook
(a callback that is invoked before a forward pass) that applies the pruning mask to the original weight.
This has the following consequences: Printing <param>
will print the parameter with the applied mask, but listing it via <module>.parameters()
or <module>.named_parameters()
will show the original, unpruned parameter.
This has the following advantages: It is possible to determine if a module has been pruned, and original parameters are accessible which allows experimentation with various pruning techniques. Yet, it comes at cost of some memory overhead.
Which PyTorch versions are supported?
You are good if you have version 1.4.0 or later.
How can I implement it?
The supported options are a bit confusing and the API is slightly inconsistent so I made this overview, that will hopefully clear things up:

Local Unstructured Pruning
The following functions are available for local unstructured pruning:
torch.nn.utils.prune.random_unstructured(module, name, amount)
torch.nn.utils.prune.l1_unstructured(module, name, amount, importance_scores=None)
Just call the functions above and pass your layer/module as module
and the name of the parameter to prune for name
. Typically this will be weight or bias. The amount
parameter specifies how much to prune. You can pass a float between 0 and 1 for a ratio, or an integer to define an absolute number of parameters. Be aware that these commands can be applied iteratively and the amount
is always relative to the number of remaining (i.e. not pruned) parameters. So, if you iteratively prune a parameter with 12 entries with amount=0.5
you will end up with 6 parameters after the first round, then 3, …
Here is an example that prunes 40% of a convolution layer weight. Note how 4 parameters are set to zero.
>>> import torch.nn.utils.prune as prune
>>> conv = torch.nn.Conv2d(1, 1, 3)
>>> prune.random_unstructured(conv, name="weight", amount=4)
>>> conv.weight
tensor([[[[-0.0000, 0.0000, 0.2603],
[-0.3278, 0.0000, 0.0280],
[-0.0361, 0.1409, 0.0000]]]], grad_fn=<MulBackward0>)
Other norms than L1 are not supported since we operate on atoms.
Global Unstructured Pruning
If you want global unstructured pruning the command is slightly different:
torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)
Here we need to pass parameters
as a list of tuples that hold the module and their parameter name to prune. pruning_method=prune.L1Unstuctured
seems to be the only supported option. Here is an example from the PyTorch docs:
model = ...
parameters = (
(model.conv1, "weight"),
(model.conv2, "weight"),
(model.fc1, "weight"),
(model.fc2, "weight"),
(model.fc3, "weight"),
)
prune.global_unstructured(
parameters,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
If you want to prune all weights of a specific layer type (e.g. a convolution layer), you can automatically collect them as follows:
model = ...
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: type(m) == torch.nn.Conv2d, model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
Of course, you can adjust the filter to your needs.
Local Structured Pruning
PyTorch only supports local structured pruning:
torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)
torch.nn.utils.prune.random_structured(module, name, amount, dim)
The commands are fairly similar to the local unstructured ones with the only difference that you will have to define dim
parameter. This will define the axis of your structure. Here is a helper for the relevant dimensions:
For torch.nn.Linear
- Disconnect all connections to one input:
1
- Disconnect one neuron:
0
For torch.nn.Conv2d
:
- Channels (stack of kernels that output one feature-map):
0
- Neurons (stack of kernels that process the same input feature-map in different channels):
1
- Filter kernels: not supported (would require multi-axis [2, 3] or prior reshape, which is not that easy either)
Note that contrary to unstructured pruning you can actually define what norm to use with the n
parameter. You can find a list of supported ones here: https://pytorch.org/docs/stable/generated/torch.norm.html#torch.norm.
Here is an example that prunes an entire channel(this corresponds to 2 kernels in our example) based on the L2-norm:
>>> conv = torch.nn.Conv2d(2, 3, 3)
>>> prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)
>>> conv.weight
tensor([[[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000]],
[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000]]],
[[[ 0.2284, 0.1574, -0.0215],
[-0.1096, 0.0952, -0.2251],
[-0.0805, -0.0173, 0.1648]],
[[-0.1104, 0.2012, -0.2088],
[-0.1687, 0.0815, 0.1644],
[-0.1963, 0.0762, -0.0722]]],
[[[-0.1055, -0.1729, 0.2109],
[ 0.1997, 0.0158, -0.2311],
[-0.1218, -0.1244, 0.2313]],
[[-0.0159, -0.0298, 0.1097],
[ 0.0617, -0.0955, 0.1564],
[ 0.2337, 0.1703, 0.0744]]]], grad_fn=<MulBackward0>)
Note how the output changes if we prune a neuron instead:
>>> conv = torch.nn.Conv2d(2, 3, 3)
>>> prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)
>>> conv.weight
tensor([[[[ 0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000]],
[[-0.1013, 0.1255, 0.0151],
[-0.1110, 0.2281, 0.0783],
[-0.0215, 0.1412, -0.1201]]],
[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000]],
[[ 0.0878, 0.2104, 0.0414],
[ 0.0724, -0.1888, 0.1855],
[ 0.2354, 0.1313, -0.1799]]],
[[[-0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000]],
[[ 0.1891, 0.0992, 0.1736],
[ 0.0451, 0.0173, 0.0677],
[ 0.2121, 0.1194, -0.1031]]]], grad_fn=<MulBackward0>)
Custom importance-based Pruning
You may have noticed that some of the previous functions support the importance_score
argument:
torch.nn.utils.prune.l1_unstructured(module, name, amount, importance_scores=None)
torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)
torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)
You can pass a tensor (or list of tensors for global_unstructured
) to those functions of the same shape as your parameter with your custom entries of pruning information. This serves as a replacement for the magnitude and provides you an option to replace it with any custom scoring.
For example, let’s implement a simple pruning approach that eliminates the first 5 entries in a linear layers weight tensor:
>>> linear = torch.nn.Linear(3, 3)
>>> prune.l1_unstructured(linear, name="weight", amount=5, importance_scores=torch.arange(9).view(3, 3))
>>> linear.weight
tensor([[-0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, -0.1293],
[ 0.1886, 0.4086, -0.1588]], grad_fn=<MulBackward0>)
Helper functions
PyTorch also offers a couple of helper functions. The first I want to show is:
torch.nn.utils.prune.is_pruned(module)
As you may have guessed, this function allows you to inspect if any parameter in a module has been pruned. It returns True if a module was pruned. However, you cannot specify which parameter to check.
The last function I want to show you is:
torch.nn.utils.prune.remove(module, name)
Naively, you may think that this undoes the pruning but it does quite the opposite: It applies the pruning by removing the mask, the original parameter, and the forward hook. Lastly, it writes the pruned tensor into the parameter. Consequently, calling torch.nn.utils.prune.is_pruned(module)
on such a module would return False.
Conclusion
PyTorch offers a built-in way to apply unstructured or structured pruning to tensors randomly, by magnitude, or by a custom metric. However, the API is a bit confusing and the documentation could be improved.
Thank you for reading this article! If you enjoyed it please consider subscribing to my updates. If you have any questions feel free to leave them in the comments.
References:
[1] P. Gavrikov and J. Keuper, CNN Filter DB: An Empirical Investigation of Trained Convolutional Filters (2022), CVPR 2022 Orals
This work was funded by the Ministry for Science, Research and Arts, Baden-Wuerttemberg, Germany under Grant 32–7545.20/45/1 (Q-AMeLiA).