
In this article we’ll make an AI model that can solve a Rubik’s Cube. We’ll define our own dataset, make a transformer style model that can learn based on that dataset, and use that model to solve new and randomly shuffled Rubik’s Cubes.
In tackling this problem we’ll discuss practical problems which come up frequently in Data Science, and the techniques data scientists use to solve those problems.
Who is this useful for? Anyone interested in achieving mastery of modern AI.
How advanced is this post? This post covers advanced modeling strategies intuitively, and is appropriate for readers of all levels.
Pre-requisites: There are no prerequisites for this article, though an understanding of transformer style models may be useful for some of the later, code heavy sections.
References: A link to the code and supporting resources can be found in the reference section at the end of this article.
Defining a Rubik’s Cube as a Modeling Problem
As you likely know, the Rubik’s Cube is a geometric game featuring a 3x3x3 cube with different colored segments on each face. These faces can be turned by 90 degrees in either direction to scramble or solve the Rubik’s Cube.

The goal of this article is to create a model which can accept a scrambled Rubik’s Cube and output a series of steps to solve said Rubik’s Cube.

There are a ton of ways this can be done. In this article we’ll be exploring one of the more straightforward approaches: supervised learning.
The Plan, From a High Level
A natural approach to making an AI model that solves Rubik’s Cubes might be to gather data from solutions of skilled players, then train a model to mimic those solutions. While using human data to train a model has its merits, it also has its drawbacks. Finding and licensing data from pro Rubik’s Cube players might be difficult if not impossible and hiring pro Rubik’s Cube players to create a custom dataset would be costly and time consuming. If you’re clever, all this work might be unnecessary. In this article, for instance, we’ll be using a completely synthetic dataset, meaning we’ll be generating all our training data automatically, and not using any data from human players.
Essentially, we’ll frame the task of solving a Rubik’s Cube as trying to predict the reverse of the sequence that was used to scramble it. The idea is to randomly scramble millions of Rubik’s Cubes, reverse the sequence used to scramble them, then create a model which is tasked with predicting the reversed scrambling sequence.

This strategy falls under "Supervised Learning", which is the prototypical approach to training an AI model. When training a model with supervised learning you essentially say to the model "here’s an input (a scrambled Rubik’s Cube), predict an output (a list of steps), and I’ll train you based on how well your response aligns with what I expected (the reverse of the scrambling sequence)".
There are other forms of learning, like contrastive learning, semi-supervised learning, and reinforcement learning, but in this article we’ll stick with the basics. If you’re curious about some of those approaches, I provided some links in the reference section at the end of the article.
So, we have a high-level plan: shuffle a bunch of Rubik’s Cubes and train a model to predict the opposite of the sequence used to scramble them. Before we get into the intricacies of defining a custom Transformer style model to work with this data, let’s review the idea of transformer style models in general.
A Brief Introduction to the Transformer
This section will briefly review transformer style models. This is, essentially, a condensed version of my more comprehensive article on the subject:
In its most basic sense, the transformer is an encoder-decoder style model.

The encoder converts an input into an abstract representation which the decoder uses to iteratively generate output.

both the encoder and decoder employ an abstract representation of text which is created using an operation called multi-headed self-attention.

There’s a few steps which multiheaded self attention employs to construct this abstract representation. In a nutshell, a dense neural network constructs three representations, usually referred to as the query, key, and value, based on the input.

The query and key are multiplied together. Thus, some representation of every word is combined with a representation of every other word.

The value is then multiplied by this abstract combination of the query and key, constructing the final output of multi headed self-attention.

The encoder uses multi-headed self attention to create abstract representations of the input, and the decoder uses multi-headed self attention to create abstract representations of the output.


That was a super quick rundown on transformers. I tried to cover the high points without getting too in the weeds, feel free to refer to my article on transformers for more information.
While the original transformer was created for English to French translation, in an abstract way the process of solving a Rubik’s Cube is somewhat similar. We have some input (a shuffled Rubik’s Cube, vs a French sentence), and we need to predict some sequence based on that input (a sequence of moves, vs a sequence of English words).

We don’t need to make any changes to the fundamental structure of the transformer to get it to solve a Rubik’s Cube. All we have to do is properly format the Rubik’s Cube and moves into a representation the transformer can understand. We’ll cover that in the following sections.
Defining The Cube and Moves
Originally, I thought to define the Rubik’s Cube as a 3x3x3 matrix of segments, each of which has some number of faces which have some color.

This is totally possible, but I’m a data scientist and my 3D spatial programming hasn’t seen the light of day in a hot minute. After some reflection, I decided on a creative and perhaps more elegant approach: representing the Rubik’s Cube as a 5x5x5 tensor, rather than a 3x3x3 cube of segments.

The essential idea is that we, technically speaking, don’t really care about the cube. Rather, we care about the stickers and where they are relative to each other. So, instead of having a 3x3x3 data structure consisting of complicated segments that have to obey complicated rules, we can simply put that cube within a 5x5x5 grid and keep track of where the sticker colors are within this space.
When we rotate a "face", we simply need to rotate all the spaces in the 5x5x5 grid which correspond to the stickers that would be on that face and corresponding edges.

