The world’s leading publication for data science, AI, and ML professionals.

Solving A Rubik’s Cube with Supervised Learning – Intuitively and Exhaustively Explained

A Popular Toy in a Brave New World

"Mosaic Space" by Daniel Warfield using Midjourney, Matplotlib, and Affinity Design 2. All images by the author unless otherwise specified. Article originally made available on Intuitively and Exhaustively Explained.
"Mosaic Space" by Daniel Warfield using Midjourney, Matplotlib, and Affinity Design 2. All images by the author unless otherwise specified. Article originally made available on Intuitively and Exhaustively Explained.

In this article we’ll make an AI model that can solve a Rubik’s Cube. We’ll define our own dataset, make a transformer style model that can learn based on that dataset, and use that model to solve new and randomly shuffled Rubik’s Cubes.

In tackling this problem we’ll discuss practical problems which come up frequently in Data Science, and the techniques data scientists use to solve those problems.

Who is this useful for? Anyone interested in achieving mastery of modern AI.

How advanced is this post? This post covers advanced modeling strategies intuitively, and is appropriate for readers of all levels.

Pre-requisites: There are no prerequisites for this article, though an understanding of transformer style models may be useful for some of the later, code heavy sections.

References: A link to the code and supporting resources can be found in the reference section at the end of this article.


Defining a Rubik’s Cube as a Modeling Problem

As you likely know, the Rubik’s Cube is a geometric game featuring a 3x3x3 cube with different colored segments on each face. These faces can be turned by 90 degrees in either direction to scramble or solve the Rubik’s Cube.

The goal of solving a Rubik's Cube is to make all faces a single uniform color by performing a series of rotations. The process of mixing up a Rubik's Cube by randomly rotating various faces is called "Scrambling".
The goal of solving a Rubik’s Cube is to make all faces a single uniform color by performing a series of rotations. The process of mixing up a Rubik’s Cube by randomly rotating various faces is called "Scrambling".

The goal of this article is to create a model which can accept a scrambled Rubik’s Cube and output a series of steps to solve said Rubik’s Cube.

The Goal of the AI Model: to accept some shuffled Rubik's Cube and output a list of steps to solve that Rubik's Cube.
The Goal of the AI Model: to accept some shuffled Rubik’s Cube and output a list of steps to solve that Rubik’s Cube.

There are a ton of ways this can be done. In this article we’ll be exploring one of the more straightforward approaches: supervised learning.

The Plan, From a High Level

A natural approach to making an AI model that solves Rubik’s Cubes might be to gather data from solutions of skilled players, then train a model to mimic those solutions. While using human data to train a model has its merits, it also has its drawbacks. Finding and licensing data from pro Rubik’s Cube players might be difficult if not impossible and hiring pro Rubik’s Cube players to create a custom dataset would be costly and time consuming. If you’re clever, all this work might be unnecessary. In this article, for instance, we’ll be using a completely synthetic dataset, meaning we’ll be generating all our training data automatically, and not using any data from human players.

Essentially, we’ll frame the task of solving a Rubik’s Cube as trying to predict the reverse of the sequence that was used to scramble it. The idea is to randomly scramble millions of Rubik’s Cubes, reverse the sequence used to scramble them, then create a model which is tasked with predicting the reversed scrambling sequence.

To create our dataset we'll come up with random moves to scramble our Rubik's Cube. The model's job will be to predict the reverse order and orientation of that scrambling sequence, which will solve the Rubik's Cube.
To create our dataset we’ll come up with random moves to scramble our Rubik’s Cube. The model’s job will be to predict the reverse order and orientation of that scrambling sequence, which will solve the Rubik’s Cube.

This strategy falls under "Supervised Learning", which is the prototypical approach to training an AI model. When training a model with supervised learning you essentially say to the model "here’s an input (a scrambled Rubik’s Cube), predict an output (a list of steps), and I’ll train you based on how well your response aligns with what I expected (the reverse of the scrambling sequence)".

There are other forms of learning, like contrastive learning, semi-supervised learning, and reinforcement learning, but in this article we’ll stick with the basics. If you’re curious about some of those approaches, I provided some links in the reference section at the end of the article.

So, we have a high-level plan: shuffle a bunch of Rubik’s Cubes and train a model to predict the opposite of the sequence used to scramble them. Before we get into the intricacies of defining a custom Transformer style model to work with this data, let’s review the idea of transformer style models in general.

A Brief Introduction to the Transformer

This section will briefly review transformer style models. This is, essentially, a condensed version of my more comprehensive article on the subject:

Transformers – Intuitively and Exhaustively Explained

In its most basic sense, the transformer is an encoder-decoder style model.

A transformer working in a translation task. The input (I am a manager) is compressed to some abstract representation that encodes the meaning of the entire input. The decoder works recurrently, by feeding into itself, to construct the output. From my article on transformers
A transformer working in a translation task. The input (I am a manager) is compressed to some abstract representation that encodes the meaning of the entire input. The decoder works recurrently, by feeding into itself, to construct the output. From my article on transformers

The encoder converts an input into an abstract representation which the decoder uses to iteratively generate output.

A high-level representation of how the output of the encoder relates to the decoder. the decoder references the encoded input for every loop of the output. The decoder generates the entire sequence by taking it's previous outputs as input, and predicting what token it thinks should come next. From my article on transformers
A high-level representation of how the output of the encoder relates to the decoder. the decoder references the encoded input for every loop of the output. The decoder generates the entire sequence by taking it’s previous outputs as input, and predicting what token it thinks should come next. From my article on transformers

both the encoder and decoder employ an abstract representation of text which is created using an operation called multi-headed self-attention.

Multi Headed self attention, in a nutshell. The mechanism mathematically combines the vectors for different words, creating a matrix which encodes a deeper meaning of the entire input. From my article on transformers
Multi Headed self attention, in a nutshell. The mechanism mathematically combines the vectors for different words, creating a matrix which encodes a deeper meaning of the entire input. From my article on transformers

There’s a few steps which multiheaded self attention employs to construct this abstract representation. In a nutshell, a dense neural network constructs three representations, usually referred to as the query, key, and value, based on the input.

Turning the input into the query, key, and value. The query, key, and value all have the same dimensions as the input, and can be thought of as several different representations of the input. From my article on transformers
Turning the input into the query, key, and value. The query, key, and value all have the same dimensions as the input, and can be thought of as several different representations of the input. From my article on transformers

The query and key are multiplied together. Thus, some representation of every word is combined with a representation of every other word.

Calculating the attention matrix with the query and key. The attention matrix is then used, in combination with the value, to generate the final output of the attention mechanism. From my article on transformers
Calculating the attention matrix with the query and key. The attention matrix is then used, in combination with the value, to generate the final output of the attention mechanism. From my article on transformers

The value is then multiplied by this abstract combination of the query and key, constructing the final output of multi headed self-attention.

The attention matrix (which is the matrix multiplication of the query and key) multiplied by the value matrix to yield the final result of the attention mechanism. Because of the shape of the attention matrix, the result is the same shape as the value matrix. Note, I'm skipping some very important steps. From my article on transformers
The attention matrix (which is the matrix multiplication of the query and key) multiplied by the value matrix to yield the final result of the attention mechanism. Because of the shape of the attention matrix, the result is the same shape as the value matrix. Note, I’m skipping some very important steps. From my article on transformers

The encoder uses multi-headed self attention to create abstract representations of the input, and the decoder uses multi-headed self attention to create abstract representations of the output.

The Transformer Architecture, with the encoder on the left and the decoder on the right. image source
The Transformer Architecture, with the encoder on the left and the decoder on the right. image source
Recall that this is what the transformer is doing, in essence.
Recall that this is what the transformer is doing, in essence.

That was a super quick rundown on transformers. I tried to cover the high points without getting too in the weeds, feel free to refer to my article on transformers for more information.

While the original transformer was created for English to French translation, in an abstract way the process of solving a Rubik’s Cube is somewhat similar. We have some input (a shuffled Rubik’s Cube, vs a French sentence), and we need to predict some sequence based on that input (a sequence of moves, vs a sequence of English words).

The transformer was originally designed for English to French translation (left), we're using it to solve a Rubik's Cube (right).
The transformer was originally designed for English to French translation (left), we’re using it to solve a Rubik’s Cube (right).

