10  Multi-head attention

NoteLearning objectives

After completing this chapter, you will be able to:

  • Explain why a single attention head is limiting
  • Describe how multiple heads attend to different aspects of the input
  • Compute multi-head attention by projecting into subspaces
  • Concatenate and project head outputs back to model dimension
  • Understand how heads specialize for different types of relationships

In the previous chapter, we built the self-attention mechanism. It allowed a token to look at the entire sequence and calculate a weighted average of other tokens’ value vectors. It produced a single context-aware representation.

But there is a subtle problem with using just one attention mechanism.

Think about how you understand a sentence like “The pilot flew the plane to Paris.”

To fully grasp this scene, your mind performs several distinct “lookups” simultaneously: 1. Who performed the action? (You link “flew” to “pilot”). 2. What was affected? (You link “flew” to “plane”). 3. Where did it happen? (You link “flew” to “Paris”).

A single self-attention head computes a single probability distribution. It has to decide on one specific pattern of attention. If it attends 80% to “pilot” and 20% to “Paris”, it creates a blended representation that is mostly “agent” and a little bit “location”. It struggles to be precise about both relationships at the same time. It’s like trying to listen to the bass line, the vocals, and the drums in a song all at once, and smashing them into a single audio track.

Multi-head attention gives the model the ability to “listen” to different parts of the input independently. Instead of one global focus, we create multiple parallel attention mechanisms (heads). One head can focus entirely on grammar (finding the subject), while another focuses entirely on geography (finding the location).

10.1 The intuition: Splitting the information bandwidth

To understand how this works mechanically, we need to look at the embedding dimension \(d\) differently.

In most transformer models, \(d\) is a large number, like 512 or 768. We typically think of this as a single long vector representing the token. But intuitively, we can think of this vector as having a total “bandwidth” or “capacity” of 512 units of information.

We don’t need all 512 units just to determine if a word is a subject or an object. That’s a relatively simple question. Maybe we only need 64 units of capacity to answer that.

So, instead of using the huge 512-dimensional space to ask one complex question, we divide our capacity. We split the model into \(h\) smaller, independent “sub-spaces.”

  • Standard approach: One giant head operating in 512 dimensions.
  • Multi-head approach: 8 separate heads, each operating in 64 dimensions.

Crucially, we do not just chop the input vector into pieces. We don’t say “Head 1 looks at the first 64 numbers, Head 2 looks at the next 64.”

Instead, every head gets to look at the entire original 512-dimensional vector. But each head has its own learned “lens” (projection matrix) that extracts only the information relevant to its specific job, compressing it down into a smaller 64-dimensional vector.

10.1.1 The projection: A specialized lens

This is where the projection matrices \(\mathbf{W}^Q, \mathbf{W}^K, \mathbf{W}^V\) come in. In multi-head attention, every head \(i\) gets its own unique set of matrices: \(\mathbf{W}_i^Q, \mathbf{W}_i^K, \mathbf{W}_i^V\).

Imagine the input vector for “Banks” in the sentence “The river banks overflowed.” This vector contains a mess of potential meanings: financial institution, river side, turning a plane, verb form, noun form.

  1. Head 1 (The Syntax Expert): Its projection matrix \(\mathbf{W}_1^Q\) is learned to ignore the “financial” or “river” meanings. It projects the 512-dim vector down to a 64-dim vector that purely encodes: “This is a plural noun.”
  2. Head 2 (The Semantic Expert): Its projection matrix \(\mathbf{W}_2^Q\) ignores the grammar. It projects the same 512-dim input down to a different 64-dim vector that encodes: “This is a physical object related to water.”

By projecting into these smaller subspaces, the heads can perform attention cleanly. Head 1 will find that “overflowed” (verb) is looking for a subject. Head 2 will find that “river” (water context) matches “banks”.

10.2 The mathematical formulation

Now we can formalize this. Let \(d_{model} = 512\) be the input dimension and \(h = 8\) be the number of heads. We set the dimension of each head to \(d_k = d_{model} / h = 64\).

