Combining MuZero and Perceiver IO to Create a Generalized AI

Hi everyone, today we’re going to be improving upon our chess AI we created in my Building a Chess Engine Part 2 by converting it to a more generalized system. This lesson is going to be technical again, so please bear with me. I try to supply both equations and diagrams to help make things a little easier.
For this new AI, we’re going to be building off of MuZero by using its powerful search method. MuZero was presented in a white paper by DeepMind as an improvement to AlphaZero. In this white paper, MuZero achieves superhuman performance in multiple games without an accurate simulator making it a unique algorithm.
MuZero’s precursor AlphaZero achieved superhuman performance, making the performance between the two algorithms comparable. However, AlphaZero uses external simulators in its planning algorithm, making it less general than MuZero. The reliance on these external simulators limited AlphaZero to perfect knowledge games with pre-crafted simulators. However, MuZero has done away with these external simulators allowing it to perform a multitude of tasks. The removal of the external simulators might lead you to believe that MuZero uses some form of model-free reinforcement learning (RL). However, this is untrue MuZero instead uses a learnt model to represent the dynamics of its task at hand.
Our AI will differ from the paper in some ways, as we will not be using the same model architectures and instead will use a Perceiver IO based architecture. The Perceiver IO is a generalized model that works on a variety of data types. This generality is very appealing to us as it pairs well with the general nature of MuZero.
Search
MuZero was built directly off AlphaZero, so you’ll find them sharing a lot of similar ideologies. These similar ideologies include the idea of guiding the Monte Carlo Tree Search (MCTS) algorithm with a Deep Learning model. However, as previously stated above, MuZero does not use external simulators to look ahead in its MCTS and instead uses a learnt dynamics model. This dynamics model is unique as it is not trying to predict the exact representation of the chosen action and is only trying to predict the hidden state representation. This representation eases the learning process allowing the dynamics model to learn the rules for accurate planning. In this new search algorithm, we will change our old upper confidence bound (UCB) formula.

We will replace our old UCB equation with the polynomial upper confidence trees (pUCT) equation found in the paper.

When comparing the two equations, you will find that they are very similar where both try to balance the exploration and exploitation dilemma faced in MCTS problems. The new equation adds an extra hyperparameter (c2) to our old UCB equation. With this new equation, c2 decays proportionally to the parent node visits. Scaling down c2 this way encourages the search algorithm to become less exploitative for heavily visited parent nodes. Another difference in our new pUCT formula is that we are normalizing our Q value. Normalizing our Q value forces it to become a percentage. Having our Q value as a percentage is beneficial here since it allows the exploration portion of the equation to have more influence.

Another trick that we employ is to add Dirichlet noise to our policy values at the root node during training. Adding this noise enhances exploration allowing the AI to contemplate more ideas while training. This noise will cause poor-performing nodes to be visited. However, poor-performing nodes will be pruned by the pUCT equation further in the search. The pUCT equation prunes poor-performing nodes since we only apply noise to the exploration side of the equation, which enables the exploitation of high-value moves to continue.

With our new AI, we have also adjusted our update value equation during the backpropagation of nodes. The parent node’s values get updated with the help of a gamma discount.

When updating a node, the sum of the gamma discount and the predicted value is normalized. Updating node values like this achieves a better approximation by creating a natural decay that will cause the AI to prefer higher values closer to the current node.

With these changes, our new MCTS algorithm looks something like this.

After running MCTS on our game tree, the best action to take is the most visited node. To get the best action as a probability distribution, we normalize each action by its number of visits. It’s believed using the MCTS like this allows for an enhanced policy to be developed when compared to using our prediction Neural Net (NN) alone. During training, we schedule a decay based on the current game move count to our temperature variable to increase exploration early in each game. When our temperature variable hits zero, we choose the most visited action allowing the AI to exploit its MCTS knowledge. When evaluating the AI, we set the temperature variable to one, so it has no effect.

Model Architecture
As the introduction stated, we will not use the models found in the MuZero paper. Instead, we will use a Perceiver IO architecture, a new model released by DeepMind as a predecessor of their Perciver model.
The Perceiver model is an attention-based method that attempts to fix some of the shortcomings found with vanilla Transformers. Vanilla Transformers have been very successful in many domains recently. However, implementations with large input spaces have not been as successful due to the quadratic scaling problem found in Transformers. Perceivers like Transformers rely heavily on the attention mechanism. Because of this, you might think this quadratic scaling problem would still exist. However, the Perceiver uses a combination of multiheaded cross-attention and multiheaded self-attention. This combination overcomes this problem by projecting the high-dimensional data to a lower dimension.

Self-attention uses the same sized Q value, V value and K value forcing all computations to stay the same size.

Cross-attention instead uses a different sized Q value which effectively scales computations to the same size as the Q value.

Now that we understand the difference between cross-attention and self-attention, I’ll explain how the Perceiver combines the two.

The Perceiver receives a single input array, represented above as a byte array. Whereas the second variable dictated as a latent array is a learnt set of weights. The latent array gets passed as the Q value to a cross-attention block which successfully shrinks the input (assuming the latent array is smaller than the input). Shrinking the input array like this is how the Perceiver mitigates the quadratic scaling problem. The latent transformer block then performs self-attention on the resultant of the cross attention block. The latent transformer block can use self-attention without negatively impacting processing time since the reduction in input size makes the quadratic scaling less impactful. The output of the latent transformer is recursively passed to another latent transformer as many times as you want. After the latent transformer, there is the option of performing cross attention on the resultant latent layer and the input data. It’s believed that this step allows the model to focus on other parts of the input data that the previous cross-attention might have missed. This belief comes from the fact that your latent array is now different since it has just been transformed by previously being passed through this block of model. The output of that cross-attention block would then go to another recursion of latent transformers. This process of having the latent transformers resultant passed to another cross-attention block can happen as many times as you want. All weights between the cross-attention blocks and latent transformer blocks have the option of being shared between themselves. Meaning we can pass the data to the same cross-attention block/latent transformer block or use separate networks with their own set of weights.

