Motivating Self-Attention

Why do we need queries, keys AND values?

Ryan Xu
Towards Data Science

--

…self-self-attention?

The goal of this article is to offer an explanation not for how the self-attention mechanism in transformers works, but rather for why it was designed that way.

We’ll start with a discussion of the kinds of abilities we would like a language understanding model to have, followed by a interactive construction of the self-attention mechanism. In the process, we’ll discover why we need queries, keys, AND values in order to model relationships between words in a natural way, and that QKV attention is one of the simplest ways to do so.

This article will be most insightful to readers who have encountered transformers and self-attention before, but should be accessible to anybody familiar with some basic linear algebra. To those who are looking to better understand transformers, I will happily refer you to this blog post.

All images are by the author.

Transformers are frequently presented in the context of sequence-to-sequence modeling tasks such as language translation or more saliently, sentence completion. However, I think that it’s easier to start off by thinking about the problem of sequence modeling and specifically, language understanding.

So, here’s a sentence that we want to understand:

Let’s think a little bit about how we understand this sentence.

  • Evan’s dog Riley… From this, we know that Riley is the name of the dog, and that Evan owns Riley.
  • …is so hyper… Simple enough, “hyper” refers to the dog Riley, influencing our impression of Riley.
  • …she never stops moving. This one is interesting. “she” refers to Riley, since the dog is the subject of the first phrase. This tells us that Evan’s dog Riley is fact female, which was previously ambiguous due to the commonly unisex dog name, “Riley”. “never stops moving” is a slightly more complex set of words that elaborates on “hyper”.

The key takeaway here is that to build up our understanding of the sentence, we’re constantly considering how words relate to other words to augment their meaning.

In the machine learning community, the process of augmenting the meaning of a word a by the presence of another word b is colloquially referred to as “a attends to b”, as in word a pays attention to word b.

An arrow a => b indicates that “a attends to b

And so, if we would like a machine learning model to understand language, we might understandably want the model to have the ability to have a word attend to another and somehow update its meaning accordingly.

This is precisely the ability that we hope to emulate as we build up the aptly named (self) attention mechanism in 3 parts.

In the following, I will pose a number of questions in italics. I strongly encourage the reader to stop and consider the question for a minute before continuing.

Part 1

For now, let’s focus on the relationship between the words “dog” and “Riley”. The word “dog” strongly influences the meaning of the word, “Riley”, and so we want “Riley” to attend to “dog”, and so the goal here is to somehow update the meaning of the word “Riley” accordingly.

To make this example more concrete, let’s say that we begin with vector representations of each word, each of length n, based on a context-free understanding of the word. We will assume that this vector space is fairly well organized, meaning that words that are more similar in meaning are associated with vectors that are closer in the space.

So, we have two vectors, v_dog and v_Riley, that capture the meaning of the two words.

How can we update the value of v_Riley using v_dog to obtain a new value for the word “Riley” that incorporates the meaning of “dog”?

We don’t want to completely replace the value of v_Riley with v_dog, so let’s say that we take a linear combination of v_Riley and v_dog as the new value for v_Riley:

v_Riley = get_value('Riley')
v_dog = get_value('dog')

ratio = .75
v_Riley = (ratio * v_Riley) + ((1-ratio) * v_dog)

This seems to work alright, we’ve embedded a bit of the meaning of the word “dog” into the word “Riley”.

Now we would like to try and apply this form of attention to the whole sentence by updating the vector representations of every single word by the vector representations of every other word.

What goes wrong here?

The core problem is that we don’t know which words should take on the meanings of other words. We would also like some measure of how much the value of each word should contribute to each other word.

Part 2

Alright. So we need to know how much two words should be related.

Time for attempt number 2.

I’ve redesigned our vector database so that each word actually has two associated vectors. The first is the same value vector that we had before, still denoted by v. In addition, we now have unit vectors denoted by k that store some notion of word relations. Specifically, if two k vectors are close together, it means that the values associated with these words are likely to influence each other’s meanings.

With our new k and v vectors, how can we modify our previous scheme to update v_Riley’s value with v_dog in a way that respects how much two words are related?

Let’s continue with the same linear combination business as before, but only if the k vectors of both are close in embedding space. Even better, we can use the dot product of the two k vectors (which range from 0–1 since they are unit vectors) to tell us how much we should update v_Riley with v_dog.

v_Riley, v_dog = get_value('Riley'), get_value('dog')
k_Riley, k_dog = get_key('Riley'), get_key('dog')

relevance = k_Riley · k_dog # dot product

v_Riley = (relevance) * v_Riley + (1 - relevance) * v_dog