There’s 12 fundamental rotations one can apply to a Rubik’s Cube. We can rotate the front, back, top, bottom, left, and right face (6 faces) and we can rotate each of those 90 degrees clockwise or counterclockwise.
When we scramble a cube we apply some number of these rotations, and then to solve the Rubik’s Cube one can simply reverse the order and direction of the moves.
Here’s a class that defines a Rubik’s Cube and it’s moves, as well as a neat little visualization:
"""Defining the Rubik's Cube
"""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from matplotlib.patches import Polygon
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
class RubiksCube:
def __init__(self):
# Initialize a 3D tensor to represent the Rubik's Cube
self.cube = np.empty((5, 5, 5), dtype='U10')
self.cube[:, :, :] = ''
# Initialize sticker colors
self.cube[0, 1:-1, 1:-1] = 'w' # Top (white)
self.cube[1:-1, 0, 1:-1] = 'g' # Front (green)
self.cube[1:-1, 1:-1, 0] = 'r' # Left (red)
self.cube[-1, 1:-1, 1:-1] = 'y' # Bottom (yellow)
self.cube[1:-1, -1, 1:-1] = 'b' # Back (blue)
self.cube[1:-1, 1:-1, -1] = 'o' # Right (orange)
def print_cube(self):
print(self.cube)
def rotate_face(self, face, reverse=False):
"""
Rotates a given face of the cube 90 degrees.
Parameters:
face (str): One of ['top', 'front', 'left', 'bottom', 'back', 'right']
reverse (bool): if the rotation should be reversed
"""
# maps a face to the section of the tensor which needs to be rotated
rot_map = {
'top': (slice(0, 2), slice(0, 5), slice(0, 5)),
'left': (slice(0, 5), slice(0, 2), slice(0, 5)),
'front': (slice(0, 5), slice(0, 5), slice(0, 2)),
'bottom': (slice(3, 5), slice(0, 5), slice(0, 5)),
'right': (slice(0, 5), slice(3, 5), slice(0, 5)),
'back': (slice(0, 5), slice(0, 5), slice(3, 5))
}
# getting all of the stickers that will be rotating
rotating_slice = self.cube[rot_map[face]]
# getting the axis of rotation
axis_of_rotation = np.argmin(rotating_slice.shape)
# rotating about axis of rotation
axes_of_non_rotation = [0,1,2]
axes_of_non_rotation.remove(axis_of_rotation)
axes_of_non_rotation = tuple(axes_of_non_rotation)
direction = 1 if reverse else -1
rotated_slice = np.rot90(rotating_slice, k=direction, axes=axes_of_non_rotation)
# overwriting cube
self.cube[rot_map[face]] = rotated_slice
def _rotate_cube_180(self):
"""
Rotate the entire cube 180 degrees by flipping and transposing
this is used for visualization
"""
# Rotate the cube 180 degrees
rotated_cube = np.rot90(self.cube, k=2, axes=(0,1))
rotated_cube = np.rot90(rotated_cube, k=1, axes=(1,2))
return rotated_cube
def visualize_opposite_corners(self):
"""
Visualize the Rubik's Cube from two truly opposite corners
"""
# Create a new figure with two subplots
fig = plt.figure(figsize=(20, 10))
# Color mapping
color_map = {
'w': 'white',
'g': 'green',
'r': 'red',
'y': 'yellow',
'b': 'blue',
'o': 'orange'
}
# Cubes to visualize: original and 180-degree rotated
cubes_to_render = [
{
'cube_data': self.cube,
'title': 'View 1'
},
{
'cube_data': self._rotate_cube_180(),
'title': 'View 2'
}
]
# Create subplots for each view
for i, cube_info in enumerate(cubes_to_render, 1):
ax = fig.add_subplot(1, 2, i, projection='3d')
ax.view_init(elev=-150, azim=45, vertical_axis='x')
# Iterate through the cube and plot non-empty stickers
cube_data = cube_info['cube_data']
for x in range(cube_data.shape[0]):
for y in range(cube_data.shape[1]):
for z in range(cube_data.shape[2]):
# Only plot if there's a color
if cube_data[x, y, z] != '':
color = color_map.get(cube_data[x, y, z], 'gray')
# Define the 8 vertices of the small cube
vertices = [
[x, y, z], [x+1, y, z],
[x+1, y+1, z], [x, y+1, z],
[x, y, z+1], [x+1, y, z+1],
[x+1, y+1, z+1], [x, y+1, z+1]
]
# Define the faces of the cube
faces = [
[vertices[0], vertices[1], vertices[2], vertices[3]], # bottom
[vertices[4], vertices[5], vertices[6], vertices[7]], # top
[vertices[0], vertices[1], vertices[5], vertices[4]], # front
[vertices[2], vertices[3], vertices[7], vertices[6]], # back
[vertices[1], vertices[2], vertices[6], vertices[5]], # right
[vertices[0], vertices[3], vertices[7], vertices[4]] # left
]
# Plot each face
for face in faces:
poly = Poly3DCollection([face], alpha=1, edgecolor='black')
poly.set_color(color)
poly.set_edgecolor('black')
ax.add_collection3d(poly)
# Set axis limits and equal aspect ratio
ax.set_xlim(0, 5)
ax.set_ylim(0, 5)
ax.set_zlim(0, 5)
ax.set_box_aspect((1, 1, 1))
# Remove axis labels and ticks
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_zlabel('')
ax.set_title(cube_info['title'])
plt.tight_layout()
plt.show()
# Example usage
cube = RubiksCube()
# Rotate a face to show some variation
cube.rotate_face('bottom', reverse=False)
# Visualize from opposite corners
cube.visualize_opposite_corners()

We can scramble our Rubik’s Cube by simply performing a few random moves.
"""Scrambling a Rubik's Cube
"""
from itertools import product
import random
#Defining Possible Moves
faces = ['top', 'left', 'front', 'bottom', 'right', 'back']
possible_moves = tuple(product(faces, [False, True]))
def scramble(cube, n=20):
moves = []
for _ in range(n):
#selecting a random move
selected_move = random.choice(possible_moves)
moves.append(selected_move)
# Rotate a face to show some variation
cube.rotate_face(selected_move[0], reverse=selected_move[1])
return moves
#creating a cube
cube = RubiksCube()
#shuffling
moves = scramble(cube)
print(moves)
# Visualize from opposite corners
cube.visualize_opposite_corners()

And, to solve that Rubik’s Cube we can simply reverse the order of the moves, and reverse the direction in which they rotate.
"""unscrambling by reversing moves and direction
"""
#reversing order of moves
moves.reverse()
for i in range(20):
#selecting a random move
selected_move = moves[i]
# Rotate a face in the opposite direction
cube.rotate_face(selected_move[0], reverse=not selected_move[1])
# Visualize from opposite corners
cube.visualize_opposite_corners()