Now that we understand the architecture of the Perceiver, the architecture for Perceiver IO will not be so hard. Perciveir IO adds two cross-attention blocks to the Perceiver to scale the input and output data to their desired shape and sizes. This trick is beneficial as it allows for the use of this model in various situations since the shape and size of our data is no longer as impactful.
Our Model
Now that we understand the underlying ideas implemented in the Perceiver IO architecture, I can explain how it’s adapted for this AI.
Our AI uses the same underlying idea as MuZero by having multiple models. However, we do not stick to the three-model format presented in the paper and instead use two models.
The first model is a cross-attention model used to represent the current hidden state of the game and shares the same purpose as the representation model in MuZero. This model can be thought of as the input model of our Perceiver IO architecture as it projects the input to a standard dimension that our second model can process. This model receives the same encoded game state as our AI in Building a Chess Engine Part 2.

The second model is a little more complex as it is a multi-task model. This multi-tasked model handles all tasks that the prediction function and dynamics function in MuZero would perform (value, policy, reward, next state). We use a multi-task model here instead of individual models because of the relationship between the tasks. This relationship makes a multi-task model beneficial as it leads to a notion of shared features amongst each task. This model consists of two parts, a backbone and separate specialized heads.

The backbone learns generalized game knowledge through a Perceiver network and acts as the autoencoder layer of this network. This layer gets a hidden state representation and an action token as an input.

The value head is a cross-attention model used to predict the outcome of the game. This layer is one of the outputs of our Perceiver IO architecture. This layer takes the backbone as input and projects it to a value representing the outcome (-1:1).

The policy head is a cross-attention model used to predict the probability distribution of taking each action. This layer is another output of our Perceiver IO architecture. This layer uses the backbone as input and projects it to a vector representing the probability distribution of taking each action.

The state head is a cross-attention model used to predict the hidden state representation of performing the desired action. This layer is another one of the outputs of our Perceiver IO architecture. This layer uses the backbone as input and projects it to the hidden state representation of performing the desired action.

The reward head is a cross-attention model used to predict the instantaneous reward for performing the desired action. This layer is one of the outputs of our Perceiver IO architecture. This layer takes the backbone as input and projects it to a value representing the immediate reward (zero for board games like chess).

Training
Similar to MuZero and our previous AI, our model learns entirely through self-play. With our training, we add an evolutionary aspect by having an active model and a new model. The new model trains itself through self-play, while the active model is our current best-performing model. To minimize potential diminishing returns, the two models compete in a round-robin-style tournament. Here the winner of the round-robin becomes the new active model.

While training, we break our two models into five separate parts (value, policy, next-state, reward, backbone, hidden state). We separate the two models up this way to individually fine-tune each head of the second model. Individually fine-tuning each head allows the model to more efficiently learn each task as there will be no interference from the other head’s gradients. While training, we use the Adam optimizer to fine-tune all five parts.
To train the value head of our second model, we use the final outcome found at the end of each self-play game as our target value. Mean squared error (MSE) is used for our loss function when training this head. We use the mean squared error loss function here since our output of this layer is a regression.

To train the policy head of our second model, we use the logged MCTS policy found during self-play as our target value. Binary cross-entropy (BCE) is used for our loss function when training this head. We use the binary cross-entropy loss function here since this head has a multivariable classification output.

To train the next-state head of our second model, we take ideas from the paper Efficient Zero. Efficient Zero uses self-supervised learning to improve the speed of training MuZero’s dynamics model.

Here we implement the same idea where we pass the known next-state to our hidden state model. We then use the results as the target value in a mean squared error calculation for our loss function. We use the mean squared error loss function here since our output of this layer is a multivariable regression.

To train the reward head of our second model, we use the logged reward found during self-play as our target value. Mean squared error (MSE) is used for our loss function when training this head. We use the mean squared error loss function here since our output of this layer is a regression.

To train the backbone of our second model, we use the sum of the individual heads losses as an overall loss. Here we sum all losses from the heads to learn a representation of the game for all tasks.

To train our hidden-state model, we sum the losses of the second model, excluding our next-state loss as an overall loss. We have excluded our next-state loss since the hidden-state model’s output is used in the next-state loss calculation.

Thanks
And t[here](https://medium.com/@bellerb/building-a-chess-engine-part2-db4784e843d5) you have it, we have successfully upgraded our chess AI to a more general system. This AI will start with no knowledge and learn the game of chess, including the rules. The more training games it plays, the better it will perform. You can check a full version of the code on my GitHub here. You can also see my previous AI here.
Thanks for reading. If you liked this, consider subscribing to my account to be notified of my most recent posts.
Reference
- https://arxiv.org/abs/1911.08265
- https://deepmind.com/
- https://arxiv.org/abs/1712.01815
- https://arxiv.org/abs/2107.14795
- https://arxiv.org/abs/2103.03206
- https://en.wikipedia.org/wiki/Dirichlet_distribution
- https://arxiv.org/abs/1706.03762
- https://arxiv.org/abs/2111.00210