The world’s leading publication for data science, AI, and ML professionals.

Convenient Reinforcement Learning With Stable-Baselines3

Reinforcement learning without the boilerplate code

Reinforcement Learning

Created by the author with Leonardo Ai.
Created by the author with Leonardo Ai.

In my previous articles about reinforcement learning, I have shown you how to implement (deep) Q-learning using nothing but a bit of numpy and TensorFlow. While this was an important step towards understanding how these algorithms work under the hood, the code tended to get lengthy – and I even merely implemented one of the most basic versions of deep Q-learning.

Hands-On Deep Q-Learning

Given the explanations in this article, understanding the code should be quite straightforward. However, if we really want to get things done, we should rely on well-documented, maintained, and optimized libraries. Just as we don’t want to implement linear regression over and over again, we don’t want to do the same for reinforcement learning.


In this article, I will show you the reinforcement library Stable-Baselines3 which is as easy to use as scikit-learn. Instead of training models to predict labels, though, we get trained agents that can navigate well in their environment.

Here is the code and my trained best models on my Github.

A Short Recap

If you are not sure what (deep) Q-learning is about, I suggest reading my previous articles. On a high level, we want to train an agent that interacts with its environment with the goal of maximizing its total reward. The most important part of reinforcement learning is to find a good reward function for the agent.

I usually imagine a character in a game searching its way to get the highest score, e.g., Mario running from start to finish without dying and – in the best case – as fast as possible.

Image by the author.
Image by the author.

In order to do so, in Q-learning, we learn quality values for each pair (s, a) where s is a state and a is an action the agent can take. Q(s, a) is the expected discounted future reward when doing action a in state s. As an example, being in the state s = "standing in front of a cliff" and doing the action a = "do one step forward" should have a very low value of Q(s, a).

We can turn this Q-function into a policy then; imagine a magical compass that tells us what to do in any given state. The method is simple: if we are in state s, just compute Q(s, a) for all possible actions a and pick the action with the highest value. Done!

In my other articles, we have seen how to get these Q-values using a table or neural networks. Now, we want to sit back and just enjoy the simplicity of Stable-Baselines3. We deserved it.

Enter Stable-Baselines3

We have already developed agents that play a variety of games, such as Frozen Lake (get the present without falling into the lakes), Taxi (pick up a customer and bring them to the hotel), or Cart Pole (balance a stick).

Frozen Lake, Taxi, and Cart Pole. Images by the author.
Frozen Lake, Taxi, and Cart Pole. Images by the author.

We could recreate agents that master these games, but let us start with something different: Mountain Car!

The Mountain Car game

In this game, we steer a car that should go up a mountain. The actions we can take are going left, going right, or doing nothing. Our training goal is to go from here…

A greedy agent that only wants to move directly to the top of the hill. Image by the author.
A greedy agent that only wants to move directly to the top of the hill. Image by the author.

… to here:

A smart agent that gains momentum first to reach its goal. Image by the author.
A smart agent that gains momentum first to reach its goal. Image by the author.

Training the model is extremely simple with Stable-Baselines3. Just look:

import gymnasium as gym
from stable_baselines3 import DQN

env_name = "MountainCar-v0"
env = gym.make(env_name)

config = {
    'batch_size': 128,
    'buffer_size': 10000,
    'exploration_final_eps': 0.07,
    'exploration_fraction': 0.2,
    'gamma': 0.98,
    'gradient_steps': 8, # don't do a single gradient update, but 8
    'learning_rate': 0.004,
    'learning_starts': 1000,
    'policy_kwargs': dict(net_arch=[256, 256]), # we train a neural network with two hidden layers of size 256 each
    'target_update_interval': 600, # see below, the target network gets overwritten with the main network every 600 steps
    'train_freq': 16, # don't train after every step in the environment, but after 16 steps
}

model = DQN("MlpPolicy", env, verbose=1, **config) # MlpPolicy = train a normal feed-forward neural network
model.learn(total_timesteps=2000, progress_bar=True)

The magic is about finding good hyperparameters for the config , but this is something we as Machine Learning practitioners have to figure out. Or let dedicated hyperparameter optimization tools handle it.

Behind the curtains

