How to tune Pytorch Lightning hyperparameters

Wondering how to optimize Pytorch Lightning hyperparameters in 30 lines of code?

Richard Liaw
Towards Data Science

--

Pytorch Lightning is one of the hottest AI libraries of 2020, and it makes AI research scalable and fast to iterate on. But if you use Pytorch Lightning, you’ll need to do hyperparameter tuning.

Proper hyperparameter tuning can make the difference between a good training run and a failing one. Everyone knows that you can dramatically boost the accuracy of your model with good tuning methods!

In this blog post, we’ll demonstrate how to use Ray Tune, an industry standard for hyperparameter tuning, with PyTorch Lightning. Ray Tune is part of Ray, a library for scaling Python.

It is available as a PyPI package and can be installed like this:

pip install "ray[tune]"

To use Ray Tune with PyTorch Lightning, we only need to add a few lines of code!!

Getting started with Ray Tune + PTL!

To run the code in this blog post, be sure to first run:

pip install "ray[tune]" 
pip install "pytorch-lightning>=1.0"
pip install "pytorch-lightning-bolts>=0.2.5"

The below example is tested on ray==1.0.1 , pytorch-lightning==1.0.2, and pytorch-lightning-bolts==0.2.5. See the full example here.

Let’s first start with some imports:

After imports, there are three easy steps.

  1. Create your LightningModule
  2. Create a function that calls Trainer.fit with the Tune callback
  3. Use tune.run to execute your hyperparameter search.

Step 1: create your LightningModule

First step, create your LightningModule. Your LightningModule should take a configuration dict as a parameter on initialization. This dict should then set the model parameters you want to tune. Your module could look like this:

Step 2: Create a function that calls Trainer.fit with the Tune callback

We will use a callback to communicate with Ray Tune. The callback is very simple:

from ray.tune.integration.pytorch_lightning import TuneReportCallback...
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
callbacks = [TuneReportCallback(metrics, on="validation_end")]
trainer = pl.Trainer(... callbacks=callbacks)

This callback ensures that after each validation epoch, we report the loss metrics back to Ray Tune. The val_loss and val_accuracy keys correspond to the return value of the validation_epoch_end method.

Further, Ray Tune will start a number of different training runs. To create multiple training runs (for the hyperparameter search), we thus need to wrap the trainer call in a function:

The train_mnist() function expects a config dict, which it then passes to the LightningModule. This config dict will contain the hyperparameter values of one evaluation.

Step 3: Use tune.run to execute your hyperparameter search.

Finally, we need to call ray.tune to optimize our parameters. Here, our first step is to tell Ray Tune which values are valid choices for the parameters. This is called the search space, and we can define it like so:

# Defining a search space!config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128])
}

Let’s take a quick look at the search space. For the first and second layer sizes, we let Ray Tune choose between three different fixed values. The learning rate is sampled between 0.0001 and 0.1. For the batch size, also a choice of three fixed values is given. Of course, there are many other (even custom) methods available for defining the search space.

Ray Tune will now proceed to sample ten different parameter combinations randomly, train them, and compare their performance afterwards.

We wrap the train_mnist function in tune.with_parameters to pass constants like the maximum number of epochs to train each model and the number of GPUs available for each trial. Ray Tune supports fractional GPUs, so something like gpus=0.25 is totally valid as long as the model still fits on the GPU memory.

# Execute the hyperparameter search
analysis = tune.run(
tune.with_parameters(train_mnist_tune, epochs=10, gpus=0),
config=config,
num_samples=10)

The final invocation of tune.run can look like this:

And finally, the tuning result could look like this:

In this simple example, a number of configurations reached a good accuracy. The best result we observed was a validation accuracy of 0.978105 with a batch size of 32, layer sizes of 128 and 64, and a small learning rate around 0.001. We can also see that the learning rate seems to be the main factor influencing performance — if it is too large, the runs fail to reach a good accuracy.

You can retrieve the best score by using the return value of tune.run :

analysis = tune.run(...)best_trial = analysis.best_trial  # Get best trial
best_config = analysis.best_config # Get best trial's hyperparameters
best_logdir = analysis.best_logdir # Get best trial's logdir
best_checkpoint = analysis.best_checkpoint # Get best trial's best checkpoint
best_result = analysis.best_result # Get best trial's last results

You can also easily leverage some of Ray Tune’s more powerful optimization features. For example, Ray Tune’s search algorithm allows you to easily optimize the landscape of hyperparameter combinations. Ray Tune’s schedulers can also stop bad performing trials early to save resources.

See the full example here.

Inspecting results in Tensorboard

Ray Tune automatically exports metrics into TensorBoard, and also easily supports W&B.

That’s it!

To enable easy hyperparameter tuning with Ray Tune, we only needed to add a callback, wrap the train function, and then start Tune.

Of course, this is a very simple example that doesn’t leverage many of Ray Tune’s search features, like early stopping of bad performing trials or population based training. If you would like to see a full example for these, please have a look at our full PyTorch Lightning tutorial.

If you’ve been successful in using PyTorch Lightning with Ray Tune, or if you need help with anything, please reach out by joining our Slack or dropping by our Github — we would love to hear from you!

If you liked this blog post, be sure to check out:

--

--