Using this code, we can generate a synthetic dataset consisting of shuffled Rubik’s Cubes and their solutions.
"""Parallelized code that generates 2M scrambled Rubik's Cubes,
and keeps track of the cube (X) and the moves to unscramble it (y)
"""
import random
from multiprocessing import Pool, cpu_count
from functools import partial
def generate_sample(max_scramble, _):
"""
Generates a single sample (X, y) for the Rubik's Cube task.
"""
num_moves = random.randint(1, max_scramble)
# Initializing a cube and scrambling it
cube = RubiksCube()
moves = scramble(cube, n=num_moves)
# Reversing moves, which is the solution
moves.reverse()
moves = [(m[0], not m[1]) for m in moves]
# Turning into modeling data
x = tokenize(cube.cube)
y = [0] + [move_to_output_index(m) + 3 for m in moves] + [1]
# Padding with 2s so the sequence length is always 22
y.extend([2] * (22 - len(y)))
return x, y
def parallel_generate_samples(num_samples, max_scramble, num_workers=None):
"""
Parallelizes the generation of Rubik's Cube samples.
"""
num_workers = num_workers or cpu_count()
# Use functools.partial to "lock in" the max_scramble parameter
generate_sample_partial = partial(generate_sample, max_scramble)
with Pool(processes=num_workers) as pool:
results = pool.map(generate_sample_partial, range(num_samples))
# Unpack results into X and y
X, y = zip(*results)
return list(X), list(y)
num_samples = 2_000_000
max_scramble = 20
# Generate data in parallel
X, y = parallel_generate_samples(num_samples, max_scramble)
In this code, X
is the thing we’ll be passing into the model (the shuffled Rubik’s Cube) and y
will be the thing we try to predict (the sequence of operations to solve it).

I’m using two helper functions, tokenize
and move_to_output_index
, to help me turn the Rubik’s Cube and list of moves into a more friendly representation for modeling. I don’t think it’s necessary to go over the implementation (feel free to refer to the code), but from a high level:
- the
tokenize
function accepts a 5x5x5 tensor consisting of sticker colors and empty spaces and outputs a 54×4 tensor. This 54×4 tensor has a vector for all 54 stickers in the Rubik’s Cube where each vector contains the(color, x position, y position, z position)
of a particular sticker. It ignores all the empty spaces in the 5x5x5 tensor. - The
move_to_output_index
function simply turns a move, like(top, clockwise)
into a number. All 12 moves are assigned a unique number. The reason we add the number 3 in the code will become apparent when we discuss the input and output of the decoder portion of the transformer model.
So, the input to the transformer is a list of 54 vectors consisting of (color, x position, y position, z position)
, and the output of the transformer is a list of numbers where each number correspond to one of 12 moves.

This slight re-formatting of the data has little conceptual impact but will be practically handy when we apply a model to this data.
Now that we’ve constructed our dataset of shuffled Rubik’s Cubes and sequences of moves to solve them, we can work towards feeding that data into a transformer. This is done through a process called embedding.
Embedding
When creating a transformer in a natural language context you first create a vocabulary consisting of the words (or pieces of words) that your model will understand. Then, you turn all the words in the model’s vocabulary into a vector. When the model receives an input sequence of words, it can "think" about the sequence by doing math with the vectors that represent the words.

We can do something similar to our Rubik’s Cube, we can think of the "vocabulary" of our Rubik’s Cube as six tokens, one for each colored sticker. Then we can assign each of those colors some random vector that represents it.

When we want to give a Rubik’s Cube to the input of our encoder, we can iterate through all the stickers in the Rubik’s Cube and, every time there’s a sticker, we can look up the vector that corresponds to that color and add it to a sequence.

Transformers have become widely popular in a variety of applications, from computer vision, to audio syntheses, to video generation. It turns out "take whatever data you have and represent it as a list of vectors, then throw it into a transformer" is a pretty good general strategy.
Before we put this sequence of vectors into a transformer, though, we need one more piece of information: position.
Positional Encoding
Every time we convert our Rubik’s Cube into a sequence of vectors, weather we’re training our model or trying to predict a solution, we’ll be converting the stickers of the Rubik’s Cube to vectors in the same order. That means each location in the embedded sequence will always correspond to the same location in the Rubik’s Cube.

Some modeling strategies, like convolutional and dense networks, are good at learning to leverage this consistency. They can learn "this location corresponds to this sticker, that location corresponds to that sticker", and thus it’s not necessary to add any additional information about position into the input.
Transformer style models, on the other hand, are famously prone to losing track of the order of the input. To create their abstract and meaning rich representations, they mix and mangle the input so much that positional information (as in "this vector came before that vector") is lost very quickly. As a result, when using a transformer, it’s customary to use a positional encoding.
The idea is to add some information about where each sticker was in our Rubik’s Cube to the vector which represents that stickers color. This will allow the model to inject explicit information about location into the value which represents each sticker, meaning it can reason about that stickers position, as well as its color.

We’ll be using a lookup table very similar to the approach described in the previous section. In the previous section we assigned a random vector to each color.

To encode position, we’ll also assign a random vector to each X, Y, and Z position in the 5x5x5 space of vectors.


For each sticker we can add the vector for where that sticker was along the X, Y, and Z axis to the vector that represents the stickers color and, as a result, represent the sticker color and position of each sticker using a single vector.

