The world’s leading publication for data science, AI, and ML professionals.

Sb3, the Swiss Army Knife of Applied RL

Your choice of model, with any environment

Stablebaseline3 (sb3) is like a Swiss Army knife. It is a multi-function utility tool, that can be used for many purpose. And, just like a Swiss Army knife can save your life if you are stranded in a jungle, sb3 can save your life in the office, when you have seemingly impossible deadlines to meet.

This guide uses gymnasium=0.28.1 and stable-baselines=2.1.0. If you use different versions, or perhaps even refer to other old guides, you may not get the results below. But fret not, an installation guide is given here as well. I guarantee you can get the results if you follow my instructions.

[1] What You Will Get Here

Stablebaseline3 is easy to use. It is also well documented, and you can follow the tutorials on your own. But…

  • Have you referred to older guides (perhaps those using gym), only to find errors on your machine?
  • Are you able to always ensure compatibility?
  • What if you want to use gymnasium‘s environment and modify perhaps the rewards?
  • Do you know how to wrap your own tasks, such that SOTA models can be applied in a few lines?

That’s the objective of this article! After reading this guided demonstration, you will…

  1. Solve classic environments with sb3 models, visualize the results, as well as save (or load) the trained model in a few lines of code. [Section 3.1]
  2. Understand how to check the action space and observation space for compatibility. [Section 3.2]
  3. Learn how to wrap gymnasiumenvironments so that any sb3 models can be used, without any restrictions on box or discrete. [Section 4.1]
  4. Learn how to wrap gymnasiumenvironments for reward shaping. [Section 4.2]
  5. Learn how to wrap your own custom environments to be compatible with sb3, with minimal changes to your original code which may follow a different structure. [Section 5]

[2] Installation

Create a virtual environment and set up the relevant dependencies. I cater to the majority – here the guide is created using Windows and has Anaconda installed. Open your Anaconda prompt and do the following:

conda create --name rl python=3.8
conda activate rl

conda install gymnasium[box2d]
pip install stable-baselines3==2.1.0
pip install pygame==2.5.2
pip install imageio==2.31.6

conda install jupyter
jupyter notebook

Here, we will be using jupyter notebook, as it is a more user-friendly tool for teaching.

[3] A Taste of Success – See Your Trained RL Agents

The first thing is to import the required libraries.

import os
import numpy as np
import gymnasium as gym   # 0.28.1
import stable_baselines3  # 2.1.0
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.evaluation import evaluate_policy

[3.1] DQN on Cartpole

Let’s start small, on the Cartpole task where the objective is to push the cart (left or right) to keep the pole upright.

What’s the absolute minimum you need? Just this, to train.

env = gym.make("CartPole-v1")
model = DQN("MlpPolicy", env)
model.learn(total_timesteps=100000)

And this, to evaluate.

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")

Finally, this, to visualize.

import pygame
env = gym.make("CartPole-v1", render_mode="human")

obs = env.reset()[0]
score = 0 
while True:
    action, states = model.predict(obs)
    obs, rewards, done, terminate, info = env.step(action)
    score += rewards
    env.render()
    if terminate: 
        break
print("score: ", score)
env.close()

In just 10+ lines of code, and a couple of seconds, we solved a classic RL problem. This is a good example of the extent to which AI has been democratized!

To save your sb3 model, just add a callback during the training execution.

env = gym.make("CartPole-v1")
model = DQN("MlpPolicy", env)
model.learn(
    total_timesteps=100000,
    callback=EvalCallback(
        env, best_model_save_path='./logs/', eval_freq=5000
    )
)

Your model can subsequently be loaded in just two lines.

model = DQN.load("./logs/best_model.zip")
model.set_env(env)

[3.2] Check action/observation space

Suppose we try a different model, say, using model=SAC("MlpPolicy", env). An error would result.

This is because SAC (Soft Actor Critic) only works with continuous action space, as stated on the official Stable Baselines3 documentations, while the cartpole environment has discrete action space.

I’ve compiled the action space constraints into a simple function below:

def is_compatible(env, model_name):
    action_requirements = {
        'A2C':  [gym.spaces.Box, gym.spaces.Discrete],
        'DDPG': [gym.spaces.Box],
        'DQN':  [gym.spaces.Discrete],
        'PPO':  [gym.spaces.Box, gym.spaces.Discrete],
        'SAC':  [gym.spaces.Box],
        'TD3':  [gym.spaces.Box],
    }
    return isinstance(env.action_space, tuple(action_requirements[model_name]))

With this, is_compatible(env,'DQN') returns True, while is_compatible(env,'SAC') returns False.

There are no constraints on observation spaces, for any of the models in sb3.

[4] Wrap gymnasium Env

