
Cross Attention is a fundamental tool in creating AI models that can understand multiple forms of data simultaneously. Think language models that can understand images like the ones used in ChatGPt, or models that generate video based on text like Sora.
This summary goes over all critical mathematical operations within cross attention, allowing you to understand its inner workings at a fundamental level.
Step 1: Defining the Inputs
Cross attention is used when modeling with a variety of data types, each of which might format the input differently. For natural language data one would likely use a word to vector embedding, paired with positional encoding, to calculate a vector that represents each word.

For visual data, one might pass the image through an encoder specifically designed to summarize the image into a vector representation.

In cross attention, I like to think of one of the inputs being used to filter the other input, thus allowing the data of both inputs to interact with one another in an AI model.

In this example, we might refer to the image data as the "query source input", as it will be used to construct the "query" within the attention mechanism, and the textual data as the "key-value source input", as it will be used to construct the "key" and "value" within the attention mechanism.
Step 2: Defining the Learnable Parameters
Multi headed self-attention essentially learns three weight matrices. These are used to construct the "query", "key", and "value", which are used later in the cross attention mechanism. These would be initially randomly defined and then updated throughout the training process.

Step 3: Defining the Query, Key, and Value
Now that we have weight matrices for our model, we can multiply them by our inputs to generate our query, key, and value. In this example we’ll multiply the weight matrix for the query with the image data to generate the query, and the key and value weight matrices by the text data to generate the key and value. Recall that in matrix multiplication, every value in a row of the first matrix is multiplied by the corresponding values in a column in the second matrix. Those multiplied values are summed to represent one value in the output.

Once the image data has been multiplied by the query weights, and the text data has been multiplied by the key and value weights, we have our query, key, and value.

Step 4: Dividing into Heads
In this example we’ll use two attention heads, meaning we’ll do cross attention with two sub-representations of the input. We’ll set that up by dividing the query, key, and value in two.

The Query, Key, & Value with label 1 will be passed to the first attention head, and the Query, Key, & Value with label 2 will be passed to the second attention head. Essentially, this allows multi-headed cross attention to reason about the same inputs in various different ways in parallel.
Step 5: Calculating the Z Matrix
To construct the attention matrix, we’ll first multiply the query and key together to construct what’s commonly referred to as the "Z" matrix. We’ll only be doing this for attention head 1, but keep in mind all of these calculations are also going on in attention head 2.

Because of the way the math shakes out, the "Z" matrix values have a tendency to grow as the size of the query and key grow. This is counteracted by dividing the values in the "Z" matrix by the square root of the sequence length.

Step 6 (optional): Masking
Depending on the application, a mask might be applied to cross attention so that only certain query source input tokens can interact with certain key-value source input tokens. We won’t be doing that in this example because we want all image data to interact with all text data, but you can get an idea of how masking works from another one of my by-hand articles.
Step 7: Calculating the Attention Matrix
The whole point of calculating the "Z" matrix, and optionally applying a mask, was to create an attention matrix. This can be done by softmaxing each row in the "Z" matrix. The equation for calculating softmax is as follows:

Meaning a value in a row is equal to e raised to that value, divided by the sum of e raised to all values in that row. We can softmax the first row of the Z matrix:

And in such a way we can calculate the softmax for each row in the z matrix, thus calculating the attention matrix

Step 8: Calculating the Output of the Attention Head
Now that we’ve calculated the attention matrix, we can multiply it by the value matrix to construct the output of the attention head.

Step 9: Concatenating the Output
Both attention heads create a distinct output, which are concatenated together to produce the final output of multi-headed cross attention.

Conclusion
And that’s it. In this article we covered the major steps to computing the output of Multi Headed cross Attention:
- defining the input
- defining the learnable parameters of the mechanism
- calculating the Query, Key, and Value
- diving into multiple heads
- calculating the Z matrix
- (optionally) masking
- calculating the attention matrix
- calculating the output of the attention head
- concatenating the output
While we used the example of cross attending an image with text, ultimately what we really cross attended was a matrix of values with another matrix of values. If you can turn some form of data (like audio, video, robot sensor data, etc.) into a similar matrix, you can use that data in cross attention!
If you’d like to learn more about the intuition behind this topic, check out these IAEE articles:
Transformers – Intuitively and Exhaustively Explained
Flamingo – Intuitively and Exhaustively Explained
Join IAEE
At IAEE you can find:
- Long form content, like the article you just read
- Thought pieces, based on my experience as a data scientist, engineering director, and entrepreneur
- A discord community focused on learning AI
- Lectures, by me, every week
