8  The attention mechanism

NoteLearning objectives

After completing this chapter, you will be able to:

  • Explain the information bottleneck problem in sequence-to-sequence models
  • Define queries, keys, and values and their roles in attention
  • Compute scaled dot-product attention scores and weights
  • Derive why the scaling factor \(\sqrt{d_k}\) prevents vanishing gradients
  • Implement the complete attention mechanism in matrix form

Attention is the core innovation that makes transformers work. It’s a mechanism for selectively combining information from different positions based on relevance. Instead of compressing a sequence into a fixed-size vector, attention lets a model look back at all positions and decide dynamically which ones matter for the current computation. This chapter develops attention from first principles.

8.1 The bottleneck problem revisited

Recall from Chapter 3 that RNN-based sequence-to-sequence models have a bottleneck: the encoder compresses the entire input sequence into a single vector, which the decoder must use to generate the output.

\[ \text{input sequence} \to \text{encoder} \to \text{single vector} \to \text{decoder} \to \text{output sequence} \]

For long sequences, this single vector can’t capture all the relevant information. Early parts of the input get “overwritten” by later parts.

Attention solves this by giving the decoder direct access to all encoder hidden states. At each decoding step, the decoder can “attend to” different parts of the input, focusing on what’s relevant for the current output.

8.2 The basic idea: weighted combinations

Suppose we have a sequence of vectors \(\mathbf{v}_1, \mathbf{v}_2, \ldots, \mathbf{v}_n\) (the “values”) and we want to produce a single output vector that combines them. In a transformer processing the sentence “The cat sat on the mat,” these might be \(\mathbf{v}_1\) for “The”, \(\mathbf{v}_2\) for “cat”, \(\mathbf{v}_3\) for “sat”, and so on. Each vector is a learned representation of that token (we’ll see exactly how these are computed later).

The simplest way to combine them is averaging:

\[ \mathbf{o} = \frac{1}{n} \sum_{i=1}^n \mathbf{v}_i \]

But this treats all positions equally. What if some positions are more relevant than others for our current purpose? When computing the representation for “sat,” we might care more about “cat” (the subject) than “the” (a function word).

Attention computes a weighted average where the weights depend on relevance:

\[ \mathbf{o} = \sum_{i=1}^n \alpha_i \mathbf{v}_i \]

where \(\alpha_1, \ldots, \alpha_n\) are attention weights that sum to 1: \(\sum_i \alpha_i = 1\) and \(\alpha_i \geq 0\).

If \(\alpha_3 = 0.7\) and all other \(\alpha_i\) are small, the output is dominated by \(\mathbf{v}_3\). We’re “paying attention” to position 3. The key question is: how do we compute the weights?

8.3 Queries, keys, and values

Attention uses three concepts: queries, keys, and values. Think of it like a soft database lookup. The query is what we’re looking for, representing the current position’s “question” to the rest of the sequence. The keys label what’s available at each position, advertising what each position offers. The values are the actual content at each position, the information that gets retrieved.

The query compares against each key to determine relevance (the attention weights), then the weights are used to combine the values.

Concretely, suppose we have a query vector \(\mathbf{q} \in \mathbb{R}^d\) representing what we’re looking for, key vectors \(\mathbf{k}_1, \ldots, \mathbf{k}_n \in \mathbb{R}^d\) representing what each of \(n\) positions offers, and value vectors \(\mathbf{v}_1, \ldots, \mathbf{v}_n \in \mathbb{R}^{d_v}\) containing the actual content at each position. Here \(d\) is the dimension of queries and keys (they must match for the dot product), and \(d_v\) is the dimension of values (which can differ).

The attention weight for position \(i\) measures how well the query matches key \(i\):

