A primer on functional PyTorch

How to use write Jax-style PyTorch models

Mario Dagrada
Towards Data Science
6 min readMay 7, 2023

--

Photo by Ricardo Gomez Angel on Unsplash

PyTorch has recently integrated the torch.func module into its main codebase in the 2.0 release. This module, previously known as functorch, enables the development of purely functional neural network models in PyTorch with a straightforward API. This package is PyTorch’s response to the growing popularity of Jax, a Python framework for general differentiable programming built using a functional programming paradigm from the ground up.

In this post, we will first introduce the basics of torch.func, followed by a simple end-to-end example of using a neural network (NN) model to fit a non-linear function. While using an NN for this task is admittedly overkill, it works well for illustrative purposes. Additionally, we will discover some of the benefits of adopting a functional approach when constructing NN models.

Write a functional model

Using torch.func begins in the same way as standard PyTorch: you need to construct a neural network. For simplicity, let us define a very simple one composed of an arbitrary number of linear layers and non-linear activation functions. The forward pass takes a batch of data points as input, where the model is evaluated.

class SimpleNN(nn.Module):
def __init__(
self,
num_layers: int = 1,
num_neurons: int = 5,
) -> None:
"""Basic neural network architecture with linear layers

Args:
num_layers (int, optional): number of hidden layers
num_neurons (int, optional): neurons for each hidden layer
"""
super().__init__()

layers = []

# input layer
layers.append(nn.Linear(1, num_neurons))

# hidden layers with linear layer and activation
for _ in range(num_layers):
layers.extend([nn.Linear(num_neurons, num_neurons), nn.Tanh()])

# output layer
layers.append(nn.Linear(num_neurons, 1))

# build the network
self.network = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x.reshape(-1, 1)).squeeze()

Now things get interesting. Recall that torch.func allows to build purely functional models. But what does in mean in practice?

First of all, one needs to grasp the central concept of functional programming: pure functions. Essentially, a pure function has two defining properties:

  • its return values are identical for identical input arguments
  • it has no side effects, meaning it does not modify its input arguments in any way

By this definition, most of the methods of a standard PyTorch module are not pure since the parameters are stored within the PyTorch model. In other words, a standard PyTorch model is stateful rather than stateless, as required by the functional paradigm. Consider this code:

import torch

x = torch.randn(10)
model = SimpleNN() # constructed above
optimizer = torch.optim.SGD(model.parameters())

# modify the state of the model
# by applying a single optimization step
out1 = model(x)
model.backward()
optimizer.step()

# recompute the output with exactly the same input
out2 = model(x)
assert not torch.equal(out1, out2)

The forward pass of the model has not the same output for identical input arguments since the optimizer updated the parameters in place.

One way for making a PyTorch module pure would then be to decouple the parameters from the model, thus making the model completely stateless. This is exactly what the torch.func.functional_call() routine does:

import torch
from torch.func import functional_call

x = torch.randn(10) # random input data
model = SimpleNN() # constructed above
params = dict(model.named_parameters()) # model parameters

# make a functional call to the model above
out = functional_call(model, params, (x,))

Now that we have a purely functional forward pass for our neural network model, we can explore how to use it with the composable primitives provided by PyTorch’s functional API. This allows us to construct complex models using modular building blocks, each with its own functional implementation.

Composable function transforms

I just showed how to define a purely functional forward pass for our model. But how can we define differentiation rules and loss functions with it? We need to use the composable function transforms provided by torch.func.

Function transforms consist of a set of routines, each of which returns a function that can be used to evaluate specific quantities. This kind of function that returns another function is known as a higher-order function. For example, one can use the grad primitive to evaluate the gradients with respect to the input data x as follow:

from torch.func import grad

# the `grad` function returns another function
# which takes the same inputs as the model forward pass
grad_fn = grad(model)

# now this function can be used to compute gradients
# with respect to the first input
params = tuple(model.parameters())
grad_values = grad_fn(x[0], params)

Notice that, by default, the grad function applies to a single number. One can use another function transform calledvmap to efficiently deal with batches of inputs. Notice that vmap performs also automatic parallelization when multiple CPUs or GPUs are available without any code change.

