Reinforcement Learning

In my previous articles about reinforcement learning, I have shown you how to implement (deep) Q-learning using nothing but a bit of numpy and TensorFlow. While this was an important step towards understanding how these algorithms work under the hood, the code tended to get lengthy – and I even merely implemented one of the most basic versions of deep Q-learning.
Given the explanations in this article, understanding the code should be quite straightforward. However, if we really want to get things done, we should rely on well-documented, maintained, and optimized libraries. Just as we don’t want to implement linear regression over and over again, we don’t want to do the same for reinforcement learning.
In this article, I will show you the reinforcement library Stable-Baselines3 which is as easy to use as scikit-learn. Instead of training models to predict labels, though, we get trained agents that can navigate well in their environment.
Here is the code and my trained best models on my Github.
A Short Recap
If you are not sure what (deep) Q-learning is about, I suggest reading my previous articles. On a high level, we want to train an agent that interacts with its environment with the goal of maximizing its total reward. The most important part of reinforcement learning is to find a good reward function for the agent.
I usually imagine a character in a game searching its way to get the highest score, e.g., Mario running from start to finish without dying and – in the best case – as fast as possible.

In order to do so, in Q-learning, we learn quality values for each pair (s, a) where s is a state and a is an action the agent can take. Q(s, a) is the expected discounted future reward when doing action a in state s. As an example, being in the state s = "standing in front of a cliff" and doing the action a = "do one step forward" should have a very low value of Q(s, a).
We can turn this Q-function into a policy then; imagine a magical compass that tells us what to do in any given state. The method is simple: if we are in state s, just compute Q(s, a) for all possible actions a and pick the action with the highest value. Done!
In my other articles, we have seen how to get these Q-values using a table or neural networks. Now, we want to sit back and just enjoy the simplicity of Stable-Baselines3. We deserved it.
Enter Stable-Baselines3
We have already developed agents that play a variety of games, such as Frozen Lake (get the present without falling into the lakes), Taxi (pick up a customer and bring them to the hotel), or Cart Pole (balance a stick).



We could recreate agents that master these games, but let us start with something different: Mountain Car!
The Mountain Car game
In this game, we steer a car that should go up a mountain. The actions we can take are going left, going right, or doing nothing. Our training goal is to go from here…

… to here:

Training the model is extremely simple with Stable-Baselines3. Just look:
import gymnasium as gym
from stable_baselines3 import DQN
env_name = "MountainCar-v0"
env = gym.make(env_name)
config = {
'batch_size': 128,
'buffer_size': 10000,
'exploration_final_eps': 0.07,
'exploration_fraction': 0.2,
'gamma': 0.98,
'gradient_steps': 8, # don't do a single gradient update, but 8
'learning_rate': 0.004,
'learning_starts': 1000,
'policy_kwargs': dict(net_arch=[256, 256]), # we train a neural network with two hidden layers of size 256 each
'target_update_interval': 600, # see below, the target network gets overwritten with the main network every 600 steps
'train_freq': 16, # don't train after every step in the environment, but after 16 steps
}
model = DQN("MlpPolicy", env, verbose=1, **config) # MlpPolicy = train a normal feed-forward neural network
model.learn(total_timesteps=2000, progress_bar=True)
The magic is about finding good hyperparameters for the config
, but this is something we as Machine Learning practitioners have to figure out. Or let dedicated hyperparameter optimization tools handle it.
Behind the curtains
We know most things that happen in the .learn
method already. If you check out the source code, you will see many old friends from my other articles. For example, if you look here, you can find code like
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
with th.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)
We have a replay memory, there is the Q-value update step (1-step TD target). This shouldn’t look too scary anymore. A noteworthy difference is that the library uses double Q-learning, something that I did not implement. The idea is easy though: instead of having one Q-value neural network, we have two.
In the source code above, self.q_net
(called the main network) is the one that gets normally trained. On the other hand, self.q_net_target
(called the target network) is used for producing the labels to train our main network. Every few epochs, the target network gets set to the main network, so you can see the target network as a lagged version of the main network.
If both are the same, we use our network (there is only one) to produce labels, and then update the network’s weights. But this in turn changes the targets again, so essentially we try to learn moving targets – the training might be unstable. Double Q-learning with its two-network approach fixes this problem.
Callbacks
Training takes a long time, and it is always sad to lose progress because your program crashes. So Stable-Baselines3 offers some nice callbacks to save your progress over time. I recommend using EvalCallback
and CheckpointCallback
.
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
env_name = "MountainCar-v0"
# callback to check the agent's performance every 1000 steps for 10 episodes
eval_callback = EvalCallback(
eval_env=env,
eval_freq=1000,
n_eval_episodes=10,
best_model_save_path=f"./logs/{env_name}",
log_path=f'./logs/{env_name}',
)
# callback to save the model every 10000 steps
save_callback = CheckpointCallback(save_freq=10000, save_path=f'./logs/{env_name}')
You can just pass these callbacks here:
model.learn(total_timesteps=2000, progress_bar=True, callback=[eval_callback, save_callback])
The EvalCallback
also saves some nice performance numbers that you can plot.