\[ \alpha_i = \frac{\exp(s(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^n \exp(s(\mathbf{q}, \mathbf{k}_j))} \]

where \(s(\mathbf{q}, \mathbf{k})\) is a score function measuring similarity. This is a softmax over scores, ensuring weights are positive and sum to 1.

The output is the weighted combination of values:

\[ \mathbf{o} = \sum_{i=1}^n \alpha_i \mathbf{v}_i \]

8.4 The score function: scaled dot product

The most common score function in transformers is the scaled dot product:

\[ s(\mathbf{q}, \mathbf{k}) = \frac{\mathbf{q}^T \mathbf{k}}{\sqrt{d}} \]

where \(d\) is the dimension of the query and key vectors.

Why the dot product? It measures alignment: \(\mathbf{q}^T \mathbf{k}\) is large when \(\mathbf{q}\) and \(\mathbf{k}\) point in similar directions. If the query represents “what I’m looking for” and the key represents “what this position offers,” a high dot product means a good match.

Why scale by \(\sqrt{d}\)? Consider what happens without scaling. If \(\mathbf{q}\) and \(\mathbf{k}\) have entries drawn from a distribution with mean 0 and variance 1, then \(\mathbf{q}^T \mathbf{k} = \sum_{i=1}^d q_i k_i\) has variance approximately \(d\) (sum of \(d\) terms, each with variance 1). For large \(d\), the dot products can have large magnitude.

Large dot products cause the softmax to saturate. If one score is 100 and others are around 0, then \(\exp(100)\) dominates and the softmax puts essentially all weight on that position. The gradients through saturated softmax are very small, making learning difficult.

Dividing by \(\sqrt{d}\) keeps the variance around 1 regardless of dimension, preventing saturation.

8.4.1 Concrete example

Let’s compute attention step by step. Suppose \(d = 3\) and we have:

Query: \(\mathbf{q} = [1, 0, 1]^T\)

Keys: \(\mathbf{k}_1 = [1, 1, 0]^T\), \(\mathbf{k}_2 = [0, 1, 1]^T\), \(\mathbf{k}_3 = [1, 0, 1]^T\)

Values: \(\mathbf{v}_1 = [1, 2]^T\), \(\mathbf{v}_2 = [3, 4]^T\), \(\mathbf{v}_3 = [5, 6]^T\)

Step 1: Compute scores

\[ s_1 = \frac{\mathbf{q}^T \mathbf{k}_1}{\sqrt{3}} = \frac{1 \cdot 1 + 0 \cdot 1 + 1 \cdot 0}{\sqrt{3}} = \frac{1}{1.732} \approx 0.577 \]

\[ s_2 = \frac{\mathbf{q}^T \mathbf{k}_2}{\sqrt{3}} = \frac{1 \cdot 0 + 0 \cdot 1 + 1 \cdot 1}{\sqrt{3}} = \frac{1}{1.732} \approx 0.577 \]

\[ s_3 = \frac{\mathbf{q}^T \mathbf{k}_3}{\sqrt{3}} = \frac{1 \cdot 1 + 0 \cdot 0 + 1 \cdot 1}{\sqrt{3}} = \frac{2}{1.732} \approx 1.155 \]

Step 2: Compute attention weights (softmax)

\[ \alpha_1 = \frac{\exp(0.577)}{\exp(0.577) + \exp(0.577) + \exp(1.155)} = \frac{1.781}{1.781 + 1.781 + 3.174} = \frac{1.781}{6.736} \approx 0.264 \]

\[ \alpha_2 = \frac{1.781}{6.736} \approx 0.264 \]

\[ \alpha_3 = \frac{3.174}{6.736} \approx 0.471 \]

Notice that \(\alpha_3\) is largest because \(\mathbf{q}\) and \(\mathbf{k}_3\) are identical (maximum alignment).

Step 3: Compute output

\[ \mathbf{o} = 0.264 \begin{bmatrix} 1 \\ 2 \end{bmatrix} + 0.264 \begin{bmatrix} 3 \\ 4 \end{bmatrix} + 0.471 \begin{bmatrix} 5 \\ 6 \end{bmatrix} \]

\[ = \begin{bmatrix} 0.264 + 0.792 + 2.355 \\ 0.528 + 1.056 + 2.826 \end{bmatrix} = \begin{bmatrix} 3.41 \\ 4.41 \end{bmatrix} \]

The output \([3.41, 4.41]\) is a blend of all three value vectors, weighted toward \(\mathbf{v}_3 = [5, 6]\) because the query best matches \(\mathbf{k}_3\).

What is this output useful for? In a transformer, this becomes the new representation for the position that issued the query. If position 1 queries positions 1, 2, and 3, the output replaces position 1’s old embedding with a new one that incorporates relevant information from all three positions. The original embedding only knew about itself; the new embedding is context-aware, having gathered information from wherever the attention weights pointed.

In this example, the output \([3.41, 4.41]\) is closer to \(\mathbf{v}_3 = [5, 6]\) than to \(\mathbf{v}_1 = [1, 2]\) because position 3 was deemed most relevant. If these were word embeddings in a sentence, the querying word now “knows about” the other words, weighted by relevance.

8.5 Matrix formulation

When we have multiple queries, we can compute attention for all of them in parallel using matrix operations. We stack the queries into a matrix \(\mathbf{Q} \in \mathbb{R}^{m \times d}\) where each of the \(m\) rows is a query vector. Similarly, we have \(\mathbf{K} \in \mathbb{R}^{n \times d}\) containing \(n\) key vectors as rows, and \(\mathbf{V} \in \mathbb{R}^{n \times d_v}\) containing \(n\) value vectors as rows.

The scaled dot-product attention computes:

\[ \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d}}\right) \mathbf{V} \]