Here’s the implementation for a model which can embed the sticker colors of a Rubik’s Cube and apply a positional encoding:
import torch
import torch.nn as nn
class EncoderEmbedding(nn.Module):
def __init__(self, vocab_size=6, pos_i_size=5, pos_j_size=5, pos_k_size=5, embedding_dim=128):
super(EncoderEmbedding, self).__init__()
# Learnable embeddings for each component
self.vocab_embedding = nn.Embedding(vocab_size, embedding_dim)
self.pos_i_embedding = nn.Embedding(pos_i_size, embedding_dim)
self.pos_j_embedding = nn.Embedding(pos_j_size, embedding_dim)
self.pos_k_embedding = nn.Embedding(pos_k_size, embedding_dim)
def forward(self, X):
"""
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, 4)
where X[..., 0] = vocab indices (0-5)
X[..., 1] = position i (0-4)
X[..., 2] = position j (0-4)
X[..., 3] = position k (0-4)
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, embedding_dim)
"""
# Split the input into components
vocab_idx = X[..., 0]
pos_i_idx = X[..., 1]
pos_j_idx = X[..., 2]
pos_k_idx = X[..., 3]
# Look up embeddings
vocab_embed = self.vocab_embedding(vocab_idx)
pos_i_embed = self.pos_i_embedding(pos_i_idx)
pos_j_embed = self.pos_j_embedding(pos_j_idx)
pos_k_embed = self.pos_k_embedding(pos_k_idx)
# Sum the embeddings
final_embedding = vocab_embed + pos_i_embed + pos_j_embed + pos_k_embed
return final_embedding
embedding_dim = 128
# Initialize the input embedding module
encoder_embedding = EncoderEmbedding(embedding_dim=embedding_dim)
# Get the final embeddings
embedded_encoder_input = encoder_embedding(X[:10])
print("Input shape:", X[:10].shape)
print("Output shape:", embedded_encoder_input.shape)

encoder_embedding
looks up the four vectors that represent the color and positions of the sticker and adds them all together. Thus, the output is a batch of 10 Rubik’s Cubes, now represented as 54 unique vectors where each vector contains values that describe the color and position of the sticker. 128 represents the model_dimension
, which is an arbitrary number defining how long the vectors are which represent an input element. Large transformers (like OpenAI’s GPT models) use vectors which are many hundreds if not thousands of values long. With a larger model dimension more complex representations of the input can be made, at the cost of computational load. I opted to go for a relatively small model dimension, but still large enough to allow the model to think through the problem. This parameter could certainly be played with.When we create a new model, we’ll be using completely random vectors to represent both the sticker color and where those stickers are located. Naturally this random information will probably be really hard for the model to understand at first. The idea is that, throughout the training process, these random vectors will update so that the model learns to create vectors for sticker color and position which it understands.

So, we’ve turned the Rubik’s Cube into a list of vectors which a transformer can understand. We’ll also need to perform a similar process to the sequence of moves we want the model to output but, before we do, I’d like to discuss some intricacies of the decoder.

The Input and Output of the Decoder
Recall that when a transformer outputs a sequence, it does so "autoregressively", meaning when you put some sequence into the decoder, the decoder will output a prediction as to what it thinks the next token should be. That new token can be then fed back into the input of the decoder, allowing the transformer to generate a sequence one token at a time.

One of the defining characteristics of a transformer is the way they’re trained. Older styles of models would typically train on one token at a time. You would feed a sequence into the model, predict a token, then update the model based on whether it was right or wrong. This is an incredibly slow and computationally expensive process, and severely limited older styles of models when applying them to sequences.
When training a transformer, on the other hand, you input the entire sequence you want, then the transformer predicts all tokens for each input as if future tokens did not exist.

So, it predicts the next token for the first spot, the next token for the second spot, the next token for the third spot, etc. simultaneously.
I talk about how this works in my article on transformers more in depth, and explore how this quirk of transformers can be used to interesting effect in my article on speculative sampling. For now, though, we know enough of the high-level theory to discuss implementing the embedding and positional encoding for the decoder.
Tokenizing, Embedding, and Positionally Encoding the Solution Sequence
As discussed in the previous section, when we’re training our model we need to input the sequence we want into the decoder, then the decoder will make predictions of all the next moves in the sequence as if future moves didn’t exist.

Just like the input to the encoder, the input to the decoder will take the form of a sequence of vectors.

The process of creating those vectors is similar to the process in which we embedded and positionally encoded the Rubik’s Cube. There are 12 possible moves (6 faces in two directions), meaning each of the moves can be represented with a list of 12 vectors.

These moves can be positionally encoded by creating a vector for each location in the sequence.

Apparently, the maximum number of moves to solve a Rubik’s Cube is 20 moves (Don’t ask me how they figured that out), so we can assume the output sequence of our model will have a maximum length of 20, plus space for two "utility tokens".
A utility token is a special token that doesn’t matter in terms of final output but is useful from a modeling perspective. For instance, it can be useful for a model to have a way to say it’s done generating output. This is a common token called the "end of sequence", often abbreviated as <EOS>
, token.
Also, recall how the decoder predicts all next tokens based on an input token, meaning we need to input some token to get the first prediction. It’s common practice to prepend each sequence with a "start of sequence" (<SOS>
) token, which holds the space for the first prediction from the model.
Another token we’ll be using is a pad (<PAD>
) token. Basically, all the math under the hood of the transformer uses matrices, which require some uniform shape. So, if we have a few sequences of moves that are short, and a few sequences that are long, they all need to fit within the same matrix. We can do that by "padding" all the short sequences until they’re the same length as the longest sequence.

So, the "vocabulary" of the decoder will be our 12 possible moves, plus the start-of-sequence (<SOS>
), end-of-sequence (<EOS>
), and pad (<PAD>
) tokens. The total sequence length will be 22, because the maximum number of moves needed to solve any Rubik’s Cube is 20 (again, I have no idea why), and we need to make room for an (<SOS>
) and (<EOS>
) token on even the longest sequences.
Now that our tokens are thought out, and we know how long the sequence will be, we can just initialize 15 random vectors for the token embedding, and 22 random vectors for each location in the sequence. When we take in some sequence to either train or make some prediction, we can use these vectors to represent both the value and position of all of the moves.

Let’s go ahead and implement the decoder embedding. First of all, our data already has our utility tokens built in. Recall we used this code to generate the y portion of our dataset
y = [0] + [move_to_output_index(m) + 3 for m in moves] + [1]
# Padding with 2s so the sequence length is always 22
y.extend([2] * (22 - len(y)))
here we’re converting each of our moves to an integer from 3–14 with the expression move_to_output_index(m) + 3
(which adds 3 to the numbers which represent our possible moves labeled as 0–11). It then adds 0
and 1
to the beginning and end of the sequence and appends a list of 2's
on the end until the total sequence is of length 22.
Thus:
0
represents start of sequence<sos>
1
represents end of sequence<eos>
2
represents pad<pad>
3
–14
represent our possible moves