We know most things that happen in the .learn method already. If you check out the source code, you will see many old friends from my other articles. For example, if you look here, you can find code like

for _ in range(gradient_steps):
  # Sample replay buffer
  replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]

  with th.no_grad():
      # Compute the next Q-values using the target network
      next_q_values = self.q_net_target(replay_data.next_observations)
      # Follow greedy policy: use the one with the highest value
      next_q_values, _ = next_q_values.max(dim=1)
      # Avoid potential broadcast issue
      next_q_values = next_q_values.reshape(-1, 1)
      # 1-step TD target
      target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

  # Get current Q-values estimates
  current_q_values = self.q_net(replay_data.observations)

We have a replay memory, there is the Q-value update step (1-step TD target). This shouldn’t look too scary anymore. A noteworthy difference is that the library uses double Q-learning, something that I did not implement. The idea is easy though: instead of having one Q-value neural network, we have two.

In the source code above, self.q_net (called the main network) is the one that gets normally trained. On the other hand, self.q_net_target (called the target network) is used for producing the labels to train our main network. Every few epochs, the target network gets set to the main network, so you can see the target network as a lagged version of the main network.

If both are the same, we use our network (there is only one) to produce labels, and then update the network’s weights. But this in turn changes the targets again, so essentially we try to learn moving targets – the training might be unstable. Double Q-learning with its two-network approach fixes this problem.

Callbacks

Training takes a long time, and it is always sad to lose progress because your program crashes. So Stable-Baselines3 offers some nice callbacks to save your progress over time. I recommend using EvalCallback and CheckpointCallback .

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback

env_name = "MountainCar-v0"

# callback to check the agent's performance every 1000 steps for 10 episodes
eval_callback = EvalCallback(
    eval_env=env,
    eval_freq=1000,
    n_eval_episodes=10,
    best_model_save_path=f"./logs/{env_name}", 
    log_path=f'./logs/{env_name}',
)

# callback to save the model every 10000 steps
save_callback = CheckpointCallback(save_freq=10000, save_path=f'./logs/{env_name}')

You can just pass these callbacks here:

model.learn(total_timesteps=2000, progress_bar=True, callback=[eval_callback, save_callback])

The EvalCallback also saves some nice performance numbers that you can plot.

The mean reward (over 10 runs) over time. Image by the author.
The mean reward (over 10 runs) over time. Image by the author.

You can see how for about 40,000 timesteps, the mode did not learn much. A reward of -200 indicates that the model did not reach the top – an episode ends after 200 timesteps. Then, suddenly the learning took off until the agent consistently reached the top of the mountain. You can plot it like this:

import numpy as np
import pandas as pd

data = np.load(f"./logs/{env_name}/evaluations.npz")
pd.DataFrame({
    "mean_reward": data["results"].mean(axis=1),
    "timestep": data["timesteps"]
}).plot(
    x="timestep",
    y="mean_reward",
)

Playing Atari Games

Okay, cool, so we beat some kindergarten games. Time to tackle something more challenging: Atari games! For the young folks: Atari was a leader in the video game market back in the 80s. They also invented the game Pong, our beloved game consisting of two sticks playing tennis.

An Atari 2600 that I used to play with as a child. Public domain image by Evan Amos.
An Atari 2600 that I used to play with as a child. Public domain image by Evan Amos.

Most of their games are still simple, but at least they feel like real games that challenge you already. To spice things up, we will only use the raw screen pixels to train our agent! No more internal game states such as coordinates, velocities, or angles of objects. The machine has to learn how to play the game in the same way as a human: by looking at the screen and figuring out what to do.

Breakout

As an example, let us use Breakout, a game where we have to destroy blocks using a ball. The ball jumps around, bouncing off from the blocks, but also the ship we control. We can steer a "spaceship" left and right to keep the ball in play. But let us just look at a game scene with our agent in the main role:

Our deep Q-learning agent playing Breakout. Image by the author.
Our deep Q-learning agent playing Breakout. Image by the author.