We don’t need to make any changes to the fundamental structure of the transformer to get it to solve a Rubik’s Cube. All we have to do is properly format the Rubik’s Cube and moves into a representation the transformer can understand. We’ll cover that in the following sections.

Defining The Cube and Moves

Originally, I thought to define the Rubik’s Cube as a 3x3x3 matrix of segments, each of which has some number of faces which have some color.

The Rubik's Cube can be thought of as a data structure. The Cube itself consisting of segments, each of which has some number of colored stickers.
The Rubik’s Cube can be thought of as a data structure. The Cube itself consisting of segments, each of which has some number of colored stickers.

This is totally possible, but I’m a data scientist and my 3D spatial programming hasn’t seen the light of day in a hot minute. After some reflection, I decided on a creative and perhaps more elegant approach: representing the Rubik’s Cube as a 5x5x5 tensor, rather than a 3x3x3 cube of segments.

Representing a Rubik's Cube as a 5x5x5 tensor, rather than a 3x3x3 cube of segments. Essentially, we're keeping track of the location of all of the stickers as they float through empty space. Note, the tensor on the right doesn't represent the same Rubik's Cube as the one on the left, so don't go pulling your hair out trying to reconcile the two.
Representing a Rubik’s Cube as a 5x5x5 tensor, rather than a 3x3x3 cube of segments. Essentially, we’re keeping track of the location of all of the stickers as they float through empty space. Note, the tensor on the right doesn’t represent the same Rubik’s Cube as the one on the left, so don’t go pulling your hair out trying to reconcile the two.

The essential idea is that we, technically speaking, don’t really care about the cube. Rather, we care about the stickers and where they are relative to each other. So, instead of having a 3x3x3 data structure consisting of complicated segments that have to obey complicated rules, we can simply put that cube within a 5x5x5 grid and keep track of where the sticker colors are within this space.

When we rotate a "face", we simply need to rotate all the spaces in the 5x5x5 grid which correspond to the stickers that would be on that face and corresponding edges.

Rotating a face, in our 5x5x5 tensor, means we just have to swap some values around. This might require a bit of head scratching, but it's way easier than actually defining a Rubik's Cube as a bunch of tiny segments, then getting them to play nicely with one another.
Rotating a face, in our 5x5x5 tensor, means we just have to swap some values around. This might require a bit of head scratching, but it’s way easier than actually defining a Rubik’s Cube as a bunch of tiny segments, then getting them to play nicely with one another.

There’s 12 fundamental rotations one can apply to a Rubik’s Cube. We can rotate the front, back, top, bottom, left, and right face (6 faces) and we can rotate each of those 90 degrees clockwise or counterclockwise.

When we scramble a cube we apply some number of these rotations, and then to solve the Rubik’s Cube one can simply reverse the order and direction of the moves.

Here’s a class that defines a Rubik’s Cube and it’s moves, as well as a neat little visualization:

"""Defining the Rubik's Cube
"""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from matplotlib.patches import Polygon
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

class RubiksCube:
    def __init__(self):
        # Initialize a 3D tensor to represent the Rubik's Cube
        self.cube = np.empty((5, 5, 5), dtype='U10')
        self.cube[:, :, :] = ''

        # Initialize sticker colors
        self.cube[0, 1:-1, 1:-1] = 'w'  # Top (white)
        self.cube[1:-1, 0, 1:-1] = 'g'  # Front (green)
        self.cube[1:-1, 1:-1, 0] = 'r'  # Left (red)
        self.cube[-1, 1:-1, 1:-1] = 'y' # Bottom (yellow)
        self.cube[1:-1, -1, 1:-1] = 'b' # Back (blue)
        self.cube[1:-1, 1:-1, -1] = 'o' # Right (orange)

    def print_cube(self):
        print(self.cube)

    def rotate_face(self, face, reverse=False):
        """
        Rotates a given face of the cube 90 degrees.

        Parameters:
            face (str): One of ['top', 'front', 'left', 'bottom', 'back', 'right']
            reverse (bool): if the rotation should be reversed
        """
        # maps a face to the section of the tensor which needs to be rotated
        rot_map = {
            'top': (slice(0, 2), slice(0, 5), slice(0, 5)),
            'left': (slice(0, 5), slice(0, 2), slice(0, 5)),
            'front': (slice(0, 5), slice(0, 5), slice(0, 2)),
            'bottom': (slice(3, 5), slice(0, 5), slice(0, 5)),
            'right': (slice(0, 5), slice(3, 5), slice(0, 5)),
            'back': (slice(0, 5), slice(0, 5), slice(3, 5))
        }

        # getting all of the stickers that will be rotating
        rotating_slice = self.cube[rot_map[face]]

        # getting the axis of rotation
        axis_of_rotation = np.argmin(rotating_slice.shape)

        # rotating about axis of rotation
        axes_of_non_rotation = [0,1,2]
        axes_of_non_rotation.remove(axis_of_rotation)
        axes_of_non_rotation = tuple(axes_of_non_rotation)
        direction = 1 if reverse else -1
        rotated_slice = np.rot90(rotating_slice, k=direction, axes=axes_of_non_rotation)

        # overwriting cube
        self.cube[rot_map[face]] = rotated_slice

    def _rotate_cube_180(self):
        """
        Rotate the entire cube 180 degrees by flipping and transposing
        this is used for visualization
        """
        # Rotate the cube 180 degrees
        rotated_cube = np.rot90(self.cube, k=2, axes=(0,1))
        rotated_cube = np.rot90(rotated_cube, k=1, axes=(1,2))
        return rotated_cube

    def visualize_opposite_corners(self):
        """
        Visualize the Rubik's Cube from two truly opposite corners
        """
        # Create a new figure with two subplots
        fig = plt.figure(figsize=(20, 10))

        # Color mapping
        color_map = {
            'w': 'white',
            'g': 'green',
            'r': 'red',
            'y': 'yellow',
            'b': 'blue',
            'o': 'orange'
        }

        # Cubes to visualize: original and 180-degree rotated
        cubes_to_render = [
            {
                'cube_data': self.cube,
                'title': 'View 1'
            },
            {
                'cube_data': self._rotate_cube_180(),
                'title': 'View 2'
            }
        ]

        # Create subplots for each view
        for i, cube_info in enumerate(cubes_to_render, 1):
            ax = fig.add_subplot(1, 2, i, projection='3d')

            ax.view_init(elev=-150, azim=45, vertical_axis='x')

            # Iterate through the cube and plot non-empty stickers
            cube_data = cube_info['cube_data']
            for x in range(cube_data.shape[0]):
                for y in range(cube_data.shape[1]):
                    for z in range(cube_data.shape[2]):
                        # Only plot if there's a color
                        if cube_data[x, y, z] != '':
                            color = color_map.get(cube_data[x, y, z], 'gray')

                            # Define the 8 vertices of the small cube
                            vertices = [
                                [x, y, z], [x+1, y, z],
                                [x+1, y+1, z], [x, y+1, z],
                                [x, y, z+1], [x+1, y, z+1],
                                [x+1, y+1, z+1], [x, y+1, z+1]
                            ]

                            # Define the faces of the cube
                            faces = [
                                [vertices[0], vertices[1], vertices[2], vertices[3]],  # bottom
                                [vertices[4], vertices[5], vertices[6], vertices[7]],  # top
                                [vertices[0], vertices[1], vertices[5], vertices[4]],  # front
                                [vertices[2], vertices[3], vertices[7], vertices[6]],  # back
                                [vertices[1], vertices[2], vertices[6], vertices[5]],  # right
                                [vertices[0], vertices[3], vertices[7], vertices[4]]   # left
                            ]

                            # Plot each face
                            for face in faces:
                                poly = Poly3DCollection([face], alpha=1, edgecolor='black')
                                poly.set_color(color)
                                poly.set_edgecolor('black')
                                ax.add_collection3d(poly)

            # Set axis limits and equal aspect ratio
            ax.set_xlim(0, 5)
            ax.set_ylim(0, 5)
            ax.set_zlim(0, 5)
            ax.set_box_aspect((1, 1, 1))

            # Remove axis labels and ticks
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_zticks([])
            ax.set_xlabel('')
            ax.set_ylabel('')
            ax.set_zlabel('')
            ax.set_title(cube_info['title'])

        plt.tight_layout()
        plt.show()