So, we can implement embedding and positional encoding for our sequence of moves as follows:
class DecoderEmbedding(nn.Module):
def __init__(self, vocab_size=15, pos_size=22, embedding_dim=128):
super(DecoderEmbedding, self).__init__()
# Learnable embeddings for each component
self.vocab_embedding = nn.Embedding(vocab_size, embedding_dim)
self.pos_embedding = nn.Embedding(pos_size, embedding_dim)
def forward(self, X):
"""
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_len), where each element
corresponds to a token index.
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, embedding_dim)
"""
# Token embeddings (based on vocab indices)
vocab_embed = self.vocab_embedding(X)
# Generate position indices based on input shape
batch_size, seq_len = X.shape
position_indices = torch.arange(seq_len, device=X.device).unsqueeze(0).expand(batch_size, -1)
# Position embeddings
pos_embedding = self.pos_embedding(position_indices)
# Sum the embeddings
final_embedding = vocab_embed + pos_embedding
return final_embedding
embedding_dim = 128
# Initialize the input embedding module
decoder_embedding = DecoderEmbedding(embedding_dim=embedding_dim)
# Get the final embeddings
embedded_decoder_input = decoder_embedding(y[:10])
print("Input shape:", y[:10].shape)
print("Output shape:", embedded_decoder_input.shape)

model_dim
of 128 is the same as the one used in encoding the Rubik’s Cube which we defined in a previous section.
Alright, we’ve figured out how to encode both the Rubik’s Cube, and sequences of moves to solve them, in a way the transformer can understand (a big list of vectors). Now we can actually get into building the transformer.
Implementing the Transformer
The main point of this article isn’t really the model, but rather the thought process around making modeling decisions. We’ve really done all the heavy lifting already. By turning our moves into vectors which a transformer understands, we can use the same standard transformer used in countless other applications.
I’ve covered the core ideas of the transformer in many different articles at this point. Still, this wouldn’t be "exhaustive" if I didn’t cover implementing the transformer, so let’s do it. This will be a fairly brief pass over the process, feel free to dig into some of the linked articles in the reference section for a more in-depth understanding.
1) The Encoder
We already made the embedding and positional encoding that turns a Rubik’s Cube into a list of vectors, so now we need to implement the encoder portion of the model which is tasked with thinking about that representation and turning it into an abstract but meaning rich representation.
import torch
import torch.nn as nn
# Define the Transformer Encoder
class TransformerEncoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerEncoder, self).__init__()
# Define a single transformer encoder layer
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
batch_first=True
)
# Stack multiple encoder layers
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, src):
"""
Args:
src (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
"""
return self.encoder(src)
# Example usage
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
# Initialize the transformer encoder
encoder = TransformerEncoder(num_layers=num_layers, d_model=embedding_dim, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
# Forward pass
encoder_output = encoder(embedded_encoder_input)
print("Encoder output shape:", encoder_output.shape) # Should be (seq_len, batch_size, d_model)

pytorch already has an implementation for the encoder block, so we just used that.
num_heads
describes how many heads are used in each multi-headed self attention block (see my article on transformers if you want to understand more)num_layers
describes how many encoder blocks are used (see my article on transformers if you want to understand more)d_ff
is how large the feed-forward network in the transformer is. This is typically much larger than themodel_dim
as it allows the feed forward network to look at each vector in the model, expand it into a few representations, then shrink those vectors back down into the original size based on that expanded information.dropout
is a regularizing parameter which randomly hides certain values in the model.dropout
is a common trick that helps AI models learn trends in data without simply memorizing individual examples in the dataset.
2) The Decoder
The decoder takes the embedded representation of the move sequence, and combines it with the output of the encoder to make predictions about which move should come next.
So, let’s build it:
import torch
import torch.nn as nn
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerDecoderLayer, self).__init__()
# Masked Multi-Head Self-Attention
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
self.self_attn_norm = nn.LayerNorm(d_model)
self.self_attn_dropout = nn.Dropout(dropout)
# Masked Multi-Head Cross-Attention
self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
self.cross_attn_norm = nn.LayerNorm(d_model)
self.cross_attn_dropout = nn.Dropout(dropout)
# Point-wise Feed Forward Network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.ffn_norm = nn.LayerNorm(d_model)
self.ffn_dropout = nn.Dropout(dropout)
def forward(self, tgt, memory):
"""
Args:
tgt (torch.Tensor): Target sequence of shape (batch_size, tgt_seq_len, d_model).
memory (torch.Tensor): Encoder output of shape (batch_size, src_seq_len, d_model).
Returns:
torch.Tensor: Output tensor of shape (batch_size, tgt_seq_len, d_model).
"""
tgt_len = tgt.size(1)
# Generate causal mask for self-attention (causal masking)
causal_mask = torch.triu(torch.ones(tgt_len, tgt_len, device=tgt.device), diagonal=1).to(torch.bool)
# Masked Multi-Head Self-Attention
self_attn_out, _ = self.self_attn(
tgt, tgt, tgt,
attn_mask=causal_mask,
)
tgt = self.self_attn_norm(tgt + self.self_attn_dropout(self_attn_out))
# Masked Multi-Head Cross-Attention
cross_attn_out, _ = self.cross_attn(
tgt, memory, memory,
)
tgt = self.cross_attn_norm(tgt + self.cross_attn_dropout(cross_attn_out))
# Feed Forward Network
ffn_out = self.ffn(tgt)
tgt = self.ffn_norm(tgt + self.ffn_dropout(ffn_out))
return tgt
class TransformerDecoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerDecoder, self).__init__()
self.layers = nn.ModuleList([
TransformerDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, tgt, memory):
"""
Args:
tgt (torch.Tensor): Target sequence of shape (batch_size, tgt_seq_len, d_model).
memory (torch.Tensor): Encoder output of shape (batch_size, src_seq_len, d_model).
Returns:
torch.Tensor: Output tensor of shape (batch_size, tgt_seq_len, d_model).
"""
for layer in self.layers:
tgt = layer(tgt, memory)
return self.norm(tgt)
# Example usage
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
embedding_dim = 128
# Initialize the transformer decoder
decoder = TransformerDecoder(num_layers=num_layers, d_model=embedding_dim, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
# Example inputs
tgt_seq_len = 22
src_seq_len = 54
batch_size = 10
# Target and memory
tgt = torch.randn(batch_size, tgt_seq_len, embedding_dim)
memory = torch.randn(batch_size, src_seq_len, embedding_dim)
# Forward pass through the decoder
decoder_output = decoder(tgt, memory)
print("Decoder output shape:", decoder_output.shape) # Expected shape: (batch_size, tgt_seq_len, d_model)

3) The Classification Head
The output of the decoder represents all of the moves the model thinks should be taken, but it does so as a big list of abstract vectors. The goal of the Classification Head is to turn each of these abstract vectors into a prediction of what token should be output (our 12 moves and 3 utility tokens). we do that by simply using a neural network on each vector to turn our 128-value long vector into a vector of length 15. Then we turn those 15 values into probabilities (where bigger numbers are higher probability) using an operation called SoftMax.
So, in other words, the prediction head turns all our abstract vectors into a prediction of which move should happen at each spot in the solution sequence.

Here’s that code:
lass ProjHead(nn.Module):
def __init__(self, d_model=128, num_tokens=15):
super(ProjHead, self).__init__()
self.num_tokens = num_tokens
# Linear layer to project from d_model to num_tokens
self.fc = nn.Linear(d_model, num_tokens)
# Softmax activation to convert logits into probabilities
self.softmax = nn.Softmax(dim=-1)
def forward(self, logits):
"""
Args:
logits (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
Returns:
torch.Tensor: Output probabilities of shape (batch_size, seq_len, num_tokens).
"""
# Project logits through a linear layer
projected_logits = self.fc(logits)
# Apply Softmax to convert to probabilities
probabilities = self.softmax(projected_logits)
return probabilities
# Initialize the module
logits_to_probs = ProjHead()
# Convert logits to probabilities
probabilities = logits_to_probs(decoder_output)
print("Probabilities shape:", probabilities.shape) # Expected: (batch_size, seq_len, num_tokens)
print("Sum of probabilities for first token:", probabilities[0, 0].sum().item()) # Should be close to 1.0

4) The Model
We created:
- a way to embed sticker colors and encode the location of stickers on a Rubik’s Cube as a list of vectors
- a way to embed moves and encode their position in a solution sequence as a list of vectors
- the encoder, which accepts the vectors representing a Rubik’s Cube and allows a model to convert that data into a dense and meaning rich representation
- the decoder, which accepts previous moves and outputs predictions of future moves that should be taken based on the value of previous moves and the encoded representation of the Rubik’s Cube
- A projection head which takes the output of the decoder and turns it into probabilistic predictions that certain moves should be made
Now we can put that all together to define the actual model
class RubiksCubeTransformer(nn.Module):
def __init__(self, layers_encoder=5, layers_decoder=5, d_model=128):
super(RubiksCubeTransformer, self).__init__()
#turns the tokens that go into the encoder and decoder into vectors
self.encoder_embedding = EncoderEmbedding(embedding_dim=d_model)
self.decoder_embedding = DecoderEmbedding(embedding_dim=d_model)
#Defining the Encoder and Decoder
self.encoder = TransformerEncoder(num_layers=layers_encoder, d_model=d_model, num_heads=4, d_ff=d_model*2, dropout=0.1)
self.decoder = TransformerDecoder(num_layers=layers_decoder, d_model=d_model, num_heads=4, d_ff=d_model*2, dropout=0.1)
#Defining the projction head to turn logits into probabilities
self.projection_head = ProjHead(d_model=d_model, num_tokens=15)
def forward(self, X, y):
#embedding both inputs
X_embed = self.encoder_embedding(X)
y_embed = self.decoder_embedding(y)
#encoding Rubiks Cube representation
X_encode = self.encoder(X_embed)
#decoding embedded previous moves cross attended with rubiks cube encoding
y_decode = self.decoder(y_embed, X_encode)
#turning logits from the decoder into predictions
return self.projection_head(y_decode)
model = RubiksCubeTransformer()
model(X[:10], y[:10]).shape

