Deep Q-Networks: from theory to implementation

Amrani Amine
Towards Data Science
5 min readDec 19, 2020

--

In my previous post, I explained and implemented the Q-learning algorithm from scratch using ForzenLake environment provided by gym library. Please be sure to have already have some Reinforcement Learning (RL) basics and have understood the Q-learning algorithm.

taken by me

Well, in the Q-learning algorithm, we compute the Q-table which contains the Q-values of any state-action pair using the Q-value iteration. Q-learning algorithm works well for finite states and actions spaces because, since we store every state-action pair, this would mean that we need huge space amount of memory to store all of them and much more iterations for the Q-table to converge. In the case where states space, actions space or both of them are continuous, it is just impossible to use the Q-learning algorithm.

As a solution to this problem, we can compute the Q-value function using Deep Neural Networks which are known for their efficiency to approximate functions. Indeed, think of the Q-table as an evaluation of an unknown function at some points. Since it is a function, we can use Deep Neural networks to approximate it and thus, allows us to deal with continuous spaces without any problem. This is what we call Deep Q-Networks. The image below shows the RL process when using DQN.

from https://www.novatec-gmbh.de/en/blog/deep-q-networks/

As we can see, the Deep Neural Network (DNN) takes as an input a state and outputs the Q-values of all possible actions for that state. We understand that the input layer of the DNN has the same size than a state size and that the output layer has the size of the number of actions that the agent can take.

To summarize, when the agent is at a certain state, he forwards that state through the DNN and chooses the action with the highest Q-value.

But how the agent learns? And how the DNN is updated?

To train our DNN, we use a technic called Replayed Memory. The idea is that the agent stores all its experiences in a memory buffer called replayed memory buffer. At time step t, the experience is a tuple containing the current state of the environment, the chosen action, the reward and the next state of the environment:

After each episode, the agent samples a batch of experiences from the replayed memory and use them to train the DNN. During the training, we use as a loss function the Temporal Difference error function (TD function), which is the difference between the Q-value of a state-action pair and its Q-Target.

As the Q-Target is unknown, we use once again the Bellman optimality equation that we recall:

Bellman optimality equation

, where s’ is the next environment state. To compute the highest Q-value from state s’, we only need to froward s’ through the DNN and get the highest output value.

You might ask why we sample a batch instead of selecting consecutive experiences. Actually, it is a relevant question.

We prefer sampling randomly rather than selecting consecutive experiences because consecutive experiences are highly correlated which can lead to an overfitting situation. We can simply break this correlation by sampling randomly.

Consecutive experiences
Random experiences

Well, now we can start coding! We will use CartPole environment provided by gym, an opensource python library which provides many environments for Reinforcement Learning such as Atari Games. In CartPole, we have a pole standing on a cart which can move. The goal of the agent is to keep the pole up by applying some force on tit every time step. When the pole is less than 15° from the vertical, the agent receives a reward of 1. An episode is ended when the pole is more than 15° far from the vertical or when the cart position exceeds 2.4 units from the centre.

Let’s start. First, we import all libraries that we need:

As you can see, we use Keras (integrated with TensorFlow), to build our neural network. You don’t need to master Keras to understand the code because I will explain each step of the implementation.

As we saw previously, the agent has to:

  • compute the action to choose for a given state
  • store its experiences in a memory buffer
  • train the DNN by sampling a batch of experiences from the memory buffer

In addition to that, we have to pay attention to the exploration-exploitation trade-off that I explained in my previous post. To recall it rapidly, the agent is more likely to explore the environment in the beginning by choosing random actions because he has no idea about how the environment works. Through time steps, the agent gets more and more knowledge, so he is more likely to exploit his knowledge rather than picking random actions. For that purpose, we will use the epsilon greedy algorithm.

We will need to create a class that we name DQNAgent

The code is well documented. Please read the comments to understand what different functions refer to.

Now, since we built our agent, we can start the training:

We are done, just let the agent learn!

After 600 iterations of training, we can use the following code to watch our agent playing CartPole.

And this is what we get:

Our agent playing CartPole

Conclusion

DQNs are very efficient but we still can improve them. Indeed, we notice that when we compute the Q-Target values, we use the neural network that we are updating. Thus, after updating the weights of the neural network, the Q-values move towards the Q-Target values but also Q-Target values move in the same direction which makes the optimization chasing its own tail and unstable. At the same time, selecting random experiences from the memory buffer may not be the optimal strategy. Indeed, the agent may focus on experiences where the TD error is large.

As a solution for these two problems, we introduce the Double Q-Networks and the Prioritized Experience Replay that I will explain in future posts.

I hope you understood the Deep Q-Networks and that you are able now to implement your own DQN!

Thank you!

--

--

5th year computer science and statistics engineering student at Polytech Lille, I am fascinated by machine learning and Graph Neural Networks.