# Example usage
cube = RubiksCube()

# Rotate a face to show some variation
cube.rotate_face('bottom', reverse=False)

# Visualize from opposite corners
cube.visualize_opposite_corners()
Creating a Cube, rotating a face, and rendering a visualization of that cube. You can imagine View 1 being the main view, and View 2 being the view you might get from flipping the Rubik's Cube over and looking at the opposite corner. They both, together, represent a single Rubik's Cube.
Creating a Cube, rotating a face, and rendering a visualization of that cube. You can imagine View 1 being the main view, and View 2 being the view you might get from flipping the Rubik’s Cube over and looking at the opposite corner. They both, together, represent a single Rubik’s Cube.

We can scramble our Rubik’s Cube by simply performing a few random moves.

"""Scrambling a Rubik's Cube
"""
from itertools import product
import random

#Defining Possible Moves
faces = ['top', 'left', 'front', 'bottom', 'right', 'back']
possible_moves = tuple(product(faces, [False, True]))

def scramble(cube, n=20):
    moves = []
    for _ in range(n):
        #selecting a random move
        selected_move = random.choice(possible_moves)
        moves.append(selected_move)

        # Rotate a face to show some variation
        cube.rotate_face(selected_move[0], reverse=selected_move[1])
    return moves

#creating a cube
cube = RubiksCube()

#shuffling
moves = scramble(cube)
print(moves)

# Visualize from opposite corners
cube.visualize_opposite_corners()
An example of a shuffled Rubik's Cube
An example of a shuffled Rubik’s Cube

And, to solve that Rubik’s Cube we can simply reverse the order of the moves, and reverse the direction in which they rotate.

"""unscrambling by reversing moves and direction
"""

#reversing order of moves
moves.reverse()

for i in range(20):
    #selecting a random move
    selected_move = moves[i]

    # Rotate a face in the opposite direction
    cube.rotate_face(selected_move[0], reverse=not selected_move[1])

# Visualize from opposite corners
cube.visualize_opposite_corners()
The result of un-scrambling the Rubik's Cube by reversing the order and direction of the moves
The result of un-scrambling the Rubik’s Cube by reversing the order and direction of the moves

Using this code, we can generate a synthetic dataset consisting of shuffled Rubik’s Cubes and their solutions.

"""Parallelized code that generates 2M scrambled Rubik's Cubes,
and keeps track of the cube (X) and the moves to unscramble it (y)
"""

import random
from multiprocessing import Pool, cpu_count
from functools import partial

def generate_sample(max_scramble, _):
    """
    Generates a single sample (X, y) for the Rubik's Cube task.
    """
    num_moves = random.randint(1, max_scramble)

    # Initializing a cube and scrambling it
    cube = RubiksCube()
    moves = scramble(cube, n=num_moves)

    # Reversing moves, which is the solution
    moves.reverse()
    moves = [(m[0], not m[1]) for m in moves]

    # Turning into modeling data
    x = tokenize(cube.cube)
    y = [0] + [move_to_output_index(m) + 3 for m in moves] + [1]

    # Padding with 2s so the sequence length is always 22
    y.extend([2] * (22 - len(y)))

    return x, y

def parallel_generate_samples(num_samples, max_scramble, num_workers=None):
    """
    Parallelizes the generation of Rubik's Cube samples.
    """
    num_workers = num_workers or cpu_count()

    # Use functools.partial to "lock in" the max_scramble parameter
    generate_sample_partial = partial(generate_sample, max_scramble)

    with Pool(processes=num_workers) as pool:
        results = pool.map(generate_sample_partial, range(num_samples))

    # Unpack results into X and y
    X, y = zip(*results)
    return list(X), list(y)

num_samples = 2_000_000
max_scramble = 20

# Generate data in parallel
X, y = parallel_generate_samples(num_samples, max_scramble)

In this code, X is the thing we’ll be passing into the model (the shuffled Rubik’s Cube) and y will be the thing we try to predict (the sequence of operations to solve it).

In data science, it's common to label the input to the model as X and the desired output as y. Our dataset consists of two million examples of shuffled Rubik's Cubes (X), and their corresponding solutions (y)
In data science, it’s common to label the input to the model as X and the desired output as y. Our dataset consists of two million examples of shuffled Rubik’s Cubes (X), and their corresponding solutions (y)

I’m using two helper functions, tokenize and move_to_output_index , to help me turn the Rubik’s Cube and list of moves into a more friendly representation for modeling. I don’t think it’s necessary to go over the implementation (feel free to refer to the code), but from a high level:

  • the tokenize function accepts a 5x5x5 tensor consisting of sticker colors and empty spaces and outputs a 54×4 tensor. This 54×4 tensor has a vector for all 54 stickers in the Rubik’s Cube where each vector contains the (color, x position, y position, z position) of a particular sticker. It ignores all the empty spaces in the 5x5x5 tensor.
  • The move_to_output_index function simply turns a move, like (top, clockwise) into a number. All 12 moves are assigned a unique number. The reason we add the number 3 in the code will become apparent when we discuss the input and output of the decoder portion of the transformer model.

So, the input to the transformer is a list of 54 vectors consisting of (color, x position, y position, z position) , and the output of the transformer is a list of numbers where each number correspond to one of 12 moves.

After reformatting, a single X/y pair looks like this.
After reformatting, a single X/y pair looks like this.

This slight re-formatting of the data has little conceptual impact but will be practically handy when we apply a model to this data.

Now that we’ve constructed our dataset of shuffled Rubik’s Cubes and sequences of moves to solve them, we can work towards feeding that data into a transformer. This is done through a process called embedding.

Embedding

When creating a transformer in a natural language context you first create a vocabulary consisting of the words (or pieces of words) that your model will understand. Then, you turn all the words in the model’s vocabulary into a vector. When the model receives an input sequence of words, it can "think" about the sequence by doing math with the vectors that represent the words.

The job of a word to vector embedder: turn words into numbers so that a language model can reason about them. From my article on transformers.
The job of a word to vector embedder: turn words into numbers so that a language model can reason about them. From my article on transformers.

We can do something similar to our Rubik’s Cube, we can think of the "vocabulary" of our Rubik’s Cube as six tokens, one for each colored sticker. Then we can assign each of those colors some random vector that represents it.

We can assign each of the colors of a Rubik's Cube some random vector of numbers. These vectors will be used to allow our models to reason about stickers and how they relate to one another. Initially the values in these vectors will be defined randomly, but they'll be updated based on the model's needs through the training process, allowing the model to come up with its own representations of the colors.
We can assign each of the colors of a Rubik’s Cube some random vector of numbers. These vectors will be used to allow our models to reason about stickers and how they relate to one another. Initially the values in these vectors will be defined randomly, but they’ll be updated based on the model’s needs through the training process, allowing the model to come up with its own representations of the colors.

When we want to give a Rubik’s Cube to the input of our encoder, we can iterate through all the stickers in the Rubik’s Cube and, every time there’s a sticker, we can look up the vector that corresponds to that color and add it to a sequence.

Recall that we re-represented the Rubik's Cube as a list of vectors, where the first item in each vector corresponded a stickers color. If we iterate through that list, and look up the corresponding vector of each color, we can effectively turn the Rubik's Cube into a sequence of vectors.
Recall that we re-represented the Rubik’s Cube as a list of vectors, where the first item in each vector corresponded a stickers color. If we iterate through that list, and look up the corresponding vector of each color, we can effectively turn the Rubik’s Cube into a sequence of vectors.

Transformers have become widely popular in a variety of applications, from computer vision, to audio syntheses, to video generation. It turns out "take whatever data you have and represent it as a list of vectors, then throw it into a transformer" is a pretty good general strategy.

Before we put this sequence of vectors into a transformer, though, we need one more piece of information: position.

Positional Encoding

Every time we convert our Rubik’s Cube into a sequence of vectors, weather we’re training our model or trying to predict a solution, we’ll be converting the stickers of the Rubik’s Cube to vectors in the same order. That means each location in the embedded sequence will always correspond to the same location in the Rubik’s Cube.