What if we want to modify the gymnasium environment according to our own specifications? Should we write the code from scratch? Or perhaps look at the source code and make modifications over there?

The answer to both is, no.

It is better to simply wrap the gymnasium objects. Not only is this fast and easy, doing so makes your code readable and trustworthy.

People do not need to scrutinize every line of your code. Instead, they only need to look at the modifications within your wrapper (assuming they are convinced with the correctness of gymnasium).

[4.1] Agnostic to box or discrete

In section 3.2, we see that SAC is not compatible with Cartpole.

This is a workaround for this. In fact, any sb3 model can be used on any environment; we just need a simple wrapper.

class EnvWrapper(gym.ActionWrapper):
    def __init__(self, env, conversion='Box'):
        super().__init__(env)
        self.conversion = conversion
        if conversion == 'Box':
            self.action_space = gym.spaces.Box(
                low=np.array([-1]), high=np.array([1]), dtype=np.float32
            )
        elif conversion == 'Discrete':
            self.num_actions = 9
            self.action_space = gym.spaces.Discrete(
                self.num_actions
            )
        else:
            pass

    def action(self, action):
        if self.conversion == 'Box':
            # Takes a Continuous action from the model and convert it to discrete for a natively Discrete Env
            if action.shape == (1,):
                action = np.round((action[0] + 1) / 2).astype(int)  # convert from scale of [-1, 1] to the set {0, 1}
            else:
                action = np.round((action + 1) / 2).astype(int)
        elif self.conversion == 'Discrete':
            # Takes a Discrete action from the model and convert it to continuous for a natively Box Env
            action = (action / (self.num_actions - 1)) * 2.0 - 1.0
            action = np.array([action])

        return action

With this, you can solve an environment with discrete action space using a model like SAC which deals with continuous action space.

wrapped_env = EnvWrapper(env, 'Box')
model = SAC("MlpPolicy", wrapped_env)
model.learn(total_timesteps=10000)

Any sb3 model can be made compatible with any classic gymnasium environment. Don’t just take my word for it. Try out the following.

env_name_list = ['CartPole-v1', 'MountainCar-v0', 'Pendulum-v1', 'Acrobot-v1']
model_name_list = ['A2C', 'DDPG', 'DQN', 'PPO', 'SAC', 'TD3']

for env_name in env_name_list:
    for model_name in model_name_list:
        env = gym.make(env_name)
        if not is_compatible(env, model_name):
            # Environment and model are not compatible. Will wrap env to suit to model
            if isinstance(env.action_space, gym.spaces.Box):
                env = EnvWrapper(env, 'Discrete')
                print("Box Environment warpped to be compatible with Discrete model...")
            else:
                env = EnvWrapper(env, 'Box')
                print("Discrete Environment warpped to be compatible with Continuous model")
        else:
            print("Already compatible")

        model = eval("%s("MlpPolicy", env, verbose=False)" % model_name)
        print("Using %s in %s. The model's action space is %s" % (model_name, env_name, model.action_space))

        model.learn(total_timesteps=100)  # just for testing

Note that the purpose here is just to show that the environments can be wrapped to be made compatible. The performance might not be ideal, but that is not the point here.

The point is to show you that if you understand how sb3 works with gymnasium, you are able to wrap anything for universal compatibility.

[4.2] Reward shaping

Suppose we want to modify a gymnasium environment, to try out reward shaping. For example, you may have played with Lunar Lander and observed that an agent trained with default hyperparameters may hover at the top, in order not to risk a crash.

In this case, we can impose a penalty when the agent persistently stays at the top.

class LunarWrapper(gym.Wrapper):
    def __init__(self, env, max_top_time=100, penalty=-1):
        super().__init__(env)
        self.max_top_time = max_top_time  # penalty kicks in after this step
        self.penalty = penalty            # additional reward (or penalty if negative) after max_top_time
        self.penalty_start_step = 20000
        self.step_counter = 0

    def reset(self, **kwargs):
        self.time_at_top = 0
        return super().reset(**kwargs)

    def step(self, action):
        obs, reward, done, terminate, info = super().step(action)
        self.step_counter += 1

        y_position = obs[1]
        if y_position > 0.5:
            self.time_at_top += 1
        else:
            self.time_at_top = 0  # Reset counter if it comes down

        # Apply penalty if the lander stays at the top for too long
        if self.time_at_top >= self.max_top_time:
            if (self.step_counter >= self.penalty_start_step):
                reward += (-y_position)    # top of the screen is 1. To incur more penalty when it is high

        return obs, reward, done, terminate, info

Keep in mind that after training with the pseudo-rewards, the agent should be fine-tuned using the actual environment with the original rewards.

env_name = "LunarLander-v2"
wrapped_env = LunarWrapper(gym.make(env_name))