You can see how for about 40,000 timesteps, the mode did not learn much. A reward of -200 indicates that the model did not reach the top – an episode ends after 200 timesteps. Then, suddenly the learning took off until the agent consistently reached the top of the mountain. You can plot it like this:
import numpy as np
import pandas as pd
data = np.load(f"./logs/{env_name}/evaluations.npz")
pd.DataFrame({
"mean_reward": data["results"].mean(axis=1),
"timestep": data["timesteps"]
}).plot(
x="timestep",
y="mean_reward",
)
Playing Atari Games
Okay, cool, so we beat some kindergarten games. Time to tackle something more challenging: Atari games! For the young folks: Atari was a leader in the video game market back in the 80s. They also invented the game Pong, our beloved game consisting of two sticks playing tennis.

Most of their games are still simple, but at least they feel like real games that challenge you already. To spice things up, we will only use the raw screen pixels to train our agent! No more internal game states such as coordinates, velocities, or angles of objects. The machine has to learn how to play the game in the same way as a human: by looking at the screen and figuring out what to do.
Breakout
As an example, let us use Breakout, a game where we have to destroy blocks using a ball. The ball jumps around, bouncing off from the blocks, but also the ship we control. We can steer a "spaceship" left and right to keep the ball in play. But let us just look at a game scene with our agent in the main role:

This agent was trained for about 3,000,000 frames using GPUs and training on 4 environments at the same time on GCP (8 vCPUs, 30 GB RAM, NVIDIA T4 x 4). It took about 3 hours to train it. Besides using a big machine, I boosted the performance using the AtariWrapper
that scales down the images to a size of 84 x 84 pixels and grayscales them since colors are not important in this game. We also use a convolutional neural network as opposed to a simple feed-forward neural network to achieve better results in less time. Here is the code:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.vec_env import SubprocVecEnv, VecFrameStack, VecTransposeImage
if __name__ == "__main__":
env_name = "BreakoutNoFrameskip-v4"
env = SubprocVecEnv([lambda: AtariWrapper(gym.make(env_name)) for _ in range(4)]) # train 4 game environments in parallel, scale down images for faster training
env = VecFrameStack(env, n_stack=4) # don't only use a still image for training, but the last 4 frames
env = VecTransposeImage(env) # technical magic for putting the channels of the animation in the first coordinate, i.e., turning HxWxC into CxHxW since Stable-Baselines3 likes it that way
config = {
"batch_size": 32,
"buffer_size": 10000,
"exploration_final_eps": 0.02,
"exploration_fraction": 0.1,
"gamma": 0.99,
"gradient_steps": 4,
"learning_rate": 1e-4,
"learning_starts": 10000,
"target_update_interval": 1000,
"train_freq": 4,
}
eval_callback = EvalCallback(
eval_env=env,
eval_freq=1000,
n_eval_episodes=10,
best_model_save_path=f"./logs/{env_name}",
log_path=f"./logs/{env_name}",
)
save_callback = CheckpointCallback(save_freq=10000, save_path=f"./logs/{env_name}")
model = DQN("CnnPolicy", env, verbose=0, **config) # CnnPolicy creates some default convolutional neural network for us for processing the screen pixels in a more efficient way
model.learn(total_timesteps=10_000_000, progress_bar=True, callback=[eval_callback, save_callback])
Note: Jupyterlab usually has problems with multiprocessing, so you might have to paste this code into a .py file and run it from the command line. Also notice that I feed the network not only single images of the game, but four consecutive images with the line
env = VecFrameStack(env, n_stack=4)
This way, the agent can learn the direction and speed of the ball in addition to its position. Otherwise, how could it tell what is going on?

The 4 is just a hyperparameter, feel free to try other values as well. This little trick makes it possible for the agent to learn how to play this game without any internal game information.
As usual, the performance of the agent is quite jumpy over the episodes. Still, you can clearly see how the trend goes up over time:

Space Invaders
Another classic – that was a response to Breakout – is the game Space Invaders. In case you don’t know: you shoot aliens and try not to get shot. By just replacing a single line in the code again, we can train an agent that can beat one wave of the game before 3,000,000 steps of training:

However, I cherry-picked this run. Usually, my agent dies faster, but it is still quite good:


You can train it via
...
if __name__ == "__main__":
env_name = "SpaceInvadersNoFrameskip-v4"
...
Of course, you can now retrain agents to play all of the Atari games.
Conclusion
In this article, we have seen a way to train agents without much boilerplate code. At the moment, I would consider Stable-Baselines3 the scikit-learn of Reinforcement Learning: you define the model, configure it a bit, and .learn
the game. It cannot get much simpler than this.
Still, I advocate for understanding what is going on behind the curtains. Otherwise, you might be lost when things don’t work out of the box. The same goes for classical machine learning, or any other algorithm. First, at least understand the fundamentals, and then treat yourself to a nice library!
As a last point, if you check out the library’s documentation, you will see that it supports more learning algorithms, such as
- Asynchronous Advantage Actor Critic (A2C)
- Proximal Policy Optimization (PPO), or
- Deep Deterministic Policy Gradient (DDPG).
If you want to have a nice alternative to deep Q-learning, from what I have seen PPO seems to be popular. Play around with all of them to see if you can find something more performant for your learning problem! But make sure to look up how these methods work as well – maybe in one of my future articles!
I hope that you learned something new, interesting, and valuable today. Thanks for reading!
If you have any questions, write me on LinkedIn!
And if you want to dive deeper into the world of algorithms, give my new publication All About Algorithms a try! I’m still searching for writers!