If we use well defined and consistent code, then the first index in our list of vectors will always represent the same location in our Rubik's Cube. Likewise, the second, third, fourth, and all other vectors will all always represent the same location across Rubik's Cubes. Note, the colors of this particular Rubik's Cube don't correspond to this list of vectors shown, this figure is for conceptual demonstration.
If we use well defined and consistent code, then the first index in our list of vectors will always represent the same location in our Rubik’s Cube. Likewise, the second, third, fourth, and all other vectors will all always represent the same location across Rubik’s Cubes. Note, the colors of this particular Rubik’s Cube don’t correspond to this list of vectors shown, this figure is for conceptual demonstration.

Some modeling strategies, like convolutional and dense networks, are good at learning to leverage this consistency. They can learn "this location corresponds to this sticker, that location corresponds to that sticker", and thus it’s not necessary to add any additional information about position into the input.

Transformer style models, on the other hand, are famously prone to losing track of the order of the input. To create their abstract and meaning rich representations, they mix and mangle the input so much that positional information (as in "this vector came before that vector") is lost very quickly. As a result, when using a transformer, it’s customary to use a positional encoding.

The idea is to add some information about where each sticker was in our Rubik’s Cube to the vector which represents that stickers color. This will allow the model to inject explicit information about location into the value which represents each sticker, meaning it can reason about that stickers position, as well as its color.

The idea is to add some positional information to the value of the vectors themselves, allowing the values of the vector to represent both a stickers color and position.
The idea is to add some positional information to the value of the vectors themselves, allowing the values of the vector to represent both a stickers color and position.

We’ll be using a lookup table very similar to the approach described in the previous section. In the previous section we assigned a random vector to each color.

Recall that we encoded sticker color by creating a lookup table for each sticker color which corresponds to some random vector.
Recall that we encoded sticker color by creating a lookup table for each sticker color which corresponds to some random vector.

To encode position, we’ll also assign a random vector to each X, Y, and Z position in the 5x5x5 space of vectors.

Recall that we converted our Rubik's Cube into a list of vectors corresponding to the color, and X, Y, Z position of each sticker
Recall that we converted our Rubik’s Cube into a list of vectors corresponding to the color, and X, Y, Z position of each sticker
Each of the five possible positions, across each axis, can be assigned a random vector
Each of the five possible positions, across each axis, can be assigned a random vector

For each sticker we can add the vector for where that sticker was along the X, Y, and Z axis to the vector that represents the stickers color and, as a result, represent the sticker color and position of each sticker using a single vector.

If we add the vectors for color, X, Y, and Z positions together, we can create a list of vectors that represent all the stickers in the Rubik's Cube.
If we add the vectors for color, X, Y, and Z positions together, we can create a list of vectors that represent all the stickers in the Rubik’s Cube.

Here’s the implementation for a model which can embed the sticker colors of a Rubik’s Cube and apply a positional encoding:

import torch
import torch.nn as nn

class EncoderEmbedding(nn.Module):
    def __init__(self, vocab_size=6, pos_i_size=5, pos_j_size=5, pos_k_size=5, embedding_dim=128):
        super(EncoderEmbedding, self).__init__()

        # Learnable embeddings for each component
        self.vocab_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_i_embedding = nn.Embedding(pos_i_size, embedding_dim)
        self.pos_j_embedding = nn.Embedding(pos_j_size, embedding_dim)
        self.pos_k_embedding = nn.Embedding(pos_k_size, embedding_dim)

    def forward(self, X):
        """
        Args:
            X (torch.Tensor): Input tensor of shape (batch_size, seq_len, 4)
                where X[..., 0] = vocab indices (0-5)
                      X[..., 1] = position i (0-4)
                      X[..., 2] = position j (0-4)
                      X[..., 3] = position k (0-4)
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, embedding_dim)
        """
        # Split the input into components
        vocab_idx = X[..., 0]
        pos_i_idx = X[..., 1]
        pos_j_idx = X[..., 2]
        pos_k_idx = X[..., 3]

        # Look up embeddings
        vocab_embed = self.vocab_embedding(vocab_idx)
        pos_i_embed = self.pos_i_embedding(pos_i_idx)
        pos_j_embed = self.pos_j_embedding(pos_j_idx)
        pos_k_embed = self.pos_k_embedding(pos_k_idx)

        # Sum the embeddings
        final_embedding = vocab_embed + pos_i_embed + pos_j_embed + pos_k_embed
        return final_embedding

embedding_dim = 128

# Initialize the input embedding module
encoder_embedding = EncoderEmbedding(embedding_dim=embedding_dim)