One important consequence of all the functions in the torch.func module being pure is their ability to be arbitrarily composed together (hence the term “composable” in their name). This is because a pure function can always be replaced with its result without affecting program execution, a direct consequence of the two properties mentioned above.

With this in mind, let’s calculate the second derivative for a batch of input data x:

from torch.func import grad, vmap

x = torch.rand(10)

# combine twice `grad` with `vmap` to compute
# the model second order derivative (Laplacian) with
# respect to batched input data
laplacian_fn = vmap(grad(grad(model)))
params = tuple(model.parameters())
out = laplacian_fn(x, params)

It’s worth noting that the forward pass of the model does not take parameters as input. As a result, to compute the gradient with respect to the parameters, we need to define an auxiliary make_functional_fwd routine with the appropriate arguments. In practice, we can achieve this using a closure, as shown below:

import torch
from torch.func import functional_call, grad

x = torch.randn(1) # random input data point
model = SimpleNN() # constructed above

# forward pass using the functional API
# to take the parameters as input arguments
def make_functional_fwd(_model):
def fn(data, parameters):
return functional_call(_model, parameters, (data,))
return fn

model_func = make_functional_fwd(my_model) # functional forward
params = tuple(my_model.parameters()) # model parameters

# the `argnums` argument allows to select with
# respect to which input argument of the functional forward
# pass defined in the closure
grad_params = grad(model_func, argnums=1)(x[0], params)

# as before but for computing the gradient with
# respect to the input data
grad_x = grad(model_func, argnums=0)(x[0], params)

The torch.func module offers many more composable function transforms for computing, for example, vector-Jacobian products. Here you can find more details.

Optimization with functional models

If you made until here, you might wonder how can you perform gradient-based optimization with a functional model. After all, the standard PyTorch optimizers works by modifying the model parameters in place which, as I just showed, breaks the pure function requirements.

Unfortunately, PyTorch does not natively provide functional optimizers. However, one can use thetorchopt library for this purpose.

For showing how functional optimization works, let’s assume that we want to fit a simple function, for example f(x) = 2 sin(x + 2π) using some random input points in the domain [0, 2π]. We can generate some training and test data points as follows:

import torch

def get_data(n_points = 20):
x = torch.rand(n_points) * 2.0 * torch.pi
y = 2.0 * torch.sin(x + 2.0 * torch.pi)
return x, y

x_train, y_train = get_data(n_points=40)
x_test, y_test = get_data(n_points=10)

Now let’s use torchoptand the PyTorch functional API to train our NN to fit this function.

import torch
import torchopt

# hyperparameters and optimizer choice from `torchopt`
num_epochs = 500
lr = 0.01
optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=lr))
loss_fn = torch.nn.MSELoss()

loss_evolution = [] # track the loss evolution per epoch
params = tuple(model.parameters()) # initialize the parameters

for i in range(num_epochs):

# update the parameters using the functional API
y = model_func(x_train, params)
loss = loss_fn(y, y_train)
params = optimizer.step(loss, params)
loss_evolution.append(float(loss))

if i % 100 == 0:
print(f"Iteration {i} with loss {float(loss)}")

# accuracy on test set
y_pred = model_func(x_test, params)
print(f"Loss on the test set: {loss_fn(y_pred, y_test)}")

As you can see, the optimization loop is pretty similar to standard PyTorch, with the crucial difference that the optimizer step now requires the current loss and the current parameters values, and evaluates the updated parameters in a fully stateless manner. In my opinion, this approach looks much cleaner than the typical stateful API of PyTorch!

For the complete code for this blog post, you can look at this code snippet. As you will notice, some details of the implementation (particularly the interaction between torch.func.functional_call and the torchopt optimizer) have not been covered in this blog. Feel free to send me a message on Linkedn if you have any questions.

Conclusions

Thank you for reading this blog post. The functional API of PyTorch is a powerful tool that enables you to write high-performance neural network models and utilize composable functions and automatic parallelization and vectorization, similar to Jax. However, it is still an experimental feature and should be used with caution. Happy coding!

--

--