For each head \(i = 1 \dots h\):

  1. Project: We take the same input \(\mathbf{X}\) and project it into the head’s specific subspace. \[ \mathbf{Q}_i = \mathbf{X}\mathbf{W}_i^Q, \quad \mathbf{K}_i = \mathbf{X}\mathbf{W}_i^K, \quad \mathbf{V}_i = \mathbf{X}\mathbf{W}_i^V \] Here, the projection matrices are size \(d_{model} \times d_k\) (e.g., \(512 \times 64\)). This effectively “compresses” the information from the full width down to the head’s specialized width.

  2. Attend: We calculate attention independently in this small subspace. \[ \text{head}_i = \text{softmax}\left(\frac{\mathbf{Q}_i \mathbf{K}_i^T}{\sqrt{d_k}}\right) \mathbf{V}_i \] The result \(\text{head}_i\) is a sequence of vectors of size \(d_k\) (64).

  3. Concatenate: This is the re-assembly phase.

    At this point, for a single token, we have 8 separate vectors, each of size 64.

    • head 1: [Grammar features]
    • head 2: [Context features]
    • head 8: [Position features]

    To move forward in the network, we need to return to our standard interface: a single vector of size 512. We do this by simply placing the 8 vectors side-by-side to form one long vector.

    \[ \text{MultiHeadOutput} = \Big[ \underbrace{\mathbf{v}_1}_{\text{Head 1}} \Big| \underbrace{\mathbf{v}_2}_{\text{Head 2}} \Big| \dots \Big| \underbrace{\mathbf{v}_8}_{\text{Head 8}} \Big] \]

    The result is a vector of length \(64 + 64 + \dots + 64 = 512\).

    Crucially, this restored vector is “segregated.” The first 64 numbers only contain information from Head 1. The next 64 only from Head 2. They haven’t talked to each other yet. If we stopped here, the next layer would receive a disjointed input where syntax information lives in one block and semantic information in another.

  4. Final linear projection (the mixer):

    We now have a concatenated vector of size 512. But there is a problem: it is segregated. The first 64 numbers come exclusively from Head 1. The next 64 come from Head 2. These segments are neighbors, but they haven’t interacted. Head 1 might know the word is “banks” (plural noun), and Head 2 might know the context is “river” (nature), but no single number in the vector represents “river banks”.

    To fix this, we apply a final linear transformation using a weight matrix \(\mathbf{W}^O\) (size \(512 \times 512\)).

    \[ \mathbf{Z} = \text{MultiHeadOutput} \times \mathbf{W}^O \]

    10.2.1 How the math mixes information

    To understand why this mixes the heads, look at the linear algebra operation for a single output value. Let the concatenated input vector be \(\mathbf{h} = [h_1, h_2, \dots, h_{512}]\) and the output vector be \(\mathbf{z} = [z_1, z_2, \dots, z_{512}]\).

    The \(j\)-th feature of the output is calculated as:

    \[ z_j = \sum_{i=1}^{512} h_i \cdot W_{ij}^O \]

    Look closely at the summation range (\(i=1\) to \(512\)). It iterates over the entire length of the concatenated vector.

    • Indices \(1 \dots 64\) come from Head 1.
    • Indices \(65 \dots 128\) come from Head 2.
    • …and so on.

    This means every single value \(z_j\) in the output is a weighted sum of all heads simultaneously. The matrix \(\mathbf{W}^O\) acts as a massive mixing desk. It can say: “To create output feature 5, take 20% of Head 1’s syntax signal, add 50% of Head 2’s semantic signal, and subtract 10% of Head 3’s position signal.”

    10.2.2 Discovering the weights

    What is inside \(\mathbf{W}^O\)? Ideally, it contains the perfect “recipes” for combining these diverse signals. But we don’t hand-code them.

    \(\mathbf{W}^O\) consists of \(512 \times 512 = 262,144\) learnable parameters. Initially, these are set to small random numbers. We “discover” their optimal values through training (backpropagation).

    • If the model makes a prediction error because it failed to combine “subject” and “verb” correctly, the gradient descent algorithm calculates exactly how to adjust the weights in \(\mathbf{W}^O\).
    • Over billions of training examples, \(\mathbf{W}^O\) evolves from random noise into a highly optimized routing system that knows exactly which heads contain relevant information for which tasks, and how to blend them to produce the most useful representation for the next layer.

10.3 Why divide the dimension?

You might ask: Why not just have 8 heads that are all 512 dimensions wide? Why do we have to shrink them to 64?

Two reasons: Computational cost and Parameter count.

If we had 8 full-size heads, the computation would be 8 times more expensive than a single head. By shrinking the dimension by a factor of \(h\), the math works out beautifully:

  • Single Head: 1 dot product of size 512.
  • Multi Head: 8 dot products of size 64.

Since \(8 \times 64 = 512\), the total number of floating-point operations for the projections and the final fusion is roughly the same as a single giant head. We get the benefit of multiple independent attention patterns “for free” in terms of compute budget.

10.4 Summary

Multi-head attention is the defining feature that allows Transformers to understand the nuance of language.

  1. Decomposition: It decomposes the complex input vector into specialized lower-dimensional subspaces.
  2. Specialization: Different heads learn to look for different things (grammar, coreference, context).
  3. Recomposition: It fuses these diverse insights back into a unified representation.

It transforms the “bag of vectors” into a rich, multi-layered web of relationships. However, despite all this sophistication, our model still has a glaring hole: it has no idea that “The” comes before “cat”. In the next chapter, we will fix this with Positional Encoding.