# Get the final embeddings
embedded_encoder_input = encoder_embedding(X[:10])
print("Input shape:", X[:10].shape)
print("Output shape:", embedded_encoder_input.shape)
In this particular example the input is a batch of 10 Rubik's Cubes. Each of these cubes has 54 stickers which are described as a 4 value vector: The stickers color, x position, y position, and z position. The encoder_embedding looks up the four vectors that represent the color and positions of the sticker and adds them all together. Thus, the output is a batch of 10 Rubik's Cubes, now represented as 54 unique vectors where each vector contains values that describe the color and position of the sticker. 128 represents the model_dimension, which is an arbitrary number defining how long the vectors are which represent an input element. Large transformers (like OpenAI's GPT models) use vectors which are many hundreds if not thousands of values long. With a larger model dimension more complex representations of the input can be made, at the cost of computational load. I opted to go for a relatively small model dimension, but still large enough to allow the model to think through the problem. This parameter could certainly be played with.
In this particular example the input is a batch of 10 Rubik’s Cubes. Each of these cubes has 54 stickers which are described as a 4 value vector: The stickers color, x position, y position, and z position. The encoder_embedding looks up the four vectors that represent the color and positions of the sticker and adds them all together. Thus, the output is a batch of 10 Rubik’s Cubes, now represented as 54 unique vectors where each vector contains values that describe the color and position of the sticker. 128 represents the model_dimension, which is an arbitrary number defining how long the vectors are which represent an input element. Large transformers (like OpenAI’s GPT models) use vectors which are many hundreds if not thousands of values long. With a larger model dimension more complex representations of the input can be made, at the cost of computational load. I opted to go for a relatively small model dimension, but still large enough to allow the model to think through the problem. This parameter could certainly be played with.

When we create a new model, we’ll be using completely random vectors to represent both the sticker color and where those stickers are located. Naturally this random information will probably be really hard for the model to understand at first. The idea is that, throughout the training process, these random vectors will update so that the model learns to create vectors for sticker color and position which it understands.

We just implemented the input embedding and positional encoding within the greater transformer architecture
We just implemented the input embedding and positional encoding within the greater transformer architecture

So, we’ve turned the Rubik’s Cube into a list of vectors which a transformer can understand. We’ll also need to perform a similar process to the sequence of moves we want the model to output but, before we do, I’d like to discuss some intricacies of the decoder.

Join IAEE
Join IAEE

The Input and Output of the Decoder

Recall that when a transformer outputs a sequence, it does so "autoregressively", meaning when you put some sequence into the decoder, the decoder will output a prediction as to what it thinks the next token should be. That new token can be then fed back into the input of the decoder, allowing the transformer to generate a sequence one token at a time.

Recall the general process of autoregressive generation. the decoder references the encoded input for every loop of the output. The decoder generates the entire sequence by taking it's previous outputs as input, and predicting what token it thinks should come next. From my article on transformers
Recall the general process of autoregressive generation. the decoder references the encoded input for every loop of the output. The decoder generates the entire sequence by taking it’s previous outputs as input, and predicting what token it thinks should come next. From my article on transformers

One of the defining characteristics of a transformer is the way they’re trained. Older styles of models would typically train on one token at a time. You would feed a sequence into the model, predict a token, then update the model based on whether it was right or wrong. This is an incredibly slow and computationally expensive process, and severely limited older styles of models when applying them to sequences.

When training a transformer, on the other hand, you input the entire sequence you want, then the transformer predicts all tokens for each input as if future tokens did not exist.

How the decoder of a transformer is trained. Instead of training on words one by one, it's given the entire sequence of words it's supposed to output and is tasked with predicting all next words simultaneously. The reason the decoder uses "masked" multi-headed self-attention is to prevent the model from trivially copying future words to come up with its output. When predicting what word should come after "trained" in this example, the model can only see the sequence " A Sentence a model is being trained" and is tasked with predicting that the word "on" is the next word. From my article on Speculative Sampling
How the decoder of a transformer is trained. Instead of training on words one by one, it’s given the entire sequence of words it’s supposed to output and is tasked with predicting all next words simultaneously. The reason the decoder uses "masked" multi-headed self-attention is to prevent the model from trivially copying future words to come up with its output. When predicting what word should come after "trained" in this example, the model can only see the sequence " A Sentence a model is being trained" and is tasked with predicting that the word "on" is the next word. From my article on Speculative Sampling

So, it predicts the next token for the first spot, the next token for the second spot, the next token for the third spot, etc. simultaneously.

I talk about how this works in my article on transformers more in depth, and explore how this quirk of transformers can be used to interesting effect in my article on speculative sampling. For now, though, we know enough of the high-level theory to discuss implementing the embedding and positional encoding for the decoder.

Tokenizing, Embedding, and Positionally Encoding the Solution Sequence

As discussed in the previous section, when we’re training our model we need to input the sequence we want into the decoder, then the decoder will make predictions of all the next moves in the sequence as if future moves didn’t exist.

A breakdown of the general flow of training. Our Rubik's Cube get's represented as a list of vectors which represent the color and position of each sticker, which is passed to the encoder. The decoder uses the encoded input to generate move predictions
A breakdown of the general flow of training. Our Rubik’s Cube get’s represented as a list of vectors which represent the color and position of each sticker, which is passed to the encoder. The decoder uses the encoded input to generate move predictions

Just like the input to the encoder, the input to the decoder will take the form of a sequence of vectors.

The input to the encoder is similar to the input to the decoder: a list of vectors. However, the vectors for the encoder represent stickers, while the vectors for the decoder represent moves.
The input to the encoder is similar to the input to the decoder: a list of vectors. However, the vectors for the encoder represent stickers, while the vectors for the decoder represent moves.

The process of creating those vectors is similar to the process in which we embedded and positionally encoded the Rubik’s Cube. There are 12 possible moves (6 faces in two directions), meaning each of the moves can be represented with a list of 12 vectors.

We can create a "vocabulary" of moves. We can assign each move some vector. When we want to input some list of moves into the decoder, we'll do so by feeding the decoder a list of these vectors.
We can create a "vocabulary" of moves. We can assign each move some vector. When we want to input some list of moves into the decoder, we’ll do so by feeding the decoder a list of these vectors.

These moves can be positionally encoded by creating a vector for each location in the sequence.

When converting a sequence of moves into a sequence of vectors, we need to add a vector for each position to encode location, just like what we did when defining the positional encoding for the Rubik's Cube. For positionally encoding the sequence we need a vector for each possible element in the sequence.
When converting a sequence of moves into a sequence of vectors, we need to add a vector for each position to encode location, just like what we did when defining the positional encoding for the Rubik’s Cube. For positionally encoding the sequence we need a vector for each possible element in the sequence.

Apparently, the maximum number of moves to solve a Rubik’s Cube is 20 moves (Don’t ask me how they figured that out), so we can assume the output sequence of our model will have a maximum length of 20, plus space for two "utility tokens".

A utility token is a special token that doesn’t matter in terms of final output but is useful from a modeling perspective. For instance, it can be useful for a model to have a way to say it’s done generating output. This is a common token called the "end of sequence", often abbreviated as <EOS> , token.

Also, recall how the decoder predicts all next tokens based on an input token, meaning we need to input some token to get the first prediction. It’s common practice to prepend each sequence with a "start of sequence" (<SOS>) token, which holds the space for the first prediction from the model.

Another token we’ll be using is a pad (<PAD>) token. Basically, all the math under the hood of the transformer uses matrices, which require some uniform shape. So, if we have a few sequences of moves that are short, and a few sequences that are long, they all need to fit within the same matrix. We can do that by "padding" all the short sequences until they’re the same length as the longest sequence.

The "vocabulary" for our model is actually 15 elements: 12 moves and three utility tokens.
The "vocabulary" for our model is actually 15 elements: 12 moves and three utility tokens.

So, the "vocabulary" of the decoder will be our 12 possible moves, plus the start-of-sequence (<SOS>), end-of-sequence (<EOS>), and pad (<PAD>) tokens. The total sequence length will be 22, because the maximum number of moves needed to solve any Rubik’s Cube is 20 (again, I have no idea why), and we need to make room for an (<SOS>) and (<EOS>) token on even the longest sequences.

Now that our tokens are thought out, and we know how long the sequence will be, we can just initialize 15 random vectors for the token embedding, and 22 random vectors for each location in the sequence. When we take in some sequence to either train or make some prediction, we can use these vectors to represent both the value and position of all of the moves.

The move vectors will be the values for the move plus the values of the move position in the sequence.
The move vectors will be the values for the move plus the values of the move position in the sequence.

Let’s go ahead and implement the decoder embedding. First of all, our data already has our utility tokens built in. Recall we used this code to generate the y portion of our dataset

y = [0] + [move_to_output_index(m) + 3 for m in moves] + [1]

# Padding with 2s so the sequence length is always 22
y.extend([2] * (22 - len(y)))

here we’re converting each of our moves to an integer from 3–14 with the expression move_to_output_index(m) + 3 (which adds 3 to the numbers which represent our possible moves labeled as 0–11). It then adds 0 and 1 to the beginning and end of the sequence and appends a list of 2's on the end until the total sequence is of length 22.

Thus:

  • 0 represents start of sequence <sos>
  • 1 represents end of sequence <eos>
  • 2 represents pad <pad>
  • 314 represent our possible moves
What a batch of 10 solution sequences look like in y. They all start with  (0), proceed with some number of moves (3–14), and in  (1), then any remaining room is filled with  (2).
What a batch of 10 solution sequences look like in y. They all start with (0), proceed with some number of moves (3–14), and in (1), then any remaining room is filled with (2).

So, we can implement embedding and positional encoding for our sequence of moves as follows:

class DecoderEmbedding(nn.Module):
    def __init__(self, vocab_size=15, pos_size=22, embedding_dim=128):
        super(DecoderEmbedding, self).__init__()

        # Learnable embeddings for each component
        self.vocab_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_embedding = nn.Embedding(pos_size, embedding_dim)

    def forward(self, X):
        """
        Args:
            X (torch.Tensor): Input tensor of shape (batch_size, seq_len), where each element
                              corresponds to a token index.

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, embedding_dim)
        """
        # Token embeddings (based on vocab indices)
        vocab_embed = self.vocab_embedding(X)

        # Generate position indices based on input shape
        batch_size, seq_len = X.shape
        position_indices = torch.arange(seq_len, device=X.device).unsqueeze(0).expand(batch_size, -1)

        # Position embeddings
        pos_embedding = self.pos_embedding(position_indices)

        # Sum the embeddings
        final_embedding = vocab_embed + pos_embedding
        return final_embedding

embedding_dim = 128

# Initialize the input embedding module
decoder_embedding = DecoderEmbedding(embedding_dim=embedding_dim)

# Get the final embeddings
embedded_decoder_input = decoder_embedding(y[:10])
print("Input shape:", y[:10].shape)
print("Output shape:", embedded_decoder_input.shape)
Here, we're processing a batch of 10 solution sequences. This encoder takes a list of 22 tokens, expressed as one of our 15 integers which corresponds to the 15 elements in our move and utility token vocabulary, and converts each of those tokens into a vector. It does that by adding the vector that represents the value of the move with a vector that represents the position of that move within the sequence. For convenience and for the same conceptual reasons, the model_dim of 128 is the same as the one used in encoding the Rubik's Cube which we defined in a previous section.
Here, we’re processing a batch of 10 solution sequences. This encoder takes a list of 22 tokens, expressed as one of our 15 integers which corresponds to the 15 elements in our move and utility token vocabulary, and converts each of those tokens into a vector. It does that by adding the vector that represents the value of the move with a vector that represents the position of that move within the sequence. For convenience and for the same conceptual reasons, the model_dim of 128 is the same as the one used in encoding the Rubik’s Cube which we defined in a previous section.
What we just implemented, an embedder and positional encoder for the move sequences.
What we just implemented, an embedder and positional encoder for the move sequences.

Alright, we’ve figured out how to encode both the Rubik’s Cube, and sequences of moves to solve them, in a way the transformer can understand (a big list of vectors). Now we can actually get into building the transformer.

Implementing the Transformer

The main point of this article isn’t really the model, but rather the thought process around making modeling decisions. We’ve really done all the heavy lifting already. By turning our moves into vectors which a transformer understands, we can use the same standard transformer used in countless other applications.

I’ve covered the core ideas of the transformer in many different articles at this point. Still, this wouldn’t be "exhaustive" if I didn’t cover implementing the transformer, so let’s do it. This will be a fairly brief pass over the process, feel free to dig into some of the linked articles in the reference section for a more in-depth understanding.

1) The Encoder

