
In this article, I summarize part of my research paper "Predicting Music Hierarchies with a Graph-Based Neural Decoder" which presents a data-driven system able to parse jazz chord sequences.
This research is motivated by my frustration with Grammar-based parsing systems (which were the only option available for music data):
- The grammar-building phase requires a lot of domain knowledge
- The parser will fail in case of some unseen configurations or noisy data
- It is challenging to account for multiple musical dimensions in a single grammar rule
- There is no well-supported active Python framework to help with the development
My approach (inspired by similar works in Natural Language Processing), instead, doesn’t rely on any grammar, produces partial results for noisy inputs, trivially handles multiple musical dimensions, and is implemented in PyTorch.
If you are not familiar with parsing and grammars, or simply need to refresh your knowledge, I’ll now take a step back.
What is "parsing"?
The term parsing refers to predicting/inferring a tree (the mathematical structure) whose leaves are the elements of the sequences.

Ok then, but why would we need a tree?
Let’s start with the following sequence of jazz Chords (section A of "Take the A Train").

In Jazz music chords are connected by a complex system of perceptual relations. For example, the Dm7 is a preparation for the dominant chord G7. This means that the Dm7 is less important than the G7 and it could, for example, be omitted in a different reharmonization. Similarly, the D7 is a secondary dominant (a dominant of a dominant) also referring to G7.
This kind of harmonic relation can be expressed with a tree and can be useful for music analysis or while performing tasks like reharmonization. However, since chords in music pieces are available mostly as a sequence, we want a system which is able to automatically build such a tree structure.
Constituent vs Dependency Trees
Before continuing we need to differentiate between two kinds of trees.
Musicologist tends to use what is called constituent trees, which you can see in the picture below. Constituent trees contain leaves (chords in blue – elements of the input sequence), and internal nodes (chords in orange – reductions of the children’s leaves).

In this work instead, we consider another kind of tree, called dependency tree. This kind of tree does not have internal nodes, but only directed arcs connecting the elements of the sequence.

We can produce the dependency tree from the constituent tree, with some algorithms that will be discussed later.
Dataset
Since this is a data-driven approach, we need a dataset of chord sequences (the input data) associated with a dataset of trees (the ground truth) for training and testing. We use the Jazz Treebank¹ which is publicly available in this GitHub repository (it can be freely used for non-commercial applications, and I obtained the author’s permission to use it in this article). In particular, they provide a JSON file with all chords and annotations.
We model each chord in input to our system, with three features:
- The root, an integer in [0..11], where C -> 0, C# ->1, etc…
- The basic form, an integer in [0..5], which select among major, minor, augmented, half-diminished, diminished, and suspended (sus).
- The extension, an integer in [0,1,2] which selects among 6, minor 7, or major 7.
To produce the chord features from a chord label (a string), we can use a regular expression as follows (note that this code work for this dataset, as the format may vary in other chord datasets).
def parse_chord_label(chord_label):
# Define a regex pattern for chord symbols
pattern = r"([A-G][#b]?)(m|+|%|o|sus)?(6|7|^7)?"
# Match the pattern with the input chord
match = re.match(pattern, chord_label)
if match:
# Extract the root, basic chord form and extension from the match obj
root = match.group(1)
form = match.group(2) or "M"
ext = match.group(3) or ""
return root, form, ext
else:
# Return None if the input is not a valid chord symbol
raise ValueError("Invalid chord symbol: {}".format(chord_label))
Finally, we need to produce the dependency tree. The JHT dataset only contains constituent trees, encoded as a nested dictionary. We import them and transform them into dependency trees with a recursive function. The mechanism of our function can be described as follows.
We start from a fully formed constituent tree and a dependency tree without any dependency arcs, consisting only of the nodes labelled with sequence elements. The algorithm groups all internal tree nodes with their primary child (which all have the same label) and uses all secondary child relations originating from each group to create dependency arcs between the group label and the secondary child label.
def parse_jht_to_dep_tree(jht_dict):
"""Parse the python jazz harmony tree dict to a list of dependencies and a list of chord in the leaves.
"""
all_leaves = []
def _iterative_parse_jht(dict_elem):
"""Iterative function to parse the python jazz harmony tree dict to a list of dependencies."""
children = dict_elem["children"]
if children == []: # recursion ending condition
out = (
[],
{"index": len(all_leaves), "label": dict_elem["label"]},
)
# add the label of the current node to the global list of leaves
all_leaves.append(dict_elem["label"])
return out
else: # recursive call
assert len(children) == 2
current_label = noast(dict_elem["label"])
out_list = [] # dependency list
iterative_result_left = _iterative_parse_jht(children[0])
iterative_result_right = _iterative_parse_jht(children[1])
# merge the dependencies lists computed deeper
out_list.extend(iterative_result_left[0])
out_list.extend(iterative_result_right[0])
# check if the label correspond to the left or right children and return the corresponding result
if iterative_result_right[1]["label"] == current_label: # default if both children are equal is to go left-right arch
# append the dependency for the current node
out_list.append((iterative_result_right[1]["index"], iterative_result_left[1]["index"]))
return out_list, iterative_result_right[1]
elif iterative_result_left[1]["label"] == current_label:
# print("right-left arc on label", current_label)
# append the dependency for the current node
out_list.append((iterative_result_left[1]["index"], iterative_result_right[1]["index"]))
return out_list, iterative_result_left[1]
else:
raise ValueError("Something went wrong with label", current_label)
dep_arcs, root = _iterative_parse_jht(jht_dict)
dep_arcs.append((-1,root["index"])) # add connection to the root, with index -1
# add self loop to the root
dep_arcs.append((-1,-1)) # add loop connection to the root, with index -1
return dep_arcs, all_leaves
Dependency Parsing Model
Our parsing model functioning mechanism is pretty simple: we consider all possible arcs and use an arc predictor (a simple binary classifier) to predict whether this arc should be part of the tree or not.
However, it is pretty hard to make this choice only based on the two chords we are trying to connect. We need some context. We build such context with a transformer encoder.
To resume, our parsing model act in two steps:
- the input sequence is passed through a transformer encoder to enrich it with contextual information;
- a binary classifier evaluates the graph of all possible dependency arcs to filter out the unwanted arcs.