model = DQN(
    "MlpPolicy", wrapped_env,
    buffer_size=50000, learning_starts=1000, train_freq=4, target_update_interval=1000, 
    learning_rate=1e-3, gamma=0.99
)

model.learn(
    total_timesteps=50000,
    callback=EvalCallback(
        wrapped_env, best_model_save_path='./logs/', log_path='./logs/', eval_freq=2000
    )
)

model = DQN.load("./logs/best_model.zip")
model.set_env(env)

model.learn(
    total_timesteps=20000,
    callback=EvalCallback(
        env, best_model_save_path='./logs/', eval_freq=2000
    )
)

This looks much better!

[5] Wrapper Over Custom Tasks

In this final section, I will deliver my 5ᵗʰ promise – Learn how to wrap your own custom environments to be compatible with sb3, with minimal changes to your original code which may follow a different structure.

As a learner, we train RL agents to solve well known benchmark problems. However, the industry pays you to solve real problems, and not toy problems. If you are employed for your RL expertise, chances are that you have to solve problems that are unique to your company.

However, sb3 and gymnasium still remain your good friends!

For the purpose of illustration, let’s consider the following simple GridWorld.

class SimpleEnv:
    def __init__(self):
        self.min_row, self.max_row = 0, 4
        self.min_col, self.max_col = 0, 4
        self.terminal = [[self.max_row, self.max_col]]
        self.reset()

    def reset(self, random=False):
        if random:
            while True:
                self.cur_state = [np.random.randint(self.max_row + 1), np.random.randint(self.max_col + 1)]
                if self.cur_state not in self.terminal:
                    break
        else:
            self.cur_state = [0,0]
        return self.cur_state

    def transition(self, state, action):
        reward = 0
        if action == 0:
            state[1] += 1   # move right one column
        elif action == 1:
            state[0] += 1   # move down one row
        elif action == 2:
            state[1] -= 1   # move left one column
        elif action == 3:
            state[0] -= 1   # move up one row
        else:
            assert False, "Invalid action"

        if (state[0] < self.min_row) or (state[1] < self.min_col) 
            or (state[0] > self.max_row) or (state[1] > self.max_col):
            reward = -1

        next_state = np.clip(
            state, [self.min_row, self.min_col], [self.max_row, self.max_col]
        ).tolist()
        if next_state in self.terminal:
            done = True
        else:
            done = False 
        return reward, next_state, done

    def _get_action_dim(self):
        return 4

    def _get_state_dim(self):
        return np.array([5,5])

Notice here that the transition method returns reward, next_state, and done. Stable baselines3 will not accept this style.

Do you have to re-code your environment? No!

Instead, we build a simple wrapper around it.

from gymnasium import spaces

class CustomEnv(gym.Env):
    def __init__(self, **kwargs):
        super().__init__()
        self.internal_env = SimpleEnv(**kwargs)
        self.action_space = spaces.Discrete(self.internal_env._get_action_dim())
        self.observation_space = spaces.MultiDiscrete(self.internal_env._get_state_dim())

    def step(self, action):
        reward, next_state, done = self.internal_env.transition(self.internal_env.cur_state, action)
        self.count += 1
        terminate = self.count > 50
        if terminate:
            reward += -100
        return np.array(next_state), reward, done, terminate, {}

    def reset(self, random=True, **kwargs):
        self.count = 0
        return (np.array(self.internal_env.reset(random=random)), {})

    def render(self, mode="human"):
        pass

    def close(self):
        pass

In the above, we define a method step, which wraps around the original environment’s transition, and returns what sb3 expects.

At the same time, I’ve used this opportunity to demonstrate that we can perform modifications without dissecting the original environment. Here, CustomEnv terminates the episode (with a large penalty) if the goal is not reached in 50 steps.

How do we know if the environment is wrapped properly? First, it has to pass the following basic check.

from stable_baselines3.common.env_checker import check_env

env = CustomEnv()
check_env(env, warn=True)

obs = env.reset()
action = env.action_space.sample()
print("Sampled action:", action)

obs, reward, done, terminate, info = env.step(action)
print(obs.shape, reward, done, info)

Next, we can use an sb3 model to train on the wrapped environment. You could also play around with the hyperparameters here, as shown below.

model = DQN(
    "MlpPolicy", env,
    learning_rate=1e-5,
    exploration_fraction=0.5,
    exploration_initial_eps=1.0,
    exploration_final_eps=0.10,
)
model.learn(
    total_timesteps=100000,
    callback=EvalCallback(
        env, best_model_save_path='./logs/', eval_freq=10000
    )
)

Conclusion

In this article, you have learnt to set up your own environment to run sb3 and gymnasium. You now have ability to implement state-of-the-art RL algorithms on any environment of your choice.

Enjoy!


Related Articles