We already made the embedding and positional encoding that turns a Rubik’s Cube into a list of vectors, so now we need to implement the encoder portion of the model which is tasked with thinking about that representation and turning it into an abstract but meaning rich representation.

import torch
import torch.nn as nn

# Define the Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoder, self).__init__()

        # Define a single transformer encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True
        )

        # Stack multiple encoder layers
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, src):
        """
        Args:
            src (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
        """
        return self.encoder(src)

# Example usage
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1

# Initialize the transformer encoder
encoder = TransformerEncoder(num_layers=num_layers, d_model=embedding_dim, num_heads=num_heads, d_ff=d_ff, dropout=dropout)

# Forward pass
encoder_output = encoder(embedded_encoder_input)

print("Encoder output shape:", encoder_output.shape)  # Should be (seq_len, batch_size, d_model)
The output of the encoder is the same size as the input. We gave it a batch of 10 Rubik's Cubes, each with 54 stickers represented with a vector of 128, and we got the same sized output. The encoder gets all these vectors to interact with each other, giving us a similar sized input, but one that is much more abstract and meaning rich.
The output of the encoder is the same size as the input. We gave it a batch of 10 Rubik’s Cubes, each with 54 stickers represented with a vector of 128, and we got the same sized output. The encoder gets all these vectors to interact with each other, giving us a similar sized input, but one that is much more abstract and meaning rich.

pytorch already has an implementation for the encoder block, so we just used that.

  • num_heads describes how many heads are used in each multi-headed self attention block (see my article on transformers if you want to understand more)
  • num_layers describes how many encoder blocks are used (see my article on transformers if you want to understand more)
  • d_ff is how large the feed-forward network in the transformer is. This is typically much larger than the model_dim as it allows the feed forward network to look at each vector in the model, expand it into a few representations, then shrink those vectors back down into the original size based on that expanded information.
  • dropout is a regularizing parameter which randomly hides certain values in the model. dropout is a common trick that helps AI models learn trends in data without simply memorizing individual examples in the dataset.

2) The Decoder

The decoder takes the embedded representation of the move sequence, and combines it with the output of the encoder to make predictions about which move should come next.

So, let’s build it:

import torch
import torch.nn as nn

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()

        # Masked Multi-Head Self-Attention
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.self_attn_norm = nn.LayerNorm(d_model)
        self.self_attn_dropout = nn.Dropout(dropout)

        # Masked Multi-Head Cross-Attention
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.cross_attn_norm = nn.LayerNorm(d_model)
        self.cross_attn_dropout = nn.Dropout(dropout)

        # Point-wise Feed Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.ffn_norm = nn.LayerNorm(d_model)
        self.ffn_dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory):
        """
        Args:
            tgt (torch.Tensor): Target sequence of shape (batch_size, tgt_seq_len, d_model).
            memory (torch.Tensor): Encoder output of shape (batch_size, src_seq_len, d_model).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, tgt_seq_len, d_model).
        """
        tgt_len = tgt.size(1)

        # Generate causal mask for self-attention (causal masking)
        causal_mask = torch.triu(torch.ones(tgt_len, tgt_len, device=tgt.device), diagonal=1).to(torch.bool)

        # Masked Multi-Head Self-Attention
        self_attn_out, _ = self.self_attn(
            tgt, tgt, tgt,
            attn_mask=causal_mask,
        )
        tgt = self.self_attn_norm(tgt + self.self_attn_dropout(self_attn_out))

        # Masked Multi-Head Cross-Attention
        cross_attn_out, _ = self.cross_attn(
            tgt, memory, memory,
        )
        tgt = self.cross_attn_norm(tgt + self.cross_attn_dropout(cross_attn_out))

        # Feed Forward Network
        ffn_out = self.ffn(tgt)
        tgt = self.ffn_norm(tgt + self.ffn_dropout(ffn_out))

        return tgt

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, tgt, memory):
        """
        Args:
            tgt (torch.Tensor): Target sequence of shape (batch_size, tgt_seq_len, d_model).
            memory (torch.Tensor): Encoder output of shape (batch_size, src_seq_len, d_model).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, tgt_seq_len, d_model).
        """
        for layer in self.layers:
            tgt = layer(tgt, memory)
        return self.norm(tgt)

# Example usage
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
embedding_dim = 128

# Initialize the transformer decoder
decoder = TransformerDecoder(num_layers=num_layers, d_model=embedding_dim, num_heads=num_heads, d_ff=d_ff, dropout=dropout)

# Example inputs
tgt_seq_len = 22
src_seq_len = 54
batch_size = 10

# Target and memory
tgt = torch.randn(batch_size, tgt_seq_len, embedding_dim)
memory = torch.randn(batch_size, src_seq_len, embedding_dim)

# Forward pass through the decoder
decoder_output = decoder(tgt, memory)

print("Decoder output shape:", decoder_output.shape)  # Expected shape: (batch_size, tgt_seq_len, d_model)
Here, the decoder is taking in a batch of 10 solutions, each with 22 possible tokens (our moves and three utility tokens) represented as a 128 valued vector, and it's outputting a matrix of the same exact dimension. While the output is a similar size to the input, it is much more abstract, meaning rich, and has allowed the representation of the Rubik's Cube to interact with the previous moves.
Here, the decoder is taking in a batch of 10 solutions, each with 22 possible tokens (our moves and three utility tokens) represented as a 128 valued vector, and it’s outputting a matrix of the same exact dimension. While the output is a similar size to the input, it is much more abstract, meaning rich, and has allowed the representation of the Rubik’s Cube to interact with the previous moves.

3) The Classification Head

The output of the decoder represents all of the moves the model thinks should be taken, but it does so as a big list of abstract vectors. The goal of the Classification Head is to turn each of these abstract vectors into a prediction of what token should be output (our 12 moves and 3 utility tokens). we do that by simply using a neural network on each vector to turn our 128-value long vector into a vector of length 15. Then we turn those 15 values into probabilities (where bigger numbers are higher probability) using an operation called SoftMax.

So, in other words, the prediction head turns all our abstract vectors into a prediction of which move should happen at each spot in the solution sequence.

The output of the decoder starts as a bunch of random vectors of length model_dim. After the linear projection, all those vectors are turned into vectors of length vocab_size. Then, softmax is used to turn that into a probabalistic prediction.
The output of the decoder starts as a bunch of random vectors of length model_dim. After the linear projection, all those vectors are turned into vectors of length vocab_size. Then, softmax is used to turn that into a probabalistic prediction.

Here’s that code:

lass ProjHead(nn.Module):
    def __init__(self, d_model=128, num_tokens=15):
        super(ProjHead, self).__init__()
        self.num_tokens = num_tokens

        # Linear layer to project from d_model to num_tokens
        self.fc = nn.Linear(d_model, num_tokens)

        # Softmax activation to convert logits into probabilities
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, logits):
        """
        Args:
            logits (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).

        Returns:
            torch.Tensor: Output probabilities of shape (batch_size, seq_len, num_tokens).
        """
        # Project logits through a linear layer
        projected_logits = self.fc(logits)

        # Apply Softmax to convert to probabilities
        probabilities = self.softmax(projected_logits)

        return probabilities

# Initialize the module
logits_to_probs = ProjHead()

# Convert logits to probabilities
probabilities = logits_to_probs(decoder_output)

print("Probabilities shape:", probabilities.shape)  # Expected: (batch_size, seq_len, num_tokens)
print("Sum of probabilities for first token:", probabilities[0, 0].sum().item())  # Should be close to 1.0
passing a batch of 10 solution sequences through the decoder, each of which has a sequence length of 22. For each point in the sequence the prediction head makes a prediction about which of the 15 possible tokens is the most relevant to be the next prediction. For a single output, the probabilities of all 15 predictions add up to 1 (or close to 1, due to rounding error).
passing a batch of 10 solution sequences through the decoder, each of which has a sequence length of 22. For each point in the sequence the prediction head makes a prediction about which of the 15 possible tokens is the most relevant to be the next prediction. For a single output, the probabilities of all 15 predictions add up to 1 (or close to 1, due to rounding error).