State of the art: This exact formula is used in all modern transformers, including GPT-4, Claude, LLaMA, and BERT. The scaled dot-product attention mechanism was introduced in the original “Attention Is All You Need” paper (2017).

Let’s unpack this formula step by step. First, the matrix product \(\mathbf{Q} \mathbf{K}^T\) gives us an \(m \times n\) matrix where entry \((i, j)\) is the dot product of query \(i\) with key \(j\). This computes all pairwise similarity scores in one operation. Next, we divide by \(\sqrt{d}\) to prevent the dot products from growing too large (as discussed earlier). Then, softmax is applied to each row independently, so row \(i\) becomes a probability distribution representing how much query \(i\) attends to each of the \(n\) keys. Finally, we multiply by \(\mathbf{V}\): each row of the result is a weighted combination of value vectors, using that row’s attention weights.

The output has shape \(m \times d_v\): for each of the \(m\) queries, we get a \(d_v\)-dimensional output vector.

8.5.1 Matrix example

Using the same values as before, with our single query as a \(1 \times 3\) matrix:

\[ \mathbf{Q} = \begin{bmatrix} 1 & 0 & 1 \end{bmatrix} \]

\[ \mathbf{K} = \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 1 \\ 1 & 0 & 1 \end{bmatrix} \]

\[ \mathbf{V} = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \]

Scores: \(\mathbf{Q}\mathbf{K}^T = \begin{bmatrix} 1 & 1 & 2 \end{bmatrix}\)

Scaled: \(\frac{1}{\sqrt{3}} \begin{bmatrix} 1 & 1 & 2 \end{bmatrix} \approx \begin{bmatrix} 0.577 & 0.577 & 1.155 \end{bmatrix}\)

Softmax: \(\begin{bmatrix} 0.264 & 0.264 & 0.471 \end{bmatrix}\)

Output: \(\begin{bmatrix} 0.264 & 0.264 & 0.471 \end{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} = \begin{bmatrix} 3.41 & 4.41 \end{bmatrix}\)

Same result as before, but computed via matrix operations.

8.6 Where do Q, K, V come from?

In the examples above, we assumed queries, keys, and values were given. In a transformer, they’re computed from the input using learned linear projections.

Given an input sequence \(\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}\) where each of the \(n\) rows is a token embedding of dimension \(d_{\text{model}}\), we compute:

\[ \mathbf{Q} = \mathbf{X} \mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X} \mathbf{W}^K, \quad \mathbf{V} = \mathbf{X} \mathbf{W}^V \]

Here \(\mathbf{W}^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}\) is the query projection matrix, \(\mathbf{W}^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\) is the key projection matrix, and \(\mathbf{W}^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\) is the value projection matrix. The dimensions \(d_k\) and \(d_v\) are hyperparameters; often \(d_k = d_v = d_{\text{model}}\), but they can be smaller to reduce computation.

Each token’s embedding gets projected into three different spaces. The query projection asks “What am I looking for?” and transforms the token into a representation optimized for searching. The key projection asks “What do I offer to others?” and transforms the token into a representation optimized for being found. The value projection asks “What information do I contribute?” and transforms the token into the content that will be retrieved.

The same token has different projections because it plays different roles depending on whether it’s the one asking (query) or being asked about (key/value).

8.6.1 Why separate projections?

Why not use the same vector for query, key, and value? Consider what attention computes: “how relevant is position \(j\) to position \(i\)?” This relevance might depend on different aspects of the tokens.

For example, in “The cat sat on the mat,” when generating output for “sat,” we might want:

  • Query based on “sat” emphasizing: “I’m a verb, I need a subject”
  • Key for “cat” emphasizing: “I’m a noun, I can be a subject”
  • Value for “cat” providing: its actual semantic content

Separate projections let the model learn these different roles. The query projection can emphasize syntactic features, the key projection can respond to those queries, and the value projection can provide semantic content.

8.7 Attention as information routing

Another way to understand attention: it routes information between positions. Each position can read from any other position, with the amount of information transferred controlled by the attention weights.