This is a little bit strange since if relevance is 1, v_Riley gets completely replaced by v_dog, but let’s ignore that for a minute.

I want to instead think about what happens when we apply this kind of idea to the whole sequence. The word “Riley” will have a relevance value with each other word via dot product of ks. So, maybe we can instead update the value of each word proportionally to the value of the dot product. For simplicity, let’s also include it’s dot product with itself as a way to preserve it’s own value.

sentence = "Evan's dog Riley is so hyper, she never stops moving"
words = sentence.split()

# obtain a list of values
values = get_values(words)

# oh yeah, that's what k stands for by the way
keys = get_keys(words)

# get riley's relevance key
riley_index = words.index('Riley')
riley_key = keys[riley_index]

# generate relevance of "Riley" to each other word
relevances = [riley_key · key for key in keys] #still pretending python has ·

# normalize relevances to sum to 1
relevances /= sum(relevances)

# takes a linear combination of values, weighted by relevances
v_Riley = relevances · values

Ok that’s good enough for now.

But once again, I claim that there’s something wrong with this approach. It’s not that any of our ideas have been implemented incorrectly, but rather there’s something fundamentally different between this approach and how we actually think about relationships between words.

If there’s any point in this article where I really really think that you should stop and think, it’s here. Even those of you who think you fully understand attention. What’s wrong with our approach?

A hint

Relationships between words are inherently asymmetric! The way that “Riley” attends to “dog” is different from the way that “dog” attends to “Riley”. It’s a much bigger deal that “Riley” refers to a dog, not a human, than the name of the dog.

In contrast, the dot product is a symmetric operation, which means that in our current setup, if a attends to b, then b attends equally strong to a! Actually, this is somewhat false because we’re normalizing the relevance scores, but the point is that the words should have the option of attending in an asymmetric way, even if the other tokens are held constant.

Part 3

We’re almost there! Finally, the question becomes:

How can we most naturally extend our current setup to allow for asymmetric relationships?

Well what can we do with one more vector type? We still have our value vectors v, and our relation vector k. Now we have yet another vector q for each token.

How can we modify our setup and use q to achieve the asymmetric relationship that we want?

Those of you who are familiar with how self-attention works will hopefully be smirking at this point.

Instead of computing relevance k_dog · k_Riley when “dog” attends to “Riley”, we can instead query q_Riley against the key k_dog by taking their dot product. When computing the other way around, we will have q_dog · k_Riley instead — asymmetric relevance!

Here’s the whole thing together, computing the update for every value at once!

sentence = "Evan's dog Riley is so hyper, she never stops moving"
words = sentence.split()
seq_len = len(words)

# obtain arrays of queries, keys, and values, each of shape (seq_len, n)
Q = array(get_queries(words))
K = array(get_keys(words))
V = array(get_values(words))

relevances = Q @ K.T
normalized_relevances = relevances / relevances.sum(axis=1)

new_V = normalized_relevances @ V

And that’s basically self-attention!

There are a few more details that I left out, but the important ideas are all there.

To recap, we started with value vectors (v) to represent the meaning of each word, but quickly found that we need key vectors (k) to account for how words relate to each other. Finally, to properly model the asymmetric nature of word relationships, we introduced query vectors (q). It almost feels like if all we’re allowed are dot products and such, 3 is the minimal number of vectors per word needed to properly model relationships between words.

The purpose of this article is to unravel the self-attention mechanism in a way that’s less overwhelming than the traditional algorithm-first approach. I hope that from this more language-motivated perspective, the elegance and simplicity of the query-key-value design can show through.

Some details that I left out:

  • Instead of storing 3 vectors for each token, we instead store a single embedding vector from which we can extract our q-k-v vectors. The extraction process is just a linear projection.
  • Technically, in this whole setup, each word has no idea where the other words are in the sentence. Self-attention is really a set operation. So, we need to embed positional knowledge, typically done by adding a position vector to the embedding vector. This isn’t completely trivial since transformers should allow for sequences of arbitrary length. How this works in practice is outside the scope of this article.
  • A single self-attention layer only allows us to represent two-word relationships. But by composing self-attention layers, we can model higher level relationships between words. Since the output of a self-attention layer is of the same sequence length as the original sequence, this means that we can compose them. In fact, transformer blocks are just self-attention layers followed by position-wise feed-forward blocks. Stack a few hundred of these, pay a few million dollars, and you have yourself an LLM! :)
OpenAI showed that it takes 512 transformer blocks to understand that second half.

--

--

interested in random things and writing about them. currently doing applied machine learning research for ebay and grokking grokking on the side :)