Now we can train this model on our synthetic dataset. Before we do, though, I’d like to take a step back and consider some of the costs and benefits of our approach.
The Subtlety of Training Strategy
Before we get into creating our model, I’d like to reflect on the synthetic dataset we created in the previous section, and some of the implications that dataset suggests.
In this article we’re computing a perfectly random sequence of moves, using it to shuffle a Rubik’s Cube, then asking the model to predict the exact opposite of the shuffling sequence. This is great in theory, but there’s a practical problem; if we happened to generate a random sequence of moves like this:
<sos>, front clockwise, front counterclockwise, front clockwise, front counterclockwise, <eos>
Then instead of training the model to predict that the Rubik’s Cube is already completed (because it is, these random moves simply undo each other), we would train the model to predict the same erroneous set of steps and then output that the Rubik’s Cube is completed.
There are a lot of ways to get around this problem. I went with the simplest approach: ignoring it.
Transformers, the style of model we’re using, are known to perform remarkably well in a natural language context, which has a lot of random noise and occasional non-sense. So, we already know transformers are good at learning to model complex sequences despite some poor-quality examples in the training set.
The hope, for us, is that silly moves will be much less common than productive moves and, as a result, the model will tend to learn productive decisions.
So, Basically, if a transformer is good at learning language even if there are occasionally silly words in the training set, maybe it will be good at solving a Rubik’s Cube even if the dataset it’s trained on has occasionally silly moves.
It can be easy to be too hopeful about this type of assumption early on. It’s important to remember that, when constructing an AI model, the model is attempting to learn exactly what you’re training it to do. No more, no less. We can hope that the nature of the model will deal with quirks in our synthetic dataset elegantly, but we’ll only really know if we made the right call once we’ve gone ahead, trained, and then tested our model.
Generally speaking, I’ve found that the best modeling strategy is the one you think might work and can implement quickly. Iteration is a fact of life in complex ML problems.
So let’s give it a shot. We have a transformer and a dataset, let’s train this sucker.
Training the Model
Before I actually train the model, I’m doing a bit of setup work:
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
# Define the checkpoint directory
checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/Blogs/RubiksCubeCheckpoints"
# Initialize key variables
batch_losses = []
epoch_iter = 0 # Keeps track of total epochs trained
# User option: Start from scratch or resume from the last checkpoint
start_from_scratch = False # Set this to True to start training from scratch
if start_from_scratch:
print("Starting training from scratch...")
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize the model and move it to GPU
model = RubiksCubeTransformer(layers_encoder=6, layers_decoder=3, d_model=64).to(device)
# Move data to GPU
X = X.to(device)
y = y.to(device)
# Define dataset and data loader
batch_size = 16
dataset = TensorDataset(X, y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-5)
else:
print("Attempting to resume training from the latest checkpoint...")
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize the model and move it to GPU
model = RubiksCubeTransformer(layers_encoder=6, layers_decoder=3, d_model=64).to(device)
# Move data to GPU
X = X.to(device)
y = y.to(device)
# Define dataset and data loader
batch_size = 16
dataset = TensorDataset(X, y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Load the latest checkpoint if available
latest_checkpoint = None
if os.path.exists(checkpoint_dir):
print(os.listdir(checkpoint_dir))
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
print(checkpoints)
if checkpoints:
checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0])) # Sort by epoch
latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])
if latest_checkpoint:
print(f"Loading checkpoint: {latest_checkpoint}")
checkpoint = torch.load(latest_checkpoint)
# Load model and optimizer states
model.load_state_dict(checkpoint['model_state_dict'])
# Initialize optimizer and load its state
optimizer = optim.Adam(model.parameters(), lr=1e-5)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Set epoch_iter to the last epoch from the checkpoint
epoch_iter = checkpoint['epoch']
print(f"Resuming training from epoch {epoch_iter}")
else:
raise ValueError('No Checkpoint Found')
Transformers can take a while to train (I trained this model over the course of several days). Also, Google Colab has a tendency to log you out of a session if you’ve been away from the keyboard for too long. As a result, it was vital to save model checkpoints somewhere such that I could recover and resume training. This code allows me to recover the most recent checkpoint from my Google Drive before I continue. If there’s no checkpoint, it defines a new model.
this code also does some other quality of life things, like turning our training data into a DataLoader
that takes care of creating batches and shuffling data across epochs, and making sure our model and data are both on the GPU.
There’re a few things going on in the actual training code. Let’s go through section by section.
from google.colab import drive
import torch
import torch.nn as nn
from tqdm import tqdm
import os
# Printing out parameter count
print('model param count:')
print(count_parameters(model))
# Define loss function
criterion = nn.CrossEntropyLoss()
verbose = False
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch in tqdm(data_loader):
X_batch, y_batch = batch
if verbose:
print('n==== Batch Examples ====')
num_examples = 2
print('Encoder Input')
print(X_batch[:num_examples])
print('Decoder Input')
print(y_batch[:num_examples, :-1])
print('Decoder target')
print(y_batch[:num_examples, 1:])
# Move batch data to GPU (if they're not already)
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
# Defining the input sequence to the model
y_input = y_batch[:, :-1]
# Forward pass
y_pred = model(X_batch, y_input)
# Transform target to one-hot encoding
y_target = F.one_hot(y_batch[:, 1:], num_classes=15).float().to(device)
# Compute loss
loss = criterion(y_pred.view(-1, 15), y_target.view(-1, 15))
running_loss += loss.item()
batch_losses.append(loss.item())
# Backward pass and optimization
loss.backward()
optimizer.step()
if verbose:
break
if verbose:
break
epoch_iter += 1
# Print epoch loss
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(data_loader)}")
# Save checkpoint every epoch
if (epoch_iter + 1) % 1 == 0:
checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch_iter+1}.pt")
torch.save({
'epoch': epoch_iter + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': running_loss / len(data_loader),
}, checkpoint_path)
print(f"Checkpoint saved at {checkpoint_path}")

