Learnings from reproducing DQN for Atari games

A raw portrayal of what it’s like to implement this reinforcement learning algorithm from scratch

Dennis Feng
Towards Data Science

--

Deep learning researchers tend to say that one of the fastest and most effective ways to gain better understanding and practice is to reproduce results from key papers. I decided to give it a try, starting with the seminal paper “Human-level control through deep reinforcement learning” by Google DeepMind, whose Deep Q-Network (DQN) algorithm played classic Atari games (like Pong and Breakout) at human, or even superhuman, levels. More amazingly, its performance generalized to ~50 Atari games, using the same algorithm with identical hyperparameters!

Here, I show a summary of my process, learnings, and results, especially listing out all the bugs I ran into (even the dumb or very specific ones!) — my aim is to show a raw, realistic portrayal of what this project feels like, which past-me would have appreciated seeing at the start. I’ve also open sourced my project on GitHub, for those curious about the detailed implementation: https://github.com/dennischenfeng/dqn. A quick disclaimer: this is not a how-to guide. There are plenty of good guides out there on building DQN for Atari games; I simply wanted to share my experience.

Having worked through this project, I think it wasn’t quite as daunting as I had anticipated initially (it just required continued persistence), and I hope that readers like you might feel the same after reading! To give context, I have professional experience writing software, but I’ve only really dived into studying and practicing deep reinforcement learning in the past several months.

Contents

  1. Introductory material
  2. Part 1: Implementing first draft of code
  3. Part 2: Test on easy benchmarks
  4. Part 3: Test on Atari environments
  5. Part 4: Test on easy benchmarks (revisited)
  6. Part 5: Test on Atari environments (revisited)
  7. Conclusion
  8. Appendix: Other bugs I ran into

Introductory material

Before I started this project, I had done some reading and free online coursework. For this, I’d recommend:

Other excellent resources that spurred me into starting this project:

Part 1: Implementing first draft of code

After reading the paper (especially the Supplementary Information section which contains the pseudocode and hyperparameters used), I first translated the paper’s pseudocode into code. Here’s a rough conceptual breakdown of the DQN algorithm (following the pseudocode in the paper):

  1. Execute an action in the environment (Atari game). With probability ε (epsilon), the action is randomly selected. Otherwise the “best” action is selected, i.e. we select the action that maximizes value (reward) based on the current action-value estimator Q. Note that ε is slowly reduced (annealed) over the course of training, to reduce the amount of exploration (manifested by random actions) in later phases of training.
  2. Receive observation (image snapshot of game screen) and reward (value increase of game score) and store that data in the “replay memory.” We will repeatedly sample from this replay memory in order to update Q, similar to how we as people pull from memory to learn and improve our decision making.
  3. Sample data from replay memory, roughly treating it as labeled training data (observation and action are input, reward is label), and take a gradient descent step to update Q.
  4. Repeat.

As I was writing the code, it was crucial to modularize the independent pieces, de-coupling their implementations from the main algorithm. For example, for the annealed epsilon (reducing exploration through training), the replay memory (storage container to play back past transitions), and the preprocessing of observations (frame skipping, grayscale, cropping, stacking frames), I created stand-alone functions/modules for each of these to reduce interactions and complexity in the DQN implementation.

Also writing unit tests in lockstep for every part and feature that was non-trivial, ended up saving me lots of debugging time (as expected) throughout this whole project so that was well worth it too.

Part 2: Test on easy benchmarks

Putting the code to the test, I attempted to train the model on the CartPole environment (env), from OpenAI’s gym.

Then after some learning rate tuning, I started to see some signs of life: episode return was on the order of 100 to 200, out of the maximum 500.

Three training runs on the CartPole environment. (Left) X-axis is number of env steps, y-axis is mean episode return (score) on 10 evaluation episodes. Note: maximum episode return on CartPole is 500. (Right) X-axis is also number of env steps, y-axis is training loss. (Image source: author)

