Rocking Hyperparameter Tuning with PyTorch’s Ax Package

Use bayesian and bandit optimization for shorter tunings, while enjoying an unbelievably simple set up. It works like a charm.

Marina Gandlin
Towards Data Science

--

Jarosław Miś

Hyperparameter tuning is a must with many machine learning tasks. We usually work hard on selecting the right algorithm and architecture for our problem, then train rigorously to get a great model. Doing hyperparameter tuning (HPT) after these two might seem unnecessary, but it is, in fact, crucial. HPT should be done periodically and might help us achieve great improvements in performance at a small effort.

What i’m suggesting here is an easy way to implement HPT with a newly released python package. It will work perfectly and take no more than half an hour to set up for anything you train on your own computer. For other models, especially ones that require training on deployments or running in production, this will work with some small adjustments we will further discuss in part B of this post.

What is Ax?

Ax is an open-source package from PyTorch that helps you find a minima for any function over the range of parameters you define. The most common use case in machine learning is finding the best hyperparameters for training a model, the ones that will bring your overall loss to a minimum. The package does this by running multiple runs of training, each with a different set of parameters and returning the ones that gave the lowest loss. The trick is that it does so without grid search or random search over these parameters, but with a more sophisticated algorithm, hence saving a lot of training and run-time.

Ax can find minimas for both continuous parameters (say, learning rate) and discrete parameters (say, size of a hidden layer). It uses bayesian optimization for the former and bandit optimization for the latter. Even though the package is from pytorch, it will work for any function, as long as it returns a single value you want to minimize.

Before we start — Installation instructions can be found here.

Part A — “Hello world!”

Ax has a couple of operating modes, but we’ll start from the most basic one with a small example. As mentioned before, we’ll usually use Ax to help us find the best hyperparameters, but at its core, this package helps us find a function’s minimum with respect to some parameters. That’s why for this example we’ll run Ax to find the minimum of a complex quadratic function. For that we’ll define a function named booth that receives its parameters {x1,x2} within a dictionary p:

# Sample Quadratic Function
def booth(p):
# p = dictionary of parameters
return (p[“x1”] + 2*p[“x2”] — 7)**2 + (2*p[“x1”] + p[“x2”] — 5)**2
print(booth({“x1”:6, “x2”: 7}))

This returns : 365.

To find the minimum for “booth”, we’ll let ax run the function several times. The parameters it chooses for each run are dependent on previous runs — which parameters it ran on and what was the result of the function on these parameters (in machine learning, “result” == “loss”). Running this next code bit performs 20 consecutive function runs, with 20 different sets of parameters {x1,x2}. Then prints out the “best_parameters”. For comparison, if we wanted to run this with grid search, with jumps of 0.5 for x1 and x2, we would need 1600 runs instead of 20!

from ax import optimize
best_parameters, best_values, _, _ = optimize(
parameters=[
{“name”: “x1”,
“type”: “range”,
“bounds”: [-10.0, 10.0],},
{“name”: “x2”,
“type”: “range”,
“bounds”: [-10.0, 10.0],},],
evaluation_function=booth,
minimize=True,)
print(best_parameters)

this prints out ( true min is (1, 3)):

{'x1': 1.0358775112792173, 'x2': 2.8776698220783423}

Pretty close! This probably takes less than 10 seconds on a laptop. You might notice that the “best parameters” that were returned are not exactly the true minimum, but they bring the function very close to the actual minimum. How forgiving you are to margins in the result is a parameter you can play with later.

The only thing we needed to do here is to make sure we have a function with a dictionary as input that returns a single value. This is what you’ll need in most cases when running ax.

Part B — Running with Deployments

Part A was pretty easy and will work wonderfully on anything you run on your computer, but it has a couple of faults:

  1. It only runs one “training” at a time, which is bad if a training session in more than a few minutes.
  2. It only runs locally, which is bad if your regular model runs on the cloud / deployments etc.

For that we need something more sophisticated: preferably, some kind of magic way to just get a couple sets of parameters to run on, so we could deploy trainings and then patiently wait until they get results back. When the first batch of trainings is finished, we can let ax know what were the losses, get the next set of parameters and start the next batch. In ax lingo, this way of running ax this is what’s called “ax service”. It is quick simple to phrase and run, almost as simple as what we saw in part A.

To fit the “ax paradigm” we still need some kind of function that runs the training and returns a loss, maybe something that looks like this:

def training_wrapper(parameters_dictionary): 1. run training with parameters_dictionary (this writes the loss value to somewhere in the cloud)2. read loss from somewhere in the cloud3. return loss

After defining a proper wrapper, we could run ax service HPT code. Since I want you to be able to run this demo code on your computer, I will stick to using the quadratic function we introduced in part A — “booth”. The parameter ranges for {x1,x2} will stay the same. For running deployments I could replace booth in the next sample code with my version of “training_wrapper”. The sample code looks like this:

from ax.service.ax_client import AxClientax = AxClient(enforce_sequential_optimization=False)ax.create_experiment(
name=”booth_experiment”,
parameters=[
{“name”: “x1”,
“type”: “range”,
“bounds”: [-10.0, 10.0],},
{“name”: “x2”,
“type”: “range”,
“bounds”: [-10.0, 10.0],},],
objective_name=”booth”,
minimize=True,
)
for _ in range(15):
next_parameters, trial_index = ax.get_next_trial()
ax.complete_trial(trial_index=trial_index, raw_data=booth(next_parameters))
best_parameters, metrics = ax.get_best_parameters()

This is a little different from part A. We no longer just use the “optimize” function. Instead:

  1. We initialize an ax client.
  2. We set up an “experiment”, and choose a range of hyperparameters we want to check.
  3. We get the next set of parameters we want to run the function on with get_next_trial().
  4. Wait for the function to complete it’s run with complete_trial()

Meaning, we separated getting the next parameters for the next run and actually running. If you want to run concurrently, you get N get_next_trial() at a time and run them async. Make sure you don’t forget setting the “enforce_sequential_optimization” flag to false if you want to do so. If you are wondering how many runs you could do concurrently, you can use the get_recommended_max_parallelism (Read about the output of this function here).

That's pretty much it. The package is wonderfully customizable and could probably handle whatever you might want to change for your specific environment. The documentation is very readable, especially the tutorials. It also has a wide variety of visualizations to help you figure out what’s going on.

Summary

The days when you had to manually set up multiple runs or use grid search are pretty much over. It is true that on a very large set of hyperparameters ax might result to something resembling random search, but even if it didn’t reduce the number of trainings you ended up running, it still saved you the coding and the handling of the multiple runs. Setting up is super easy — I highly recommend you try it yourself!

Links:

  1. General Documentation : https://ax.dev/docs/why-ax.html
  2. Tutorials : https://ax.dev/tutorials/
  3. Visualizations : https://ax.dev/tutorials/visualizations.html

--

--

Data science & tech lead focused on large scale online recommendations. Used to design satellites, but thinks machine learning is way cooler.