I have this little block of code for my own debugging purposes. I defined a function that gives me the total number of trainable parameters in my model, which is useful in me getting a rough idea of how large the model I’m training is.
print('model param count:')
print(count_parameters(model))
Next I’m defining my "criterion", which is a fancy way of saying how I’ll be judging how right or wrong the model is. Here’ I’m using cross entropy, which is a standard loss function that compares what the model predicted, and what it should have predicted and spits out a big number if the model was very wrong and a small number if the model was mostly right.
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
Then we get into our actual training loop by first defining how many times we want to iterate through our dataset, then iterating over our dataset. Here I’m using tqdm
to render fancy little progress bars which allow me to observe how quickly training is going.
num_epochs = 100
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch in tqdm(data_loader):
# training code...
First thing we do in a training iteration is unpack the batch
X_batch, y_batch = batch
Then we reset the gradients of the optimizer. I think the intricacies of training are out of scope for this article, but if you want to learn more check out my beginner’s introduction to AI and my article on gradients for more information. For our purposes, we’ll just say this line of code gets us ready to learn from a new batch of examples.
optimizer.zero_grad()
at this point y_batch
represents the entire solution sequence. We want to turn that solution into two representations, what we would be putting into the model and the predictions we would like to get back. For instance, for this sequence:
<sos>, move 0, move 1, move 3, <eos>
we would want to put in this sequence into our decoder:
<sos>, move 0, move 1, move 3
and hope to get back this sequence from the decoder output:
move 0, move 1, move 3, <eos>
In this block of code we’re defining the input to the model, getting the models prediction of what moves should be made, and getting what we would have liked the model to have predicted
# Defining the input sequence to the model
y_input = y_batch[:, :-1]
# Forward pass
y_pred = model(X_batch, y_input)
# Transform target to one-hot encoding
y_target = F.one_hot(y_batch[:, 1:], num_classes=15).float().to(device)
Then we’re figuring out how wrong the model was, keeping track of that information to get an idea of if the models getting better, and updating our model to be ever so slightly less bad at that particular example.
# Compute loss
loss = criterion(y_pred.view(-1, 15), y_target.view(-1, 15))
running_loss += loss.item()
batch_losses.append(loss.item())
# Backward pass and optimization
loss.backward()
optimizer.step()
We’re also doing some other quality of life stuff, like printing out statuses and saving model checkpoints.
# Print epoch loss
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(data_loader)}")
# Save checkpoint every epoch
if (epoch_iter + 1) % 1 == 0:
checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch_iter+1}.pt")
torch.save({
'epoch': epoch_iter + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': running_loss / len(data_loader),
}, checkpoint_path)
print(f"Checkpoint saved at {checkpoint_path}")
And, tadah, we have defined a Rubik’s Cube solving model and trained it. Let’s see how good it is.
Testing our Model
Now that we’ve trained our model we can go ahead and apply it to some newly shuffled Rubik’s Cubes and see how it performs. This code creates a new Rubik’s Cube and generates a sequence of move predictions until the model outputs the <stop>
token
def predict_and_execute(cube, max_iter = 21):
#turning rubiks cube into encoder input
model_X = torch.tensor(tokenize(cube.cube)).to(torch.int32).unsqueeze(0).to(device)
#input to decoder initialized as a vector of zeros, which is the start token
model_y = torch.zeros(22).unsqueeze(0).to(torch.int32).to(device)
current_index = 0
mask = model_y<-1
#predicting move sequence
while current_index < max_iter:
y_pred = model(model_X, model_y, mask)
predicted_tokens = torch.argmax(y_pred, dim=-1)
predicted_next_token = predicted_tokens[0,current_index]
model_y[0,current_index+1] = predicted_next_token
current_index+=1
#converting into a list of moves
predicted_tokens = model_y.cpu().numpy()[0]
#executing move sequence
moves = []
for token in predicted_tokens:
#start token
if token == 0: continue
#pad token
if token == 3: continue
#stop token
if token == 1: break
#move
move = output_index_to_move(token-3) #accounting for start, pad, and end
cube.rotate_face(move[0], reverse=move[1])
moves.append(move)
return moves
#we can define how many shuffles we'll use for this particular test
NUMBER_OF_SHUFFLES = 3
print(f'attempting to solve a Rubiks Cube with {NUMBER_OF_SHUFFLES} scrambling movesn')
#creating a cube
cube = RubiksCube()
#shuffling (changing n will change the number of moves to scramble the cube)
moves = scramble(cube, n=NUMBER_OF_SHUFFLES)
print(f'moves to scramble the cube:n{moves}')
# Visualize from opposite corners
fig = cube.visualize_opposite_corners(return_fig = True)
fig.set_size_inches(4, 2)
plt.show()
#trying to solve cube
print('nsolving...')
solution = predict_and_execute(cube)
print(f'moves predicted by the model to solve the cube:n{solution}')
# Visualize from opposite corners
fig = cube.visualize_opposite_corners(return_fig = True)
fig.set_size_inches(4, 2)
plt.show()
We can adjust NUMBER_OF_SHUFFLES
to observe how well our model solves a few Rubik’s Cubes of various difficulty:




It’s doing a pretty good job, certainly much better than I can.
It does appear that the general assumption that the model would have a tendency to ignore erroneous moves was at least somewhat correct. Here’s a few examples of the model predicting better solutions than reversing shuffling:



It’s not all roses, though. The model appears to be somewhat inconsistent at solving a small number of scrambles:

and for complex scrambles, (like over 7) the model is pretty much hopeless.

There’s one easy solution to this problem: just train for longer. Transformers benefit from a ton of training data and a ton of training time. I have no doubt that, given a few weeks of training this model could learn to solve pretty much any Rubik’s Cube you throw at it. You could always increase some of the model parameters to make the model better at understanding intricacies about the problem.
There’s another solution as well: use a better training strategy. Supervised learning is ok, but this erroneous move issue adds a lot of noise to the training set which likely becomes exacerbated as the number of shuffles grows. I think it’s likely that using the reverse of the shuffling sequence makes less and less sense the longer the sequence gets, meaning we would need a fair amount of training time to get to the point of a highly performant model.
If you want a super performant Rubik’s Cube right now, go ahead and throw the code described in this article at a GPU, wait a while, and see what happens. Personally, I’m more interested in exploring a better approach to modeling.
In a future article I’ll be fine-tuning this model with reinforcement learning which, hopefully, will allow the model to become much more robust very quickly. So, stay tuned.
Conclusion
In this article we created a model that can solve Rubik’s Cubes from scratch by learning based off of a synthetic dataset of shuffled Rubik’s Cubes. First, we created a way to define a Rubik’s Cube such that we could shuffle and solve it, then we used that definition to generate a dataset of 2 million shuffled Rubik’s Cubes and their solutions. We figured out how to tokenize, embed, and positionally encode both the Rubik’s Cube and series of moves, then created a transformer which could accept those moves and output next move predictions. We trained that transformer based on our data and tested it on new Rubik’s Cubes. In the end, we got a promising first proof of concept model which we’ll use for future exploration around this topic.
Join Intuitively and Exhaustively Explained
At IAEE you can find:
- Long form content, like the article you just read
- Thought pieces, based on my experience as a data scientist, engineering director, and entrepreneur
- A discord community focused on learning AI
- Regular Lectures and office hours

References
The Code:
MLWritingAndResearch/RubiksCubeAI.ipynb at main · DanielWarfield1/MLWritingAndResearch
Relevant Articles:
AI for the Absolute Novice – Intuitively and Exhaustively Explained
Transformers – Intuitively and Exhaustively Explained
GPT – Intuitively and Exhaustively Explained
LoRA – Intuitively and Exhaustively Explained
Speculative Sampling – Intuitively and Exhaustively Explained