I couldn’t achieve a trained model that hit 500 just by manually running some varied hyperparameter iterations, so I turned to using Optuna for an overnight automated 100-trial hyperparameter study. It seemed to find hyperparameter configurations that achieved the max score, so I was satisfied and moved to the next step. (Unfortunately, it wasn’t until later that I found out this was a fluke; see part 4).

Parallel coordinate plot (generated by Optuna) for 100 training runs with varying hyperparameters. Each curve from left to right represents one training run, and intersections with the vertical axes indicate the selected values for each hyperparameter on that run. As shown by the color bar, darker blue curves are runs with higher objective value (episode return). (Image source: author)

Part 3: Test on Atari environments

From looking at DQN’s training curves over a variety of Atari environments (see Appendix of Rainbow DQN paper), I chose Pong and Breakout to start with because of their sharply positive slope early in training, and because I understood them the best from playing them in my younger days.

[Bug] Immediately in my first run on Breakout, my evaluation step was stalling progress → turns out Breakout requires you to press the “FIRE” button to spawn the ball, or else the game sits idle! Initially, I created a wrapper env that presses “FIRE” to start, but later I created one that instead terminates the game after idleness (because I want the model to learn something as crucial as pressing “FIRE”).

At this point my model was actually training, but it was training at a snail’s pace — something like 100k steps per hour. To give context, the original paper trains each model for 200M environment steps (4 ⨉ 50M update steps)! Clearly I needed to speed it up; it was reasonable to believe that I had some inefficiencies buried deep in my code. I turned to CProfile, a handy built-in python package that profiles the runtime of function/method calls.

Example CProfile output on a training run. The “tottime” column shows the compute duration (in seconds) in the particular function indicated in the last column. (Image source: author)

[Bug] Using Cprofile, I found an obscure quirk that converting a list of large numpy arrays into a pytorch tensor was inefficient → converting first to a numpy array before going to pytorch tensor was much faster.

For a while, I didn’t make much performance gain. I tried to make the env potentially easier by giving it a parametrized fixed reward at every step (like in CartPole), I increased minibatch size to reduce stochasticity, and I tried training for more env steps by repeatedly saving the model and loading it at the beginning of the next immediate training session (I had to do this because my training sessions were limited to a given number of hours).

Still, the performance was not much better than random, even after 5 to 10M env steps. A bit disheartening, but this means there was likely a crucial mistake I made somewhere, so I went back to easier environments to analyze performance with a faster feedback loop.

Part 4: Test on easy benchmarks (revisited)

Uh-oh. I ran 3 repeated training runs on CartPole using the previously “optimized” hyperparameters, and got this:

Three (smoothed) training runs on CartPole. X-axis is number of env steps, y-axis is mean episode return on 10 evaluation episodes. (Image source: author)

Looks like the run-to-run variance was far too high; sometimes the episode return stays below 50 for the entire training run, which is quite poor. Turns out my earlier benchmark tests were a fluke!

[Big mistake] I skipped repeatability/variance studies during my initial benchmarking because I was too excited to move on to Atari envs. This costed me weeks of mostly wasted effort.

With a careful eye on sifting through my code, plentiful use of Cprofile to identify inefficiencies, and comparing against a known good open-source implementation of DQN (from Stable Baselines 3), I squashed a handful of bugs.

[Bug] There was a factor of 2 wrong in my loss function, because I implemented the smoothing incorrectly → I ended up using pytorch’s smooth_l1_loss.

[Bug] I was sampling the replay memory usingnp.random.choice(with replacement), which Cprofile told me was fairly inefficient → using np.random.randint brought runtime down significantly.

[Bug] Accidentally was running a backpropagation backward pass through the target network (which is supposed to stay constant for long periods) for every gradient step! → using torch.no_grad(suppresses gradient calculations) during target network operations eliminated the unnecessary backward passes.

[Bug] A showstopper bug: I had accidentally written done instead of (1 — done), where done is either 0 or 1, which effectively gave the Q update the complete opposite information on whether an episode was terminated. Oof.

Finally, after removing all those bugs, the model could solve CartPole reliably:

(Left) Mean of 10 training runs on CartPole. Error ribbons, indicating 1 standard error, are in red. (Middle) A representative training run, where x-axis is number of env steps, y-axis is mean episode return over 100 evaluation episodes. (Right) Gameplay of a fully trained agent, whose goal is to move the cart so the pole stays balanced without toppling. (Image and gif source: author)

The CartPole environment (Cartpole-v1) is considered solved when an agent can reach an average episode return of 475 or higher over 100 consecutive episodes, which our model achieves robustly (every run hits 500 at some point within the training).

One thing I learned from benchmarking DQN on CartPole: don’t be too concerned about initial unstable drops in performance, because it’s still able to solve the environment given enough training. The finicky performance drops are perhaps due to the distributional shift that occurs during training, e.g. early on as the model learns rapidly, the observations seen (test distribution) will deviate from the training distribution.

And for diversity’s sake, I tested on FrozenLake too:

(Left) Mean of 10 training runs on FrozenLake. Error ribbons, indicating 1 standard error, are in red. (Middle) A representative training run, where x-axis is number of env steps, y-axis is mean episode return over 100 evaluation episodes. (Right) Gameplay of a fully trained agent, whose goal is to navigate from the start position S to the goal position G by walking through frozen spaces F without falling into hole spaces H. The catch is that the floor is slippery and the actual step direction can be randomly rotated 90° from the intended direction. The agent’s input direction for every step is indicated at the top of the screen. (Image and gif source: author)

FrozenLake (FrozenLake-v0) is considered solved when an agent has surpasses an average return threshold of 0.78. And it looks like our model reaches this as well (above threshold at some point during training)!

Part 5: Test on Atari environments (revisited)

Testing with Pong, Breakout, and Freeway (another Atari game with sharply positive initial slope on its training curve), I was elated to find that the model was finally able to learn intelligent gameplay!

(Top) Three training runs on Pong, where x-axis is number of env steps and y-axis is episode return of a single evaluation episode. (Bottom) Gameplay of fully trained agent (green player), whose goal is to hit the ball past the opponent’s paddle. Here, I added a small amount of stochasticity (10% chance of random action) to show how the agent deals with a more varied range of scenarios. Without the added stochasticity, the agent beats the opponent in a very similar way each time. (Image and gif source: author)
(Top) Three training runs on Freeway, where x-axis is number of env steps and y-axis is episode return of a single evaluation episode. (Bottom) Gameplay of fully trained agent (left-side player), whose goal is to direct the chicken across the road as quickly as possible while avoiding cars. (Image and gif source: author)

As for Breakout, the training curve shows it’s learning beyond random behavior, but it seems to require more steps than Pong and Freeway did to achieve human performance. I’m currently in the process of using cloud compute (Google Compute Engine) to initiate longer runs, and I’m planning to update this post or start a new one after I finish!

Conclusion

Overall, this was definitely a fun side project that I’d echo as a recommendation to folks interested in really getting into the nitty-gritty engineering of reinforcement learning, since I learned a lot myself while doing it. Thanks for reading!

Appendix: Other bugs I ran into

For completeness, here I show the rest of the bugs I came across, which may be a bit more specific to my own experience and perhaps less generalizable to the reader.

[Bug] For my very first run, the (training) episode return was hardly increasing at all even after 200k env steps → I found that I forgot to increment the counter that updated the target Q network.

[Bug] Early in the process, I noticed that my memory usage was through the roof, ~10x higher than the expected # bytes in the replay memory! → aha, the numpyarray was defaulting to use datatype float64 instead of the intended uint8.

[Bug] During an effort to make the env easier to learn in Part 3, I realized I was training on a full game of Breakout (5 lives) → re-reading the paper, I found that terminating after 1 life was the way to go.

[Bug] A fairly obvious one: I learned I needed to manually send the network parameters and input tensors to the GPU in order to utilize it — to(torch.device("cuda")). Additionally, I found a speedup by running the image preprocessing (rescaling, grayscale, cropping) on the GPU as well.

--

--