The Transformer Encoder follows the standard architecture. We use a learnable embedding layer to map each categorical input feature to points in a continuous multidimensional space. All embeddings are then summed together, so it is up to the network to "decide" the dimension to use for each feature.
import torch.nn as nn
class TransformerEncoder(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
encoder_depth,
n_heads = 4,
dropout=0,
embedding_dim = 8,
activation = "gelu",
):
super().__init__()
self.input_dim = input_dim
self.positional_encoder = PositionalEncoding(
d_model=input_dim, dropout=dropout, max_len=200
)
encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, dim_feedforward=hidden_dim, nhead=n_heads, dropout =dropout, activation=activation)
encoder_norm = nn.LayerNorm(input_dim)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=encoder_depth, norm=encoder_norm)
self.embeddings = nn.ModuleDict({
"root": nn.Embedding(12, embedding_dim),
"form": nn.Embedding(len(CHORD_FORM), embedding_dim),
"ext": nn.Embedding(len(CHORD_EXTENSION), embedding_dim),
"duration": nn.Embedding(len(JTB_DURATION), embedding_dim,
"metrical": nn.Embedding(METRICAL_LEVELS, embedding_dim)
})
def forward(self, sequence):
root = sequence[:,0]
form = sequence[:,1]
ext = sequence[:,2]
duration = sequence[:,3]
metrical = sequence[:,4]
# transform categorical features to embedding
root = self.embeddings["root"](root.long())
form = self.embeddings["form"](form.long())
ext = self.embeddings["ext"](ext.long())
duration = self.embeddings["duration"](duration.long())
metrical = self.embeddings["metrical"](metrical.long())
# sum all embeddings
z = root + form + ext + duration + metrical
# add positional encoding
z = self.positional_encoder(z)
# reshape to (seq_len, batch = 1, input_dim)
z = torch.unsqueeze(z,dim= 1)
# run transformer encoder
z = self.transformer_encoder(src=z, mask=src_mask)
# remove batch dim
z = torch.squeeze(z, dim=1)
return z, ""
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)
The arc predictor is just a linear layer taking as input the concatenation of the hidden features of the two chords. The classification step for all the arcs is done in parallel thanks to the power of matrix multiplication.
class ArcPredictor(nn.Module):
def __init__(self, hidden_channels, activation=F.gelu, dropout=0.3):
super().__init__()
self.activation = activation
self.root_linear = nn.Linear(1, hidden_channels) # linear to produce root features
self.lin1 = nn.Linear(2*hidden_channels, hidden_channels)
self.lin2 = nn.Linear(hidden_channels, 1)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(hidden_channels)
def forward(self, z, pot_arcs):
# add column for the root element
root_feat = self.root_linear(torch.ones((1,1), device=z.device))
z = torch.vstack((root_feat,z))
# proceed with the computation
z = self.norm(z)
# concat the embeddings of the two nodes, shape (num_pot_arcs, 2*hidden_channels)
z = torch.cat([z[pot_arcs[:, 0]], z[pot_arcs[:, 1]]], dim=-1)
# pass through a linear layer, shape (num_pot_arcs, hidden_channels)
z = self.lin1(z)
# pass through activation, shape (num_pot_arcs, hidden_channels)
z = self.activation(z)
# normalize
z = self.norm(z)
# dropout
z = self.dropout(z)
# pass through another linear layer, shape (num_pot_arcs, 1)
z = self.lin2(z)
# return a vector of shape (num_pot_arcs,)
return z.view(-1)
We can put the transformer encoder and the arc predictor in a single torch module to simplify its usage.
class ChordParser(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, dropout=0.2, embedding_dim = 8, use_embedding = True, n_heads = 4):
super().__init__()
self.activation = nn.functional.gelu
# initialize the encoder
self.encoder = NotesEncoder(input_dim, hidden_dim, num_layers, dropout, embedding_dim, n_heads=n_heads)
# initialize the decoder
self.decoder = ArcDecoder(input_dim, dropout=dropout)
def forward(self, note_features, pot_arcs, mask=None):
z = self.encoder(note_features)
return self.decoder(z, pot_arcs)
Loss Function
As a loss function, we use the sum of two losses:
- The binary cross entropy loss: the idea is to see our problem as a binary classification problem, where each arc can be predicted or not.
- The cross-entropy loss: the idea is to see our problem as a multiclass classification problem, where for each head in a head → dep arc, we need to predict which one is the correct dependent among all other chords
loss_bce = torch.nn.BCEWithLogitsLoss()
loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1)
total_loss = loss_bce + loss_ce
Postprocessing
There is one problem that we still have to solve. The fact that the predicted arcs should form a tree structure is not enforced at any point during our training. Therefore we could have an invalid configuration such as an arc loop. Fortunately, there is an algorithm that we can use to ensure that this does not happen: the Eisner algorithm.²
Instead of just assuming that an arc exists if its predicted probability is bigger than 0.5, we save all predictions in a square matrix (the adjacency matrix) of size (number of chords, number of chords) and we run the Eisner algorithm on it.
# Adapted from https://github.com/HMJW/biaffine-parser
def eisner(scores, return_probs = False):
"""Parse using Eisner's algorithm.
The matrix follows the following convention:
scores[i][j] = p(i=head, j=dep) = p(i --> j)
"""
rows, collumns = scores.shape
assert rows == collumns, 'scores matrix must be square'
num_words = rows - 1 # Number of words (excluding root).
# Initialize CKY table.
complete = np.zeros([num_words+1, num_words+1, 2]) # s, t, direction (right=1).
incomplete = np.zeros([num_words+1, num_words+1, 2]) # s, t, direction (right=1).
complete_backtrack = -np.ones([num_words+1, num_words+1, 2], dtype=int) # s, t, direction (right=1).
incomplete_backtrack = -np.ones([num_words+1, num_words+1, 2], dtype=int) # s, t, direction (right=1).
incomplete[0, :, 0] -= np.inf
# Loop from smaller items to larger items.
for k in range(1, num_words+1):
for s in range(num_words-k+1):
t = s + k
# First, create incomplete items.
# left tree
incomplete_vals0 = complete[s, s:t, 1] + complete[(s+1):(t+1), t, 0] + scores[t, s]
incomplete[s, t, 0] = np.max(incomplete_vals0)
incomplete_backtrack[s, t, 0] = s + np.argmax(incomplete_vals0)
# right tree
incomplete_vals1 = complete[s, s:t, 1] + complete[(s+1):(t+1), t, 0] + scores[s, t]
incomplete[s, t, 1] = np.max(incomplete_vals1)
incomplete_backtrack[s, t, 1] = s + np.argmax(incomplete_vals1)
# Second, create complete items.
# left tree
complete_vals0 = complete[s, s:t, 0] + incomplete[s:t, t, 0]
complete[s, t, 0] = np.max(complete_vals0)
complete_backtrack[s, t, 0] = s + np.argmax(complete_vals0)
# right tree
complete_vals1 = incomplete[s, (s+1):(t+1), 1] + complete[(s+1):(t+1), t, 1]
complete[s, t, 1] = np.max(complete_vals1)
complete_backtrack[s, t, 1] = s + 1 + np.argmax(complete_vals1)
value = complete[0][num_words][1]
heads = -np.ones(num_words + 1, dtype=int)
backtrack_eisner(incomplete_backtrack, complete_backtrack, 0, num_words, 1, 1, heads)
value_proj = 0.0
for m in range(1, num_words+1):
h = heads[m]
value_proj += scores[h, m]
if return_probs:
return heads, value_proj
else:
return heads
def backtrack_eisner(incomplete_backtrack, complete_backtrack, s, t, direction, complete, heads):
"""
Backtracking step in Eisner's algorithm.
- incomplete_backtrack is a (NW+1)-by-(NW+1) numpy array indexed by a start position,
an end position, and a direction flag (0 means left, 1 means right). This array contains
the arg-maxes of each step in the Eisner algorithm when building *incomplete* spans.
- complete_backtrack is a (NW+1)-by-(NW+1) numpy array indexed by a start position,
an end position, and a direction flag (0 means left, 1 means right). This array contains
the arg-maxes of each step in the Eisner algorithm when building *complete* spans.
- s is the current start of the span
- t is the current end of the span
- direction is 0 (left attachment) or 1 (right attachment)
- complete is 1 if the current span is complete, and 0 otherwise
- heads is a (NW+1)-sized numpy array of integers which is a placeholder for storing the
head of each word.
"""
if s == t:
return
if complete:
r = complete_backtrack[s][t][direction]
if direction == 0:
backtrack_eisner(incomplete_backtrack, complete_backtrack, s, r, 0, 1, heads)
backtrack_eisner(incomplete_backtrack, complete_backtrack, r, t, 0, 0, heads)
return
else:
backtrack_eisner(incomplete_backtrack, complete_backtrack, s, r, 1, 0, heads)
backtrack_eisner(incomplete_backtrack, complete_backtrack, r, t, 1, 1, heads)
return
else:
r = incomplete_backtrack[s][t][direction]
if direction == 0:
heads[s] = t
backtrack_eisner(incomplete_backtrack, complete_backtrack, s, r, 1, 1, heads)
backtrack_eisner(incomplete_backtrack, complete_backtrack, r+1, t, 0, 1, heads)
return
else:
heads[t] = s
backtrack_eisner(incomplete_backtrack, complete_backtrack, s, r, 1, 1, heads)
backtrack_eisner(incomplete_backtrack, complete_backtrack, r+1, t, 0, 1, heads)
return
Conclusions
I presented a system for the dependency parsing of chord sequences which uses a transformer to build contextual chord hidden representations, and a classifier to select whether two chords should be linked by an arc.
The main advantage with respect to competing systems is that this approach does not rely on any particular symbolic grammar, therefore it can consider multiple musical features simultaneously, make use of sequential context information, and produce partial results for noisy inputs.
To keep this article of a reasonable size, both the explanation and the code focus on the most interesting part of the system. You can find a more complete explanation in this scientific article and all the code on this GitHub repository.
(All images are by the author.)
References
- D. Harasim, C. Finkensiep, P. Ericson, T. J. O’Donnell, and M. Rohrmeier, "The jazz harmony treebank," in Proceedings of the International Society for Music Information Retrieval Conference (ISMIR), 2020, pp. 207–215.
- J. M. Eisner, "Three new probabilistic models for dependency parsing: An exploration," in Proceedings of the International Conference on Computational Linguistics (COLING), 1996.