The attention weight \(\alpha_{ij}\) represents “how much information flows from position \(j\) to position \(i\).” When \(\alpha_{ij}\) is high, position \(i\)’s output strongly reflects position \(j\)’s value.

This creates a flexible information flow that depends on the content of the sequence. Unlike RNNs where information flows strictly sequentially, attention allows direct long-range connections. A token at position 1 can directly influence position 100 if the attention weights say it’s relevant.

The computational complexity is \(O(n^2 \cdot d)\) because we compute \(n^2\) attention scores (every query with every key) and then \(n\) weighted combinations of \(d\)-dimensional values. For very long sequences, this quadratic cost becomes expensive, which is why various “efficient attention” variants exist (though we won’t cover them in detail).

8.8 Attention visualized

Attention weights form an \(n \times n\) matrix (for self-attention, which we’ll cover next chapter). Visualizing this matrix reveals patterns that emerge from training and reflect the linguistic structure the model has learned.

Diagonal patterns appear when tokens attend to themselves or nearby tokens, capturing local context and sequential relationships. Vertical stripes occur when many tokens attend to a specific position, often the start of a sentence or a semantically important word like a verb. Block patterns emerge when groups of tokens attend to other groups, reflecting phrase structure or clause boundaries. Sparse patterns are common in trained models: most weights are near zero, with a few dominant connections, suggesting the model learns to focus on a small number of relevant positions rather than spreading attention uniformly.

8.9 Properties of attention

Attention has several important properties that make it well-suited for sequence modeling.

First, attention is permutation equivariant. This means if you shuffle the input sequence, the output gets shuffled in exactly the same way. Consider a sentence “cat sat mat” with tokens at positions 1, 2, 3. If we shuffle it to “mat cat sat” (positions 3, 1, 2), attention produces the same outputs, just reordered to match. The output for “cat” is identical whether “cat” was at position 1 or position 2.

Why? Because attention only compares content (queries against keys), not positions. The attention weight between “cat” and “sat” depends on their embeddings, not on whether “cat” is first or second. Attention treats the input as a set of vectors, not an ordered sequence.

This is a problem for language, where order matters (“dog bites man” vs “man bites dog”). The solution is to add position information to the embeddings before attention sees them, via positional encodings (covered later). Without positional encodings, attention would process “the cat sat” and “sat the cat” identically.

Second, attention is differentiable. Every operation (the dot products, the scaling, the softmax, the weighted sum) is a smooth function of its inputs. Gradients flow through attention without discontinuities, allowing end-to-end training via backpropagation.

Third, attention is parallelizable. Unlike RNNs where each step depends on the previous step, all positions in attention can be processed simultaneously. The matrix multiplication \(\mathbf{Q}\mathbf{K}^T\) computes all pairwise scores at once, and modern GPUs are optimized for exactly this kind of parallel matrix operation.

Fourth, attention is dynamic: the attention pattern changes based on the input content. The same trained model will attend to different positions for different inputs, because the attention weights are computed from the queries and keys, not fixed during training. This content-based routing is what makes attention so powerful.

8.10 Attention vs. full connection

You might wonder: why not just use a fully connected layer instead of attention? A fully connected layer also lets every position influence every other position.

The difference is in how the weights are determined. In a fully connected layer, the weights are fixed parameters learned during training. The connection from position \(j\) to position \(i\) has the same weight regardless of what’s at those positions. The weights are static, determined only by the positions.

In attention, the weights are computed from the content. The connection strength depends on how well the query at \(i\) matches the key at \(j\). Different inputs produce different attention patterns, even with the same trained weights.

This makes attention a form of content-based routing: the same architecture can route information differently for different inputs. This is crucial for handling variable-length sequences and capturing context-dependent relationships. A fully connected layer can’t generalize to sequences of different lengths, but attention can.

8.11 Summary

We’ve developed the attention mechanism:

  • Attention computes a weighted average of values, where weights depend on query-key similarity.
  • The scaled dot product \(\frac{\mathbf{q}^T \mathbf{k}}{\sqrt{d}}\) measures how well a query matches a key.
  • Softmax turns scores into a probability distribution (attention weights).
  • In matrix form: \(\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d}}\right) \mathbf{V}\)
  • Queries, keys, and values are computed from input via learned linear projections.
  • Attention enables content-based information routing between all positions.

In the next chapter, we’ll see how attention is applied when queries, keys, and values all come from the same sequence: self-attention.