Learnings from reproducing DQN for Atari games
A raw portrayal of what it’s like to implement this reinforcement learning algorithm from scratch
Deep learning researchers tend to say that one of the fastest and most effective ways to gain better understanding and practice is to reproduce results from key papers. I decided to give it a try, starting with the seminal paper “Human-level control through deep reinforcement learning” by Google DeepMind, whose Deep Q-Network (DQN) algorithm played classic Atari games (like Pong and Breakout) at human, or even superhuman, levels. More amazingly, its performance generalized to ~50 Atari games, using the same algorithm with identical hyperparameters!
Here, I show a summary of my process, learnings, and results, especially listing out all the bugs I ran into (even the dumb or very specific ones!) — my aim is to show a raw, realistic portrayal of what this project feels like, which past-me would have appreciated seeing at the start. I’ve also open sourced my project on GitHub, for those curious about the detailed implementation: https://github.com/dennischenfeng/dqn. A quick disclaimer: this is not a how-to guide. There are plenty of good guides out there on building DQN for Atari games; I simply wanted to share my experience.
Having worked through this project, I think it wasn’t quite as daunting as I had anticipated initially (it just required continued persistence), and I hope that readers like you might feel the same after reading! To give context, I have professional experience writing software, but I’ve only really dived into studying and practicing deep reinforcement learning in the past several months.
Contents
- Introductory material
- Part 1: Implementing first draft of code
- Part 2: Test on easy benchmarks
- Part 3: Test on Atari environments
- Part 4: Test on easy benchmarks (revisited)
- Part 5: Test on Atari environments (revisited)
- Conclusion
- Appendix: Other bugs I ran into
Introductory material
Before I started this project, I had done some reading and free online coursework. For this, I’d recommend:
- fast.ai’s “Practical Deep Learning for Coders” course
- OpenAI Spinning Up’s short 3-page “Introduction to RL” series
- UC Berkeley’s video lectures for “CS 285: Deep Reinforcement Learning”
Other excellent resources that spurred me into starting this project:
- Matthew Rahtz’s “Lessons Learned Reproducing a Deep Reinforcement Learning Paper”. Fantastic read about his learnings from implementing a more extensive algorithm than I’m doing here.
- John Schulman’s “Nuts and Bolts of Deep RL Research” video. Good advice and intuitions about deep RL, and even gives some tips for implementing DQN specifically.
Part 1: Implementing first draft of code
After reading the paper (especially the Supplementary Information section which contains the pseudocode and hyperparameters used), I first translated the paper’s pseudocode into code. Here’s a rough conceptual breakdown of the DQN algorithm (following the pseudocode in the paper):
- Execute an action in the environment (Atari game). With probability ε (epsilon), the action is randomly selected. Otherwise the “best” action is selected, i.e. we select the action that maximizes value (reward) based on the current action-value estimator Q. Note that ε is slowly reduced (annealed) over the course of training, to reduce the amount of exploration (manifested by random actions) in later phases of training.
- Receive observation (image snapshot of game screen) and reward (value increase of game score) and store that data in the “replay memory.” We will repeatedly sample from this replay memory in order to update Q, similar to how we as people pull from memory to learn and improve our decision making.
- Sample data from replay memory, roughly treating it as labeled training data (observation and action are input, reward is label), and take a gradient descent step to update Q.
- Repeat.
As I was writing the code, it was crucial to modularize the independent pieces, de-coupling their implementations from the main algorithm. For example, for the annealed epsilon (reducing exploration through training), the replay memory (storage container to play back past transitions), and the preprocessing of observations (frame skipping, grayscale, cropping, stacking frames), I created stand-alone functions/modules for each of these to reduce interactions and complexity in the DQN implementation.
Also writing unit tests in lockstep for every part and feature that was non-trivial, ended up saving me lots of debugging time (as expected) throughout this whole project so that was well worth it too.
Part 2: Test on easy benchmarks
Putting the code to the test, I attempted to train the model on the CartPole environment (env), from OpenAI’s gym
.
Then after some learning rate tuning, I started to see some signs of life: episode return was on the order of 100 to 200, out of the maximum 500.
I couldn’t achieve a trained model that hit 500 just by manually running some varied hyperparameter iterations, so I turned to using Optuna for an overnight automated 100-trial hyperparameter study. It seemed to find hyperparameter configurations that achieved the max score, so I was satisfied and moved to the next step. (Unfortunately, it wasn’t until later that I found out this was a fluke; see part 4).
Part 3: Test on Atari environments
From looking at DQN’s training curves over a variety of Atari environments (see Appendix of Rainbow DQN paper), I chose Pong and Breakout to start with because of their sharply positive slope early in training, and because I understood them the best from playing them in my younger days.
[Bug] Immediately in my first run on Breakout, my evaluation step was stalling progress → turns out Breakout requires you to press the “FIRE” button to spawn the ball, or else the game sits idle! Initially, I created a wrapper env that presses “FIRE” to start, but later I created one that instead terminates the game after idleness (because I want the model to learn something as crucial as pressing “FIRE”).
At this point my model was actually training, but it was training at a snail’s pace — something like 100k steps per hour. To give context, the original paper trains each model for 200M environment steps (4 ⨉ 50M update steps)! Clearly I needed to speed it up; it was reasonable to believe that I had some inefficiencies buried deep in my code. I turned to CProfile
, a handy built-in python package that profiles the runtime of function/method calls.
[Bug] Using Cprofile
, I found an obscure quirk that converting a list of large numpy arrays into a pytorch tensor was inefficient → converting first to a numpy array before going to pytorch tensor was much faster.
For a while, I didn’t make much performance gain. I tried to make the env potentially easier by giving it a parametrized fixed reward at every step (like in CartPole), I increased minibatch size to reduce stochasticity, and I tried training for more env steps by repeatedly saving the model and loading it at the beginning of the next immediate training session (I had to do this because my training sessions were limited to a given number of hours).
Still, the performance was not much better than random, even after 5 to 10M env steps. A bit disheartening, but this means there was likely a crucial mistake I made somewhere, so I went back to easier environments to analyze performance with a faster feedback loop.
Part 4: Test on easy benchmarks (revisited)
Uh-oh. I ran 3 repeated training runs on CartPole using the previously “optimized” hyperparameters, and got this:
Looks like the run-to-run variance was far too high; sometimes the episode return stays below 50 for the entire training run, which is quite poor. Turns out my earlier benchmark tests were a fluke!
[Big mistake] I skipped repeatability/variance studies during my initial benchmarking because I was too excited to move on to Atari envs. This costed me weeks of mostly wasted effort.
With a careful eye on sifting through my code, plentiful use of Cprofile
to identify inefficiencies, and comparing against a known good open-source implementation of DQN (from Stable Baselines 3), I squashed a handful of bugs.
[Bug] There was a factor of 2 wrong in my loss function, because I implemented the smoothing incorrectly → I ended up using pytorch’s smooth_l1_loss
.
[Bug] I was sampling the replay memory usingnp.random.choice
(with replacement), which Cprofile
told me was fairly inefficient → using np.random.randint
brought runtime down significantly.
[Bug] Accidentally was running a backpropagation backward pass through the target network (which is supposed to stay constant for long periods) for every gradient step! → using torch.no_grad
(suppresses gradient calculations) during target network operations eliminated the unnecessary backward passes.
[Bug] A showstopper bug: I had accidentally written done
instead of (1 — done)
, where done
is either 0 or 1, which effectively gave the Q update the complete opposite information on whether an episode was terminated. Oof.
Finally, after removing all those bugs, the model could solve CartPole reliably:
The CartPole environment (Cartpole-v1
) is considered solved when an agent can reach an average episode return of 475 or higher over 100 consecutive episodes, which our model achieves robustly (every run hits 500 at some point within the training).
One thing I learned from benchmarking DQN on CartPole: don’t be too concerned about initial unstable drops in performance, because it’s still able to solve the environment given enough training. The finicky performance drops are perhaps due to the distributional shift that occurs during training, e.g. early on as the model learns rapidly, the observations seen (test distribution) will deviate from the training distribution.
And for diversity’s sake, I tested on FrozenLake too:
FrozenLake (FrozenLake-v0
) is considered solved when an agent has surpasses an average return threshold of 0.78. And it looks like our model reaches this as well (above threshold at some point during training)!
Part 5: Test on Atari environments (revisited)
Testing with Pong, Breakout, and Freeway (another Atari game with sharply positive initial slope on its training curve), I was elated to find that the model was finally able to learn intelligent gameplay!
As for Breakout, the training curve shows it’s learning beyond random behavior, but it seems to require more steps than Pong and Freeway did to achieve human performance. I’m currently in the process of using cloud compute (Google Compute Engine) to initiate longer runs, and I’m planning to update this post or start a new one after I finish!
Conclusion
Overall, this was definitely a fun side project that I’d echo as a recommendation to folks interested in really getting into the nitty-gritty engineering of reinforcement learning, since I learned a lot myself while doing it. Thanks for reading!
Appendix: Other bugs I ran into
For completeness, here I show the rest of the bugs I came across, which may be a bit more specific to my own experience and perhaps less generalizable to the reader.
[Bug] For my very first run, the (training) episode return was hardly increasing at all even after 200k env steps → I found that I forgot to increment the counter that updated the target Q network.
[Bug] Early in the process, I noticed that my memory usage was through the roof, ~10x higher than the expected # bytes in the replay memory! → aha, the numpy
array was defaulting to use datatype float64
instead of the intended uint8
.
[Bug] During an effort to make the env easier to learn in Part 3, I realized I was training on a full game of Breakout (5 lives) → re-reading the paper, I found that terminating after 1 life was the way to go.
[Bug] A fairly obvious one: I learned I needed to manually send the network parameters and input tensors to the GPU in order to utilize it — to(torch.device("cuda"))
. Additionally, I found a speedup by running the image preprocessing (rescaling, grayscale, cropping) on the GPU as well.