4) The Model

We created:

  • a way to embed sticker colors and encode the location of stickers on a Rubik’s Cube as a list of vectors
  • a way to embed moves and encode their position in a solution sequence as a list of vectors
  • the encoder, which accepts the vectors representing a Rubik’s Cube and allows a model to convert that data into a dense and meaning rich representation
  • the decoder, which accepts previous moves and outputs predictions of future moves that should be taken based on the value of previous moves and the encoded representation of the Rubik’s Cube
  • A projection head which takes the output of the decoder and turns it into probabilistic predictions that certain moves should be made

Now we can put that all together to define the actual model

class RubiksCubeTransformer(nn.Module):
    def __init__(self, layers_encoder=5, layers_decoder=5, d_model=128):
        super(RubiksCubeTransformer, self).__init__()

        #turns the tokens that go into the encoder and decoder into vectors
        self.encoder_embedding = EncoderEmbedding(embedding_dim=d_model)
        self.decoder_embedding = DecoderEmbedding(embedding_dim=d_model)

        #Defining the Encoder and Decoder
        self.encoder = TransformerEncoder(num_layers=layers_encoder, d_model=d_model, num_heads=4, d_ff=d_model*2, dropout=0.1)
        self.decoder = TransformerDecoder(num_layers=layers_decoder, d_model=d_model, num_heads=4, d_ff=d_model*2, dropout=0.1)

        #Defining the projction head to turn logits into probabilities
        self.projection_head = ProjHead(d_model=d_model, num_tokens=15)

    def forward(self, X, y):

        #embedding both inputs
        X_embed = self.encoder_embedding(X)
        y_embed = self.decoder_embedding(y)

        #encoding Rubiks Cube representation
        X_encode = self.encoder(X_embed)

        #decoding embedded previous moves cross attended with rubiks cube encoding
        y_decode = self.decoder(y_embed, X_encode)

        #turning logits from the decoder into predictions
        return self.projection_head(y_decode)

model = RubiksCubeTransformer()
model(X[:10], y[:10]).shape
The model took in a batch of 10 Rubik's Cubes (X), each of their solutions (y) and, for each one, made a prediction of which of the 15 tokens it thought should be in all possible locations in the 22 output positions.
The model took in a batch of 10 Rubik’s Cubes (X), each of their solutions (y) and, for each one, made a prediction of which of the 15 tokens it thought should be in all possible locations in the 22 output positions.

Now we can train this model on our synthetic dataset. Before we do, though, I’d like to take a step back and consider some of the costs and benefits of our approach.

The Subtlety of Training Strategy

Before we get into creating our model, I’d like to reflect on the synthetic dataset we created in the previous section, and some of the implications that dataset suggests.

In this article we’re computing a perfectly random sequence of moves, using it to shuffle a Rubik’s Cube, then asking the model to predict the exact opposite of the shuffling sequence. This is great in theory, but there’s a practical problem; if we happened to generate a random sequence of moves like this:

<sos>, front clockwise, front counterclockwise, front clockwise, front counterclockwise, <eos>

Then instead of training the model to predict that the Rubik’s Cube is already completed (because it is, these random moves simply undo each other), we would train the model to predict the same erroneous set of steps and then output that the Rubik’s Cube is completed.

There are a lot of ways to get around this problem. I went with the simplest approach: ignoring it.

Transformers, the style of model we’re using, are known to perform remarkably well in a natural language context, which has a lot of random noise and occasional non-sense. So, we already know transformers are good at learning to model complex sequences despite some poor-quality examples in the training set.

The hope, for us, is that silly moves will be much less common than productive moves and, as a result, the model will tend to learn productive decisions.

So, Basically, if a transformer is good at learning language even if there are occasionally silly words in the training set, maybe it will be good at solving a Rubik’s Cube even if the dataset it’s trained on has occasionally silly moves.

It can be easy to be too hopeful about this type of assumption early on. It’s important to remember that, when constructing an AI model, the model is attempting to learn exactly what you’re training it to do. No more, no less. We can hope that the nature of the model will deal with quirks in our synthetic dataset elegantly, but we’ll only really know if we made the right call once we’ve gone ahead, trained, and then tested our model.

Generally speaking, I’ve found that the best modeling strategy is the one you think might work and can implement quickly. Iteration is a fact of life in complex ML problems.

So let’s give it a shot. We have a transformer and a dataset, let’s train this sucker.

Training the Model

Before I actually train the model, I’m doing a bit of setup work:

import os
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

# Define the checkpoint directory
checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/Blogs/RubiksCubeCheckpoints"

# Initialize key variables
batch_losses = []
epoch_iter = 0  # Keeps track of total epochs trained

# User option: Start from scratch or resume from the last checkpoint
start_from_scratch = False  # Set this to True to start training from scratch

if start_from_scratch:
    print("Starting training from scratch...")

    # Check for GPU availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the model and move it to GPU
    model = RubiksCubeTransformer(layers_encoder=6, layers_decoder=3, d_model=64).to(device)

    # Move data to GPU
    X = X.to(device)
    y = y.to(device)

    # Define dataset and data loader
    batch_size = 16
    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

else:
    print("Attempting to resume training from the latest checkpoint...")

    # Check for GPU availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the model and move it to GPU
    model = RubiksCubeTransformer(layers_encoder=6, layers_decoder=3, d_model=64).to(device)

    # Move data to GPU
    X = X.to(device)
    y = y.to(device)

    # Define dataset and data loader
    batch_size = 16
    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Load the latest checkpoint if available
    latest_checkpoint = None
    if os.path.exists(checkpoint_dir):
        print(os.listdir(checkpoint_dir))
        checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
        print(checkpoints)
        if checkpoints:
            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))  # Sort by epoch
            latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])

    if latest_checkpoint:
        print(f"Loading checkpoint: {latest_checkpoint}")
        checkpoint = torch.load(latest_checkpoint)

        # Load model and optimizer states
        model.load_state_dict(checkpoint['model_state_dict'])

        # Initialize optimizer and load its state
        optimizer = optim.Adam(model.parameters(), lr=1e-5)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Set epoch_iter to the last epoch from the checkpoint
        epoch_iter = checkpoint['epoch']
        print(f"Resuming training from epoch {epoch_iter}")
    else:
        raise ValueError('No Checkpoint Found')

Transformers can take a while to train (I trained this model over the course of several days). Also, Google Colab has a tendency to log you out of a session if you’ve been away from the keyboard for too long. As a result, it was vital to save model checkpoints somewhere such that I could recover and resume training. This code allows me to recover the most recent checkpoint from my Google Drive before I continue. If there’s no checkpoint, it defines a new model.

this code also does some other quality of life things, like turning our training data into a DataLoader that takes care of creating batches and shuffling data across epochs, and making sure our model and data are both on the GPU.

There’re a few things going on in the actual training code. Let’s go through section by section.

from google.colab import drive
import torch
import torch.nn as nn
from tqdm import tqdm
import os

# Printing out parameter count
print('model param count:')
print(count_parameters(model))

# Define loss function
criterion = nn.CrossEntropyLoss()