This agent was trained for about 3,000,000 frames using GPUs and training on 4 environments at the same time on GCP (8 vCPUs, 30 GB RAM, NVIDIA T4 x 4). It took about 3 hours to train it. Besides using a big machine, I boosted the performance using the AtariWrapper that scales down the images to a size of 84 x 84 pixels and grayscales them since colors are not important in this game. We also use a convolutional neural network as opposed to a simple feed-forward neural network to achieve better results in less time. Here is the code:

import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv, VecFrameStack, VecTransposeImage

if __name__ == "__main__":
    env_name = "BreakoutNoFrameskip-v4"
    env = SubprocVecEnv([lambda: AtariWrapper(gym.make(env_name)) for _ in range(4)]) # train 4 game environments in parallel, scale down images for faster training
    env = VecFrameStack(env, n_stack=4) # don't only use a still image for training, but the last 4 frames
    env = VecTransposeImage(env) # technical magic for putting the channels of the animation in the first coordinate, i.e., turning HxWxC into CxHxW since Stable-Baselines3 likes it that way

    config = {
        "batch_size": 32,
        "buffer_size": 10000,
        "exploration_final_eps": 0.02,
        "exploration_fraction": 0.1,
        "gamma": 0.99,
        "gradient_steps": 4,
        "learning_rate": 1e-4,
        "learning_starts": 10000,
        "target_update_interval": 1000,
        "train_freq": 4,
    }

    eval_callback = EvalCallback(
        eval_env=env,
        eval_freq=1000,
        n_eval_episodes=10,
        best_model_save_path=f"./logs/{env_name}",
        log_path=f"./logs/{env_name}",
    )
    save_callback = CheckpointCallback(save_freq=10000, save_path=f"./logs/{env_name}")

    model = DQN("CnnPolicy", env, verbose=0, **config) # CnnPolicy creates some default convolutional neural network for us for processing the screen pixels in a more efficient way
    model.learn(total_timesteps=10_000_000, progress_bar=True, callback=[eval_callback, save_callback])

Note: Jupyterlab usually has problems with multiprocessing, so you might have to paste this code into a .py file and run it from the command line. Also notice that I feed the network not only single images of the game, but four consecutive images with the line

env = VecFrameStack(env, n_stack=4)

This way, the agent can learn the direction and speed of the ball in addition to its position. Otherwise, how could it tell what is going on?

Where is the ball going? Image by the author.
Where is the ball going? Image by the author.

The 4 is just a hyperparameter, feel free to try other values as well. This little trick makes it possible for the agent to learn how to play this game without any internal game information.


As usual, the performance of the agent is quite jumpy over the episodes. Still, you can clearly see how the trend goes up over time:

Image by the author.
Image by the author.

Space Invaders

Another classic – that was a response to Breakout – is the game Space Invaders. In case you don’t know: you shoot aliens and try not to get shot. By just replacing a single line in the code again, we can train an agent that can beat one wave of the game before 3,000,000 steps of training:

Image by the author.
Image by the author.

However, I cherry-picked this run. Usually, my agent dies faster, but it is still quite good:

Image by the author.
Image by the author.

You can train it via

...

if __name__ == "__main__":
    env_name = "SpaceInvadersNoFrameskip-v4"
    ...

Of course, you can now retrain agents to play all of the Atari games.

Conclusion

In this article, we have seen a way to train agents without much boilerplate code. At the moment, I would consider Stable-Baselines3 the scikit-learn of Reinforcement Learning: you define the model, configure it a bit, and .learn the game. It cannot get much simpler than this.

Still, I advocate for understanding what is going on behind the curtains. Otherwise, you might be lost when things don’t work out of the box. The same goes for classical machine learning, or any other algorithm. First, at least understand the fundamentals, and then treat yourself to a nice library!

As a last point, if you check out the library’s documentation, you will see that it supports more learning algorithms, such as

If you want to have a nice alternative to deep Q-learning, from what I have seen PPO seems to be popular. Play around with all of them to see if you can find something more performant for your learning problem! But make sure to look up how these methods work as well – maybe in one of my future articles!


I hope that you learned something new, interesting, and valuable today. Thanks for reading!

If you have any questions, write me on LinkedIn!

And if you want to dive deeper into the world of algorithms, give my new publication All About Algorithms a try! I’m still searching for writers!

All About Algorithms


Related Articles