verbose = False

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in tqdm(data_loader):
        X_batch, y_batch = batch

        if verbose:
            print('n==== Batch Examples ====')
            num_examples = 2
            print('Encoder Input')
            print(X_batch[:num_examples])
            print('Decoder Input')
            print(y_batch[:num_examples, :-1])
            print('Decoder target')
            print(y_batch[:num_examples, 1:])

        # Move batch data to GPU (if they're not already)
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()

        # Defining the input sequence to the model
        y_input = y_batch[:, :-1]

        # Forward pass
        y_pred = model(X_batch, y_input)

        # Transform target to one-hot encoding
        y_target = F.one_hot(y_batch[:, 1:], num_classes=15).float().to(device)

        # Compute loss
        loss = criterion(y_pred.view(-1, 15), y_target.view(-1, 15))
        running_loss += loss.item()
        batch_losses.append(loss.item())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        if verbose:
            break
    if verbose:
        break

    epoch_iter += 1

    # Print epoch loss
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(data_loader)}")

    # Save checkpoint every epoch
    if (epoch_iter + 1) % 1 == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch_iter+1}.pt")
        torch.save({
            'epoch': epoch_iter + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': running_loss / len(data_loader),
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")
An example of some of the output when training the model. Training was done over the course of several days.
An example of some of the output when training the model. Training was done over the course of several days.

I have this little block of code for my own debugging purposes. I defined a function that gives me the total number of trainable parameters in my model, which is useful in me getting a rough idea of how large the model I’m training is.

print('model param count:')
print(count_parameters(model))

Next I’m defining my "criterion", which is a fancy way of saying how I’ll be judging how right or wrong the model is. Here’ I’m using cross entropy, which is a standard loss function that compares what the model predicted, and what it should have predicted and spits out a big number if the model was very wrong and a small number if the model was mostly right.

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

Then we get into our actual training loop by first defining how many times we want to iterate through our dataset, then iterating over our dataset. Here I’m using tqdm to render fancy little progress bars which allow me to observe how quickly training is going.

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in tqdm(data_loader):
        # training code...

First thing we do in a training iteration is unpack the batch

X_batch, y_batch = batch

Then we reset the gradients of the optimizer. I think the intricacies of training are out of scope for this article, but if you want to learn more check out my beginner’s introduction to AI and my article on gradients for more information. For our purposes, we’ll just say this line of code gets us ready to learn from a new batch of examples.

optimizer.zero_grad()

at this point y_batch represents the entire solution sequence. We want to turn that solution into two representations, what we would be putting into the model and the predictions we would like to get back. For instance, for this sequence:

<sos>, move 0, move 1, move 3, <eos>

we would want to put in this sequence into our decoder:

<sos>, move 0, move 1, move 3

and hope to get back this sequence from the decoder output:

move 0, move 1, move 3, <eos>

In this block of code we’re defining the input to the model, getting the models prediction of what moves should be made, and getting what we would have liked the model to have predicted

# Defining the input sequence to the model
y_input = y_batch[:, :-1]

# Forward pass
y_pred = model(X_batch, y_input)

# Transform target to one-hot encoding
y_target = F.one_hot(y_batch[:, 1:], num_classes=15).float().to(device)

Then we’re figuring out how wrong the model was, keeping track of that information to get an idea of if the models getting better, and updating our model to be ever so slightly less bad at that particular example.

# Compute loss
loss = criterion(y_pred.view(-1, 15), y_target.view(-1, 15))
running_loss += loss.item()
batch_losses.append(loss.item())

# Backward pass and optimization
loss.backward()
optimizer.step()

We’re also doing some other quality of life stuff, like printing out statuses and saving model checkpoints.

# Print epoch loss
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(data_loader)}")

# Save checkpoint every epoch
if (epoch_iter + 1) % 1 == 0:
    checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch_iter+1}.pt")
    torch.save({
        'epoch': epoch_iter + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': running_loss / len(data_loader),
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

And, tadah, we have defined a Rubik’s Cube solving model and trained it. Let’s see how good it is.

Testing our Model

Now that we’ve trained our model we can go ahead and apply it to some newly shuffled Rubik’s Cubes and see how it performs. This code creates a new Rubik’s Cube and generates a sequence of move predictions until the model outputs the <stop> token

def predict_and_execute(cube, max_iter = 21):

    #turning rubiks cube into encoder input
    model_X = torch.tensor(tokenize(cube.cube)).to(torch.int32).unsqueeze(0).to(device)

    #input to decoder initialized as a vector of zeros, which is the start token
    model_y = torch.zeros(22).unsqueeze(0).to(torch.int32).to(device)
    current_index = 0

    mask = model_y<-1

    #predicting move sequence
    while current_index < max_iter:
        y_pred = model(model_X, model_y, mask)
        predicted_tokens = torch.argmax(y_pred, dim=-1)
        predicted_next_token = predicted_tokens[0,current_index]
        model_y[0,current_index+1] = predicted_next_token
        current_index+=1

    #converting into a list of moves
    predicted_tokens = model_y.cpu().numpy()[0]

    #executing move sequence
    moves = []
    for token in predicted_tokens:

        #start token
        if token == 0: continue

        #pad token
        if token == 3: continue

        #stop token
        if token == 1: break

        #move
        move = output_index_to_move(token-3) #accounting for start, pad, and end
        cube.rotate_face(move[0], reverse=move[1])
        moves.append(move)

    return moves

#we can define how many shuffles we'll use for this particular test
NUMBER_OF_SHUFFLES = 3

print(f'attempting to solve a Rubiks Cube with {NUMBER_OF_SHUFFLES} scrambling movesn')

#creating a cube
cube = RubiksCube()

#shuffling (changing n will change the number of moves to scramble the cube)
moves = scramble(cube, n=NUMBER_OF_SHUFFLES)
print(f'moves to scramble the cube:n{moves}')

# Visualize from opposite corners
fig = cube.visualize_opposite_corners(return_fig = True)
fig.set_size_inches(4, 2)
plt.show()

#trying to solve cube
print('nsolving...')
solution = predict_and_execute(cube)
print(f'moves predicted by the model to solve the cube:n{solution}')

# Visualize from opposite corners
fig = cube.visualize_opposite_corners(return_fig = True)
fig.set_size_inches(4, 2)
plt.show()

We can adjust NUMBER_OF_SHUFFLES to observe how well our model solves a few Rubik’s Cubes of various difficulty:

solving a 3 move scramble
solving a 3 move scramble
solving a 4 move scramble
solving a 4 move scramble
solving a 5 move scramble
solving a 5 move scramble
solving a 6 move scramble
solving a 6 move scramble

It’s doing a pretty good job, certainly much better than I can.

It does appear that the general assumption that the model would have a tendency to ignore erroneous moves was at least somewhat correct. Here’s a few examples of the model predicting better solutions than reversing shuffling:

It’s not all roses, though. The model appears to be somewhat inconsistent at solving a small number of scrambles:

and for complex scrambles, (like over 7) the model is pretty much hopeless.

There’s one easy solution to this problem: just train for longer. Transformers benefit from a ton of training data and a ton of training time. I have no doubt that, given a few weeks of training this model could learn to solve pretty much any Rubik’s Cube you throw at it. You could always increase some of the model parameters to make the model better at understanding intricacies about the problem.

There’s another solution as well: use a better training strategy. Supervised learning is ok, but this erroneous move issue adds a lot of noise to the training set which likely becomes exacerbated as the number of shuffles grows. I think it’s likely that using the reverse of the shuffling sequence makes less and less sense the longer the sequence gets, meaning we would need a fair amount of training time to get to the point of a highly performant model.

If you want a super performant Rubik’s Cube right now, go ahead and throw the code described in this article at a GPU, wait a while, and see what happens. Personally, I’m more interested in exploring a better approach to modeling.

In a future article I’ll be fine-tuning this model with reinforcement learning which, hopefully, will allow the model to become much more robust very quickly. So, stay tuned.

Conclusion

In this article we created a model that can solve Rubik’s Cubes from scratch by learning based off of a synthetic dataset of shuffled Rubik’s Cubes. First, we created a way to define a Rubik’s Cube such that we could shuffle and solve it, then we used that definition to generate a dataset of 2 million shuffled Rubik’s Cubes and their solutions. We figured out how to tokenize, embed, and positionally encode both the Rubik’s Cube and series of moves, then created a transformer which could accept those moves and output next move predictions. We trained that transformer based on our data and tested it on new Rubik’s Cubes. In the end, we got a promising first proof of concept model which we’ll use for future exploration around this topic.

Join Intuitively and Exhaustively Explained

At IAEE you can find:

  • Long form content, like the article you just read
  • Thought pieces, based on my experience as a data scientist, engineering director, and entrepreneur
  • A discord community focused on learning AI
  • Regular Lectures and office hours
Join IAEE
Join IAEE

References

The Code:

MLWritingAndResearch/RubiksCubeAI.ipynb at main · DanielWarfield1/MLWritingAndResearch

Relevant Articles:

AI for the Absolute Novice – Intuitively and Exhaustively Explained

Transformers – Intuitively and Exhaustively Explained

GPT – Intuitively and Exhaustively Explained

LoRA – Intuitively and Exhaustively Explained

Speculative Sampling – Intuitively and Exhaustively Explained

Self-Supervised Learning Using Projection Heads

CLIP, Intuitively and Exhaustively Explained


Related Articles