9  Self-attention

NoteLearning objectives

After completing this chapter, you will be able to:

  • Distinguish self-attention from cross-attention
  • Compute self-attention where queries, keys, and values come from the same sequence
  • Trace through a complete self-attention computation with concrete numbers
  • Explain why nonlinearity is essential in neural networks
  • Describe the role of the feed-forward network in transformer blocks

In the previous chapter, we developed attention as a mechanism for combining information based on relevance. The queries asked about keys, and the answers came from values. But where do these come from? In self-attention, the queries, keys, and values all come from the same sequence. Each position attends to every other position (including itself) in the same sequence. This is the fundamental operation at the heart of transformers.

9.1 From attention to self-attention

In general attention, queries, keys, and values can come from different sequences. Consider machine translation from English to French. The encoder processes the English sentence “The cat sat” and produces a sequence of hidden states. The decoder generates French tokens one at a time: “Le”, “chat”, … When generating “chat”, the decoder needs to look back at the English sentence to find the relevant word (“cat”). So the decoder’s current state becomes the query, while the encoder’s hidden states provide the keys and values. The decoder is asking: “Which English words should I pay attention to right now?” This is called cross-attention: the query comes from one sequence (French, being generated) and the keys/values come from another (English, already encoded).

In self-attention, a single sequence provides all three. Queries, keys, and values all come from the same sequence. Each token attends to every other token in that same sequence, building a representation that incorporates information from the entire context. There’s no second sequence involved.

Why is this useful? Self-attention lets each position gather information from all other positions based on relevance. A verb can find its subject, a pronoun can find its antecedent, and a word can incorporate context from anywhere in the sequence, all in a single operation.

9.2 The self-attention computation

Given an input sequence \(\mathbf{X} \in \mathbb{R}^{n \times d}\) where \(n\) is the sequence length and \(d\) is the embedding dimension (each row is one token’s embedding), self-attention proceeds in four steps.

Step 1: Project to queries, keys, values. We transform the input into three different representations:

\[ \mathbf{Q} = \mathbf{X} \mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X} \mathbf{W}^K, \quad \mathbf{V} = \mathbf{X} \mathbf{W}^V \]

Here \(\mathbf{W}^Q, \mathbf{W}^K, \mathbf{W}^V \in \mathbb{R}^{d \times d_k}\) are learned weight matrices that project the \(d\)-dimensional input into \(d_k\)-dimensional query, key, and value spaces. Typically \(d_k = d\) or \(d_k = d/h\) where \(h\) is the number of attention heads (covered in the next chapter). The key point is that the same input \(\mathbf{X}\) is used for all three projections. This is what makes it self-attention.

Step 2: Compute attention scores. We measure how relevant each position is to each other position:

\[ \mathbf{S} = \frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}} \]

The score matrix \(\mathbf{S} \in \mathbb{R}^{n \times n}\) contains all pairwise scores. Entry \(s_{ij}\) is the scaled dot product between position \(i\)’s query and position \(j\)’s key, answering: how relevant is position \(j\) to position \(i\)?

Step 3: Apply softmax. We convert scores to probabilities:

\[ \mathbf{A} = \text{softmax}(\mathbf{S}) \]

Softmax is applied row-wise. Row \(i\) of the attention matrix \(\mathbf{A}\) contains the attention weights, a probability distribution over all positions representing how much position \(i\) attends to each position.

Step 4: Compute output. We gather information according to the attention weights:

\[ \mathbf{O} = \mathbf{A} \mathbf{V} \]

Each row of the output \(\mathbf{O} \in \mathbb{R}^{n \times d_k}\) is a weighted combination of value vectors. Position \(i\)’s output is the sum of all value vectors, weighted by how much position \(i\) attends to each position.

Combining all steps, the complete self-attention operation is:

\[ \text{SelfAttention}(\mathbf{X}) = \text{softmax}\left(\frac{\mathbf{X}\mathbf{W}^Q (\mathbf{X}\mathbf{W}^K)^T}{\sqrt{d_k}}\right) \mathbf{X}\mathbf{W}^V \]

9.3 Concrete example

Let’s work through self-attention on a tiny sequence. Suppose we have 3 tokens with 4-dimensional embeddings:

\[ \mathbf{X} = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \]

Row 1 is token 1’s embedding, row 2 is token 2’s, row 3 is token 3’s.

For simplicity, let’s use \(d_k = 2\) and assume the projection matrices are:

\[ \mathbf{W}^Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad \mathbf{W}^K = \begin{bmatrix} 0 & 1 \\ 1 & 0 \\ 0 & 1 \\ 1 & 0 \end{bmatrix}, \quad \mathbf{W}^V = \begin{bmatrix} 1 & 1 \\ 0 & 0 \\ 0 & 1 \\ 1 & 0 \end{bmatrix} \]

Step 1: Compute Q, K, V

\[ \mathbf{Q} = \mathbf{X}\mathbf{W}^Q = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} 2 & 0 \\ 0 & 2 \\ 1 & 1 \end{bmatrix} \]

\[ \mathbf{K} = \mathbf{X}\mathbf{W}^K = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \begin{bmatrix} 0 & 1 \\ 1 & 0 \\ 0 & 1 \\ 1 & 0 \end{bmatrix} = \begin{bmatrix} 0 & 2 \\ 2 & 0 \\ 1 & 1 \end{bmatrix} \]

\[ \mathbf{V} = \mathbf{X}\mathbf{W}^V = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix} \begin{bmatrix} 1 & 1 \\ 0 & 0 \\ 0 & 1 \\ 1 & 0 \end{bmatrix} = \begin{bmatrix} 1 & 2 \\ 1 & 0 \\ 1 & 1 \end{bmatrix} \]

Step 2: Compute scores

\[ \mathbf{Q}\mathbf{K}^T = \begin{bmatrix} 2 & 0 \\ 0 & 2 \\ 1 & 1 \end{bmatrix} \begin{bmatrix} 0 & 2 & 1 \\ 2 & 0 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 4 & 2 \\ 4 & 0 & 2 \\ 2 & 2 & 2 \end{bmatrix} \]

Scaling by \(\sqrt{d_k} = \sqrt{2} \approx 1.414\):

\[ \mathbf{S} = \frac{1}{1.414} \begin{bmatrix} 0 & 4 & 2 \\ 4 & 0 & 2 \\ 2 & 2 & 2 \end{bmatrix} \approx \begin{bmatrix} 0 & 2.83 & 1.41 \\ 2.83 & 0 & 1.41 \\ 1.41 & 1.41 & 1.41 \end{bmatrix} \]

Step 3: Softmax

For row 1: \(\text{softmax}([0, 2.83, 1.41]) = [\frac{1}{1+16.95+4.10}, \frac{16.95}{22.05}, \frac{4.10}{22.05}] \approx [0.045, 0.769, 0.186]\)

For row 2: \(\text{softmax}([2.83, 0, 1.41]) \approx [0.769, 0.045, 0.186]\)

For row 3: \(\text{softmax}([1.41, 1.41, 1.41]) = [0.333, 0.333, 0.333]\)

\[ \mathbf{A} \approx \begin{bmatrix} 0.045 & 0.769 & 0.186 \\ 0.769 & 0.045 & 0.186 \\ 0.333 & 0.333 & 0.333 \end{bmatrix} \]

Step 4: Compute output

\[ \mathbf{O} = \mathbf{A}\mathbf{V} = \begin{bmatrix} 0.045 & 0.769 & 0.186 \\ 0.769 & 0.045 & 0.186 \\ 0.333 & 0.333 & 0.333 \end{bmatrix} \begin{bmatrix} 1 & 2 \\ 1 & 0 \\ 1 & 1 \end{bmatrix} \]

For row 1: \(0.045 \cdot [1,2] + 0.769 \cdot [1,0] + 0.186 \cdot [1,1] = [1.0, 0.28]\)

For row 2: \(0.769 \cdot [1,2] + 0.045 \cdot [1,0] + 0.186 \cdot [1,1] = [1.0, 1.72]\)

For row 3: \(0.333 \cdot [1,2] + 0.333 \cdot [1,0] + 0.333 \cdot [1,1] = [1.0, 1.0]\)

\[ \mathbf{O} \approx \begin{bmatrix} 1.0 & 0.28 \\ 1.0 & 1.72 \\ 1.0 & 1.0 \end{bmatrix} \]

9.3.1 Interpreting the attention pattern

Look at the attention matrix \(\mathbf{A}\):

  • Token 1 mostly attends to token 2 (weight 0.769)
  • Token 2 mostly attends to token 1 (weight 0.769)
  • Token 3 attends equally to all tokens (weights 0.333 each)

These patterns emerged from the dot products between queries and keys. Tokens 1 and 2 have queries and keys that align strongly with each other but not with themselves (the diagonal scores were 0 and 0). Token 3’s query aligns equally with all keys.

The output for each token is a blend of all value vectors, weighted by these attention patterns. Token 1’s output is dominated by token 2’s value. Token 3’s output is the average of all values.

9.4 What self-attention learns

Self-attention learns to compute useful relationships between tokens. The projection matrices \(\mathbf{W}^Q\), \(\mathbf{W}^K\), \(\mathbf{W}^V\) determine what aspects of tokens are compared and what information is gathered. Through training, these matrices evolve to support patterns the model needs for its task.

What kinds of patterns emerge? Self-attention can learn syntactic relationships: a verb might learn to attend to its subject and object, so in “The cat ate the fish,” the query for “ate” strongly matches the keys for “cat” and “fish.” It can learn coreference: pronouns attend to their antecedents, so in “John said he was tired,” “he” attends strongly to “John.”

Unlike RNNs, self-attention naturally handles long-range dependencies. In “The cat that I saw yesterday at the park was black,” the word “was” can directly attend to “cat” without information passing through all the intervening tokens. The path length between any two positions is just one attention step, not proportional to their distance.

Self-attention also supports copying and retrieval: when the model needs to copy information from one position to another, it simply puts high attention weight on the source position. The output then strongly reflects the source’s value. This is useful for tasks like question answering, where the answer often appears verbatim in the context.

9.5 The attention matrix

The attention matrix \(\mathbf{A} \in \mathbb{R}^{n \times n}\) is central to understanding self-attention. Each entry \(a_{ij}\) represents how much position \(i\) attends to position \(j\). Let’s look at a concrete example for the sentence “The cat sat”:

The cat sat
The 0.70 0.15 0.15
cat 0.10 0.60 0.30
sat 0.05 0.75 0.20

Reading this table: each row shows how one token distributes its attention. The row for “sat” shows that when computing the new representation for “sat”, the model pulls 5% from “The”, 75% from “cat”, and 20% from itself. This makes sense: “sat” is a verb looking for its subject, and “cat” is the subject.

Notice the asymmetry: “sat” attends strongly to “cat” (0.75), but “cat” attends only moderately to “sat” (0.30). The attention relationship is not mutual.

Let’s examine the general properties of attention matrices.

Each row sums to 1 because softmax produces a probability distribution. Position \(i\) distributes its attention across all positions, and the total weight must be 1. All entries are non-negative (softmax outputs are always positive), so attention weights can be interpreted as probabilities.

The matrix is generally not symmetric: \(a_{ij} \neq a_{ji}\). How much position \(i\) attends to position \(j\) can differ from how much \(j\) attends to \(i\). In “The cat sat,” the verb “sat” might strongly attend to “cat” (looking for its subject), but “cat” might not strongly attend to “sat” (nouns don’t typically search for their verbs).

The diagonal entries are often significant. Tokens can attend to themselves, and often do. A token’s own value is frequently relevant to its output. However, the diagonal is learned, not special; if self-attention isn’t useful, the model can learn to ignore it.

Trained models typically develop sparse attention patterns where most weights are near zero and a few dominate. This suggests the model learns to focus on a small number of relevant positions rather than spreading attention uniformly. Different attention heads (covered next chapter) often specialize in different patterns.

9.6 Computational complexity

Computational complexity measures how the number of operations grows as the input size increases. We express this using big-O notation: \(O(n)\) means operations grow linearly with input size, \(O(n^2)\) means they grow quadratically, and so on. This matters because it determines whether an algorithm is practical for large inputs. An \(O(n)\) algorithm that processes a sequence of 1,000 tokens can likely handle 10,000 tokens (10× more work). An \(O(n^2)\) algorithm would require 100× more work.

Self-attention has complexity \(O(n^2 \cdot d)\), where \(n\) is the sequence length and \(d\) is the dimension.

The dominant cost comes from the matrix multiplications. Computing \(\mathbf{Q}\mathbf{K}^T\) multiplies an \(n \times d\) matrix by a \(d \times n\) matrix, producing an \(n \times n\) score matrix. This takes \(O(n^2 \cdot d)\) operations. The softmax takes \(O(n^2)\) to process all entries. Computing \(\mathbf{A}\mathbf{V}\) multiplies the \(n \times n\) attention matrix by the \(n \times d\) value matrix, another \(O(n^2 \cdot d)\) operations.

The \(n^2\) term is significant. For sequence length \(n = 1{,}000\), we compute and store 1 million attention scores. For \(n = 10{,}000\), it’s 100 million. This quadratic scaling limits how long sequences transformers can efficiently process. It’s also why early GPT models had context windows of only 1,024 or 2,048 tokens.

Various “efficient attention” methods reduce this complexity through approximations (linear attention), sparsity patterns (sparse transformers), or reformulations (FlashAttention). But the basic transformer uses full quadratic attention, and understanding it is essential before exploring optimizations.

9.7 Self-attention vs. recurrence

How does self-attention compare to RNNs for sequence modeling? The table below summarizes the key differences:

Aspect Self-Attention RNN
Complexity per layer \(O(n^2 \cdot d)\) \(O(n \cdot d^2)\)
Sequential operations \(O(1)\) \(O(n)\)
Maximum path length \(O(1)\) \(O(n)\)
Parallelizable Yes No

The path length measures how many steps information must travel between two positions. In an RNN, information from position 1 must pass through all intermediate hidden states to reach position \(n\), a path of length \(n-1\). Each step applies a transformation, and information can be lost or distorted along the way. In self-attention, any position can directly attend to any other in one step, regardless of distance.

Short paths have two benefits. First, gradients flow more easily during training. In an RNN, gradients must backpropagate through \(n\) steps, risking vanishing or exploding. In self-attention, gradients flow directly between any two positions. Second, information is less likely to be lost. The “telephone game” effect, where information degrades as it passes through many transformations, is avoided.

Parallelization is the other huge advantage. RNNs must compute sequentially because \(\mathbf{h}_t\) depends on \(\mathbf{h}_{t-1}\). Self-attention computes all positions simultaneously. On modern GPUs that excel at parallel matrix operations, this makes self-attention dramatically faster to train. A sequence of length 1,000 requires 1,000 sequential steps in an RNN but just one parallel operation in self-attention.

9.8 Adding nonlinearity

Self-attention computes weighted averages of value vectors. This is fundamentally a linear operation: if you double all the inputs, the outputs double. While softmax adds nonlinearity in computing the attention weights, the final combination \(\sum_j a_{ij} \mathbf{v}_j\) is just a weighted sum. Why is this a problem?

9.8.1 The limitation of linear functions

A function \(f\) is linear if it satisfies two properties: \(f(a\mathbf{x} + b\mathbf{y}) = af(\mathbf{x}) + bf(\mathbf{y})\) for any scalars \(a, b\) and vectors \(\mathbf{x}, \mathbf{y}\). Matrix multiplication is linear: \(\mathbf{W}(a\mathbf{x} + b\mathbf{y}) = a\mathbf{W}\mathbf{x} + b\mathbf{W}\mathbf{y}\).

The critical limitation: composing linear functions gives another linear function. If \(f\) and \(g\) are linear, then \(f(g(\mathbf{x}))\) is also linear. Mathematically, if \(f(\mathbf{x}) = \mathbf{W}_2\mathbf{x}\) and \(g(\mathbf{x}) = \mathbf{W}_1\mathbf{x}\), then:

\[ f(g(\mathbf{x})) = \mathbf{W}_2(\mathbf{W}_1\mathbf{x}) = (\mathbf{W}_2\mathbf{W}_1)\mathbf{x} = \mathbf{W}_3\mathbf{x} \]

where \(\mathbf{W}_3 = \mathbf{W}_2\mathbf{W}_1\) is just another matrix. No matter how many linear layers you stack, the result is equivalent to a single linear layer. A 100-layer linear network has the same representational power as a 1-layer linear network.

Linear functions can only learn linear relationships. They can’t learn “if this AND that” or “if this OR that” or “if this is greater than a threshold.” Consider a classic example: a hallway light controlled by two switches (one at each end). The light turns on when exactly one switch is flipped.

Figure 9.1: Left: Linear function fails to separate the two classes. Right: A nonlinear (curved) boundary successfully separates them

The four corners represent all switch combinations. When both switches are off (0,0) or both are on (1,1), the light is off (black circles). When exactly one switch is on (0,1) or (1,0), the light is on (white circles). This creates a diagonal pattern.

On the left, we try a straight line (linear function) but it can’t separate “light on” from “light off”. The dashed line fails to keep the black and white points on opposite sides. On the right, a curved boundary (nonlinear function) succeeds. The curve wraps around the two white “on” points, separating them from the black “off” corners. Linear functions can’t draw curves, so they can’t solve this problem. To solve problems like this, we need nonlinearity.

9.8.2 What is ReLU?

The Rectified Linear Unit (ReLU) is the most common nonlinearity in modern neural networks. It’s defined as:

\[ \text{ReLU}(x) = \max(0, x) = \begin{cases} x & \text{if } x > 0 \\ 0 & \text{if } x \leq 0 \end{cases} \]

For a vector \(\mathbf{x} = [x_1, x_2, \ldots, x_d]\), we apply ReLU element-wise:

\[ \text{ReLU}(\mathbf{x}) = [\max(0, x_1), \max(0, x_2), \ldots, \max(0, x_d)] \]

ReLU is simple to compute: if the input is positive, pass it through unchanged. If negative, replace it with zero.

Example computation:

\[ \text{ReLU}([2, -1, 0, -3, 5]) = [2, 0, 0, 0, 5] \]

9.8.3 How ReLU introduces nonlinearity

Let’s understand why ReLU is nonlinear and how this simple function enables complex learning. Consider a linear function \(f(x) = 2x\):

Figure 9.2: Linear function f(x) = 2x showing a straight line passing through the origin

This satisfies the linearity property: \(f(2x) = 2f(x)\), and \(f(x+y) = f(x) + f(y)\). No matter what inputs you give it, the output is always proportional.

Now consider ReLU applied to the same input: \(g(x) = \text{ReLU}(2x)\):

Figure 9.3: ReLU function showing nonlinearity with a bend at x=0, staying flat for negative x

The key difference: for \(x < 0\), the output is 0 instead of negative. This breaks linearity. We can verify: \(g(-1) = 0\) and \(g(1) = 2\), so \(g(-1) + g(1) = 2\). But \(g(-1 + 1) = g(0) = 0 \neq 2\). The linearity property is violated.

9.8.4 Why the bend matters

The bend at \(x = 0\) divides the input space into two regions with different behavior:

  • Region 1 (\(x < 0\)): Output is always 0, no matter how negative \(x\) gets. The function is “turned off.”
  • Region 2 (\(x > 0\)): Output equals input. The function is “turned on” and passes values through.

This creates a threshold behavior: below zero, nothing happens; above zero, the function activates. This is fundamentally different from a line, which has the same behavior everywhere.

9.8.5 How multiple ReLUs create complex boundaries

A single ReLU creates one bend, one threshold. But what happens when you combine many ReLUs? You get many bends, many thresholds, and together they can approximate any shape.

Consider a 2D input \(\mathbf{x} = [x_1, x_2]\) passing through a layer with multiple ReLU units. Each unit computes:

\[ h_i = \text{ReLU}(\mathbf{w}_i^T \mathbf{x} + b_i) \]

where \(\mathbf{w}_i\) is a weight vector and \(b_i\) is a bias. Geometrically, \(\mathbf{w}_i^T \mathbf{x} + b_i = 0\) defines a line in 2D space. The ReLU turns this into a decision: on one side of the line, \(h_i = 0\) (off); on the other side, \(h_i\) is positive (on).

With \(n\) ReLU units, you have \(n\) different lines dividing the space into regions. Each region has a different pattern of which units are on vs. off. A subsequent layer can combine these regions to create complex decision boundaries.

Example: To solve the two-switch problem (XOR), a simple network might work like this:

  1. First ReLU unit: Detects “Switch A is on” (region where \(x_1 > 0.5\))
  2. Second ReLU unit: Detects “Switch B is on” (region where \(x_2 > 0.5\))
  3. Third ReLU unit: Detects “both switches are on” (region where \(x_1 + x_2 > 1.5\))
  4. Output layer: Combine these: “light on” = (Switch A on) + (Switch B on) - 2×(both on)

This gives: - (0,0): No units fire, output = 0 (off) - (0,1): Unit 2 fires, output = 1 (on) - (1,0): Unit 1 fires, output = 1 (on) - (1,1): Units 1, 2, 3 fire, output = 1 + 1 - 2 = 0 (off)

Each ReLU creates a region, and combining regions lets us draw the curved boundary we saw earlier.

9.8.6 Piecewise linear approximation

From another angle, functions with ReLU are piecewise linear: they’re made of straight line segments joined at bends. Between bends, the function is linear. But the bends let the overall shape be nonlinear.

Think of approximating a smooth curve with straight line segments. With one segment, you can’t do much. With two segments (one bend), you can make a corner. With ten segments, you can approximate a gentle curve. With hundreds of segments, you can approximate almost any smooth function.

A deep network with ReLU is doing exactly this in high-dimensional space: creating many regions (separated by hyperplanes at the bends) and assigning each region a different linear function. The result looks smooth and nonlinear when you zoom out, even though it’s technically piecewise linear.

9.8.7 Why ReLU is simple but powerful

ReLU’s power comes from being just barely nonlinear. It’s the simplest possible nonlinearity: - Computation: Just max(0, x), no expensive operations like exponentials - Gradient: Either 0 or 1, very simple to compute during backpropagation - Unbounded above: Unlike sigmoid or tanh, ReLU doesn’t saturate for large positive values, avoiding vanishing gradients

Yet this simple “clip negatives to zero” operation is enough. With enough ReLU units and enough layers, the network can learn to approximate any continuous function. The key insight: you don’t need fancy nonlinearities. You just need some nonlinearity, and ReLU’s threshold behavior is sufficient.

9.8.8 Why project to higher dimensions?

The feedforward network in transformers doesn’t just apply ReLU. It projects to a higher dimension first:

\[ \text{FFN}(\mathbf{x}) = \mathbf{W}_2 \cdot \text{ReLU}(\mathbf{W}_1 \mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2 \]

where \(\mathbf{W}_1 \in \mathbb{R}^{d \times d_{ff}}\) projects from dimension \(d\) to a larger dimension \(d_{ff}\) (typically \(d_{ff} = 4d\)), and \(\mathbf{W}_2 \in \mathbb{R}^{d_{ff} \times d}\) projects back down.

Why go up and then back down? The intuition is geometric: higher dimensions give you more room to work.

Consider a factory quality control system inspecting manufactured parts on a conveyor belt. You measure two features for each part: size (in millimeters) and weight (in grams). Some parts are good, some are defective. The defective parts happen to fall along a diagonal pattern in 2D space, while good parts are at the corners.

Figure 9.4: Left: With only size and weight, no straight line separates good parts from defects. Right: Adding a derived feature (density) makes separation possible with a simple threshold

In 2D (left panel), using just size and weight, the defective parts (black) are mixed with good parts (white) in a way that no straight line can separate them. The defects fall along the diagonal where size and weight are proportional, while good parts are scattered.

But what if we add a third feature: density = weight / size? Now we’re in 3D (right panel). The good parts have low or high density (they’re either light for their size or heavy for their size), while defective parts have medium density (they fall in the “wrong” proportional range). Now a simple horizontal plane separates them: good parts below the threshold, defects above.

The problem became linearly separable by adding a dimension. We didn’t change the data, we just looked at it from a higher-dimensional perspective that revealed the underlying pattern.

This is what the FFN does: project to higher dimensions (via \(\mathbf{W}_1\)), apply nonlinearity (ReLU creates thresholds), then project back (via \(\mathbf{W}_2\)). The higher-dimensional space gives more “degrees of freedom” to represent complex patterns.

Think of it like this: in 2D, you can only draw straight lines. In 3D, you can draw planes. In 4D, you can draw 3D volumes. In 2048D (a typical \(d_{ff}\) for transformers), you have enormous flexibility to separate and organize data. The network learns to use these extra dimensions to compute features that aren’t easily expressible in the original space, then combines them back into a useful representation.

9.8.9 Why projecting back down doesn’t lose what we learned

But wait - if we can separate things in high dimensions, why doesn’t projecting back down collapse everything and lose the separation? This seems paradoxical.

The key insight: we’re not projecting back to the original space. We’re creating a new lower-dimensional representation that encodes what we learned in the high-dimensional space.

Let’s continue the factory example. After computing density in 3D and separating good parts from defects, the projection back doesn’t just throw away the density information. Instead, it might create a new 2D output like:

\[\begin{align} \text{quality score} &= 0.3 \cdot \text{size} + 0.2 \cdot \text{weight} + 0.8 \cdot \text{ReLU}(\text{density} - 1.5) \\ \text{defect probability} &= -0.1 \cdot \text{size} - 0.1 \cdot \text{weight} + 0.9 \cdot \text{ReLU}(\text{density} - 1.5) \end{align}\]

Notice what happened: the high-dimensional feature (density) got compressed into the output, but the decision based on density is preserved. The ReLU created a threshold at density = 1.5. After ReLU: - Good parts (low density): ReLU output = 0, so they get low defect probability - Defects (high density): ReLU output = positive, so they get high defect probability

The projection back (\(\mathbf{W}_2\)) is learning how to combine the thresholded features. It’s not reversing the projection up - it’s a completely different linear transformation that says “here’s how to weight these high-dimensional insights when summarizing back to lower dimensions.”

9.8.10 The projection cycle creates new features

Think of the process as:

  1. Project up (\(\mathbf{W}_1\)): “Let me look at the input from many angles.” Each dimension in the high-dimensional space is asking a different question about the input. One dimension might compute “is size proportional to weight?”, another “is weight greater than 2?”, another “is size times weight large?”, etc.

  2. Apply ReLU: “Decide which angles matter right now.” For this particular input, some questions have positive answers (kept), others negative (zeroed out). This is selecting which features are relevant.

  3. Project down (\(\mathbf{W}_2\)): “Combine the answers that survived into a useful summary.” The output isn’t the original input - it’s a summary of which high-dimensional features were active.

Here’s a concrete toy example. Suppose \(\mathbf{x} = [3, 4]\) (size=3, weight=4), and we project to 3D with:

\[ \mathbf{h} = \text{ReLU}(\mathbf{W}_1 \mathbf{x}) = \text{ReLU}\left(\begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0.5 & -0.5 \end{bmatrix} \begin{bmatrix} 3 \\ 4 \end{bmatrix}\right) = \text{ReLU}\left(\begin{bmatrix} 3 \\ 4 \\ -0.5 \end{bmatrix}\right) = \begin{bmatrix} 3 \\ 4 \\ 0 \end{bmatrix} \]

The third dimension computed \(0.5 \times 3 - 0.5 \times 4 = -0.5\) (checking if size > weight), which is negative, so ReLU killed it. Now project back:

\[ \mathbf{y} = \mathbf{W}_2 \mathbf{h} = \begin{bmatrix} 1 & 0 & 0.5 \\ 0 & 1 & -0.5 \end{bmatrix} \begin{bmatrix} 3 \\ 4 \\ 0 \end{bmatrix} = \begin{bmatrix} 3 \\ 4 \end{bmatrix} \]

In this case, the output is the same as input because the third feature wasn’t active. But if we had \(\mathbf{x} = [4, 2]\) (size > weight):

\[ \mathbf{h} = \text{ReLU}\left(\begin{bmatrix} 4 \\ 2 \\ 1 \end{bmatrix}\right) = \begin{bmatrix} 4 \\ 2 \\ 1 \end{bmatrix} \]

Now the third feature is active (size > weight). Projecting back:

\[ \mathbf{y} = \begin{bmatrix} 1 & 0 & 0.5 \\ 0 & 1 & -0.5 \end{bmatrix} \begin{bmatrix} 4 \\ 2 \\ 1 \end{bmatrix} = \begin{bmatrix} 4.5 \\ 1.5 \end{bmatrix} \]

The output is different from the input! The network has encoded “this part has size > weight” by boosting the first dimension and reducing the second. The output is a new representation that encodes what the network learned by going to higher dimensions.

9.8.11 Why this is powerful

The FFN is computing derived features in high dimensions and encoding their activation patterns in the output. With 2048 dimensions (typical \(d_{ff}\)), you can compute 2048 different features, threshold them with ReLU, and then the output encodes “feature 17 was active, feature 203 was active, feature 1842 was inactive, …” in a compressed form.

The network learns through training: - Which features to compute (what \(\mathbf{W}_1\) does) - Which features matter for the task (what survives ReLU) - How to combine active features into outputs (what \(\mathbf{W}_2\) does)

So projecting back down doesn’t lose information - it compresses the high-dimensional decisions into a lower-dimensional representation that’s optimized for the task. It’s like saying “I examined 2048 different properties of the input, and here’s a summary of what I found.”

9.8.12 The complete feedforward network

Putting it together, the position-wise feedforward network is:

\[ \text{FFN}(\mathbf{x}) = \mathbf{W}_2 \cdot \text{ReLU}(\mathbf{W}_1 \mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2 \]

where: - \(\mathbf{x} \in \mathbb{R}^d\) is the input (from self-attention) - \(\mathbf{W}_1 \in \mathbb{R}^{d \times d_{ff}}\) projects to dimension \(d_{ff}\) (typically \(4d\)) - \(\mathbf{b}_1 \in \mathbb{R}^{d_{ff}}\) is a bias (shifts the activation) - \(\text{ReLU}\) zeros out negative values - \(\mathbf{W}_2 \in \mathbb{R}^{d_{ff} \times d}\) projects back to dimension \(d\) - \(\mathbf{b}_2 \in \mathbb{R}^d\) is the final bias

This is applied independently to each position. It doesn’t mix information across positions (self-attention does that). Instead, it transforms each position’s representation in a complex, nonlinear way.

The combination of self-attention (mixing information across positions) and FFN (nonlinear transformation at each position) gives transformers their power. Self-attention handles “communication”: which positions should influence each other. FFN handles “computation”: what complex features should be computed from the gathered information.

9.8.13 What the output space looks like

Let’s visualize exactly what happens to the coordinates. The diagram below shows how the FFN transforms a 2D input space into a different 2D output space:

Figure 9.5: The FFN transforms input coordinates (size, weight) into new output coordinates (Feature A, Feature B) where the problem becomes linearly separable

Left panel (Input space): The x-axis is “Size” and y-axis is “Weight”, the original measurements. The five points are mixed: defective parts (black) fall on a diagonal, good parts (white) are scattered. No straight line can separate them.

Right panel (Output space): After passing through the FFN, we’re in a completely different 2D space. The x-axis is now “Feature A” and y-axis is “Feature B”. These aren’t size or weight anymore; they’re new derived features computed by the network. Look at what happened to the points:

  • Good parts moved to the bottom (low Feature B values: 0.4, 0.5, 0.6)
  • Defective parts moved to the top (high Feature B values: 1.8, 2.1)
  • A horizontal line at Feature B = 1.1 now cleanly separates them!

The coordinates literally changed. The input point at (1, 1) moved to approximately (1.3, 1.8) in the output space. The input point at (2.5, 1) moved to approximately (2.2, 0.5). The FFN didn’t just relabel the axes. It computed entirely new coordinates where separation is possible.

What are Feature A and Feature B mathematically? Let’s trace through the computation with a concrete example. Suppose we project to 3 dimensions (instead of 2048) with:

\[ \mathbf{W}_1 = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ -0.5 & 0.5 \end{bmatrix}, \quad \mathbf{b}_1 = \begin{bmatrix} 0 \\ 0 \\ 1.5 \end{bmatrix} \]

For input \(\mathbf{x} = [\text{size}, \text{weight}]^T\), the high-dimensional features are:

\[ \mathbf{h} = \text{ReLU}\left(\begin{bmatrix} 1 & 0 \\ 0 & 1 \\ -0.5 & 0.5 \end{bmatrix} \begin{bmatrix} \text{size} \\ \text{weight} \end{bmatrix} + \begin{bmatrix} 0 \\ 0 \\ 1.5 \end{bmatrix}\right) = \text{ReLU}\left(\begin{bmatrix} \text{size} \\ \text{weight} \\ -0.5 \cdot \text{size} + 0.5 \cdot \text{weight} + 1.5 \end{bmatrix}\right) \]

So the three high-dimensional features before ReLU are: - \(h_1 = \text{size}\) (always positive, passes through ReLU) - \(h_2 = \text{weight}\) (always positive, passes through ReLU) - \(h_3 = 1.5 + 0.5(\text{weight} - \text{size})\) (measures if weight > size, plus a shift)

After ReLU, \(h_3\) is only positive when \(\text{weight} > \text{size} - 3\). Now project back down:

\[ \begin{bmatrix} \text{Feature A} \\ \text{Feature B} \end{bmatrix} = \mathbf{W}_2 \mathbf{h} = \begin{bmatrix} 0.3 & 0.2 & 0.1 \\ 0.1 & 0.1 & 0.8 \end{bmatrix} \begin{bmatrix} h_1 \\ h_2 \\ h_3 \end{bmatrix} \]

Therefore: - Feature A \(= 0.3 \cdot h_1 + 0.2 \cdot h_2 + 0.1 \cdot h_3\) - Feature B \(= 0.1 \cdot h_1 + 0.1 \cdot h_2 + 0.8 \cdot h_3\)

Feature B heavily weights \(h_3\) (the “weight > size” detector). When \(h_3 = 0\) (good parts), Feature B is low. When \(h_3\) is large (defects), Feature B is high. That’s why the horizontal line at Feature B = 1.1 separates them!

So yes, Feature A and Feature B are linear combinations of the high-dimensional features \(h_1, h_2, h_3\). But those high-dimensional features themselves are thresholded (ReLU) linear combinations of the inputs. The full formula is:

\[ \begin{bmatrix} \text{Feature A} \\ \text{Feature B} \end{bmatrix} = \mathbf{W}_2 \cdot \text{ReLU}(\mathbf{W}_1 \begin{bmatrix} \text{size} \\ \text{weight} \end{bmatrix} + \mathbf{b}_1) + \mathbf{b}_2 \]

This is not a simple linear function of size and weight because of the ReLU. If it were just \(\mathbf{W}_2 \mathbf{W}_1 \mathbf{x}\), that would collapse to a single linear transformation. The ReLU creates the nonlinearity: different regions of input space activate different combinations of the \(h_i\) features, leading to different linear mappings in different regions. That’s how the FFN creates the curved decision boundaries we need.

This is why going to higher dimensions and back is powerful: the network has room to compute many candidate features (2048 in typical transformers), threshold them with ReLU, and then \(\mathbf{W}_2\) learns to combine the useful ones into output coordinates that make the problem easier for the next layer.

9.9 Residual connections and layer normalization

We’ve built powerful components: self-attention mixes information across positions, and feedforward networks compute complex nonlinear transformations. But when we stack these operations into deep networks, we encounter severe training problems. Two techniques solve these problems: residual connections and layer normalization.

9.9.1 The vanishing gradient problem

To understand why we need residual connections, we need to understand the vanishing gradient problem. When training deep networks, we update weights by computing gradients via backpropagation. The gradient tells us how to adjust each weight to reduce the loss.

Consider a simple deep network where each layer multiplies by a weight matrix \(\mathbf{W}\):

\[ \mathbf{y} = \mathbf{W}_L \cdot \mathbf{W}_{L-1} \cdots \mathbf{W}_2 \cdot \mathbf{W}_1 \mathbf{x} \]

To update \(\mathbf{W}_1\) (the first layer), we need \(\frac{\partial \mathcal{L}}{\partial \mathbf{W}_1}\) where \(\mathcal{L}\) is the loss. By the chain rule:

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{W}_1} = \frac{\partial \mathcal{L}}{\partial \mathbf{y}} \cdot \frac{\partial \mathbf{y}}{\partial \mathbf{W}_L} \cdots \frac{\partial \mathbf{y}}{\partial \mathbf{W}_2} \cdot \frac{\partial \mathbf{y}}{\partial \mathbf{W}_1} \]

Each derivative depends on the weight matrices. If the weights are initialized with small values (common practice), each derivative term might be less than 1. Multiplying many numbers less than 1 produces exponentially small results.

Concrete example: Suppose each derivative is 0.5, and we have 10 layers. The gradient reaching the first layer is proportional to \(0.5^{10} \approx 0.001\). With 20 layers, it becomes \(0.5^{20} \approx 0.000001\). The gradient becomes so small it’s essentially zero, and the first layer stops learning. This is the vanishing gradient.

The opposite problem, exploding gradients, occurs when derivatives are larger than 1. Then \(1.5^{10} \approx 57\) or \(1.5^{20} \approx 3325\), and gradients become enormous. Updates are unstable, weights shoot to infinity, and training collapses.

Both problems worsen with depth. A 5-layer network might train fine, but a 50-layer network becomes impossible without special techniques. This was a major barrier to deep learning until residual connections were invented.

9.9.2 Why residual connections solve this

Residual connections (also called skip connections) add the input directly to the output:

\[ \mathbf{X}' = \mathbf{X} + \text{SelfAttention}(\mathbf{X}) \]

Instead of \(\mathbf{X}' = f(\mathbf{X})\), we compute \(\mathbf{X}' = \mathbf{X} + f(\mathbf{X})\). The “+\(\mathbf{X}\)” creates a direct path from input to output that bypasses the function \(f\).

Why does this help? Look at the gradient. Using the chain rule:

\[ \frac{\partial \mathbf{X}'}{\partial \mathbf{X}} = \frac{\partial}{\partial \mathbf{X}}[\mathbf{X} + f(\mathbf{X})] = \mathbf{I} + \frac{\partial f(\mathbf{X})}{\partial \mathbf{X}} \]

The derivative is the identity matrix \(\mathbf{I}\) plus something extra. Even if \(\frac{\partial f}{\partial \mathbf{X}}\) vanishes (goes to zero), the gradient is still \(\mathbf{I}\). The gradient can flow backward through the “\(+\mathbf{X}\)” path without any multiplicative degradation.

The gradient highway: This direct path is called a “gradient highway” or “shortcut connection.” Gradients from later layers can flow directly back to earlier layers without passing through multiple matrix multiplications. Compare:

  • Without residual: Gradient passes through \(\mathbf{W}_L, \mathbf{W}_{L-1}, \ldots, \mathbf{W}_1\), multiplying many matrices.
  • With residual: Gradient can flow through the shortcut path with no multiplication at all, just addition.

This is revolutionary. We can now train networks with hundreds of layers because gradients reliably reach the early layers. Residual connections (ResNets) enabled the deep learning revolution in computer vision (2015) and transformers in NLP (2017).

9.9.3 How residual connections work mechanically

Let’s trace through a concrete example. Suppose we have a 3-token sequence with 4-dimensional embeddings:

\[ \mathbf{X} = \begin{bmatrix} 1.0 & 0.5 & -0.5 & 0.2 \\ 0.8 & -0.3 & 0.6 & 0.1 \\ -0.2 & 0.7 & 0.3 & -0.4 \end{bmatrix} \]

After self-attention, we get an output \(\mathbf{O}\) (the weighted combination of value vectors):

\[ \mathbf{O} = \text{SelfAttention}(\mathbf{X}) = \begin{bmatrix} 0.2 & 0.1 & 0.3 & -0.1 \\ 0.1 & 0.2 & -0.2 & 0.1 \\ 0.3 & -0.1 & 0.1 & 0.2 \end{bmatrix} \]

Without residual connection, the output would just be \(\mathbf{O}\). We’d completely replace the input with the self-attention result. If self-attention makes a mistake or produces small values, the original information is lost.

With residual connection, we compute:

\[ \mathbf{X}' = \mathbf{X} + \mathbf{O} \]

\[ = \begin{bmatrix} 1.0 & 0.5 & -0.5 & 0.2 \\ 0.8 & -0.3 & 0.6 & 0.1 \\ -0.2 & 0.7 & 0.3 & -0.4 \end{bmatrix} + \begin{bmatrix} 0.2 & 0.1 & 0.3 & -0.1 \\ 0.1 & 0.2 & -0.2 & 0.1 \\ 0.3 & -0.1 & 0.1 & 0.2 \end{bmatrix} = \begin{bmatrix} 1.2 & 0.6 & -0.2 & 0.1 \\ 0.9 & -0.1 & 0.4 & 0.2 \\ 0.1 & 0.6 & 0.4 & -0.2 \end{bmatrix} \]

Notice what happened: - The original values (like 1.0, 0.8, -0.2 in the first column) are still present, just modified - Self-attention made small adjustments: +0.2, +0.1, +0.3 in the first row - The output is the original plus a modification, not a replacement

This has several benefits:

  1. Preserves information: The original embedding is always present. Even if self-attention fails completely (outputs zero), we still have \(\mathbf{X}' = \mathbf{X}\).

  2. Easier learning: The network learns to compute modifications rather than new representations from scratch. It’s easier to learn “add 0.2 to this feature” than “compute the right value from scratch.”

  3. Graceful initialization: At initialization, when weights are random, \(\text{SelfAttention}(\mathbf{X})\) might produce garbage. But with residual connections, we get \(\mathbf{X}' \approx \mathbf{X} + \text{small noise}\), and the network starts as approximately an identity function. Training gradually learns useful modifications.

9.9.4 Why layer normalization is needed

Even with residual connections, we have another problem: activation explosion or vanishing. As signals pass through many layers, the scale of values can drift. Some features might grow to huge values (like 1000), others might shrink to tiny values (like 0.001). This makes training unstable.

Consider what happens after several layers. If each layer adds a small positive value on average, values grow:

  • After layer 1: mean value = 1.0
  • After layer 10: mean value = 5.0
  • After layer 50: mean value = 50.0

Conversely, if values get smaller on average:

  • After layer 1: mean value = 1.0
  • After layer 10: mean value = 0.1
  • After layer 50: mean value = 0.0001

Either scenario is bad. Large values cause gradients to explode (tiny weight changes cause huge output changes). Small values cause gradients to vanish (the signal becomes noise).

Layer normalization solves this by normalizing the scale of features after each layer. It ensures that regardless of what happened in previous layers, the output has a consistent mean and standard deviation.

9.9.5 How layer normalization works

Layer normalization acts on the vector representation of a single token independently. For a given position’s vector \(\mathbf{x} \in \mathbb{R}^d\) (containing \(d\) features), the operation consists of two main steps: normalization and affine transformation.

The full formula is:

\[ \text{LayerNorm}(\mathbf{x}) = \underbrace{\gamma \odot \frac{\mathbf{x} - \mu}{\sigma + \epsilon}}_{\text{Transformation}} + \beta \]

Let’s break this down:

  1. Normalization: The fraction \(\frac{\mathbf{x} - \mu}{\sigma + \epsilon}\) standardizes the vector features.

    • Center (\(\mathbf{x} - \mu\)): We subtract the mean \(\mu\) of the features. This ensures the vector is centered around zero.
    • Scale (Divide by \(\sigma\)): We divide by the standard deviation \(\sigma\). This ensures the values have a spread (variance) of 1.
    • Stability (\(\epsilon\)): We add a tiny number \(\epsilon\) (e.g., \(10^{-5}\)) to the denominator to prevent division by zero if the variance is 0.

    \[ \mu = \frac{1}{d}\sum_{i=1}^d x_i, \quad \sigma = \sqrt{\frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2} \]

  2. Affine Transformation: Forcing every vector to have exactly mean 0 and variance 1 limits the network’s expressiveness. To fix this, we introduce two learnable parameters per feature dimensions:

    • Scale (\(\gamma\)): A learned multiplier that can expand or shrink the range.
    • Shift (\(\beta\)): A learned bias that can shift the center.

    These parameters (\(\gamma, \beta \in \mathbb{R}^d\)) allow the model to learn the optimal distribution for the features, effectively “undoing” the normalization if necessary, but starting from a stable baseline.

Key intuition: Unlike Batch Normalization (which uses statistics across a batch of samples), Layer Normalization computes \(\mu\) and \(\sigma\) using only the features within the single vector \(\mathbf{x}\) itself. This makes it independent of batch size and perfect for sequence models like Transformers.

The process summary: 1. Center: Subtract the mean \(\mu\), so the values have mean 0 2. Scale: Divide by standard deviation \(\sigma\), so the values have variance 1 3. Re-scale and re-shift: Multiply by learned \(\gamma\) and add learned \(\beta\)

Concrete example: Suppose after self-attention + residual, one position has:

\[ \mathbf{x} = [1.2, 0.6, -0.2, 0.1] \]

Step 1: Compute mean and standard deviation:

\[ \mu = \frac{1.2 + 0.6 + (-0.2) + 0.1}{4} = \frac{1.7}{4} = 0.425 \]

\[ \sigma = \sqrt{\frac{(1.2-0.425)^2 + (0.6-0.425)^2 + (-0.2-0.425)^2 + (0.1-0.425)^2}{4}} \]

\[ = \sqrt{\frac{0.601 + 0.031 + 0.391 + 0.106}{4}} = \sqrt{\frac{1.129}{4}} = \sqrt{0.282} \approx 0.531 \]

Step 2: Normalize:

\[ \hat{\mathbf{x}} = \frac{\mathbf{x} - \mu}{\sigma} = \frac{[1.2, 0.6, -0.2, 0.1] - 0.425}{0.531} = \frac{[0.775, 0.175, -0.625, -0.325]}{0.531} \]

\[ \approx [1.46, 0.33, -1.18, -0.61] \]

Notice: the normalized values have mean 0 and standard deviation 1 (you can verify: \((1.46 + 0.33 - 1.18 - 0.61)/4 = 0\)).

Step 3: Apply learned parameters. Suppose \(\gamma = [1, 1, 1, 1]\) and \(\beta = [0, 0, 0, 0]\) (initialization values). Then:

\[ \text{LayerNorm}(\mathbf{x}) = \gamma \odot \hat{\mathbf{x}} + \beta = [1.46, 0.33, -1.18, -0.61] \]

During training, \(\gamma\) and \(\beta\) might learn to be different values (like \(\gamma = [2, 1, 0.5, 1]\)) if the network finds that helpful. But the normalization ensures values don’t explode or vanish.

9.9.6 When and where these techniques are applied

In a transformer layer, residual connections and layer normalization are applied after both self-attention and the feedforward network. Here’s the complete flow:

Step 1: Start with input embeddings \(\mathbf{X}_0\) (dimension \(n \times d\))

Step 2: Self-attention with residual connection:

\[ \mathbf{X}_1 = \mathbf{X}_0 + \text{SelfAttention}(\mathbf{X}_0) \]

Step 3: Apply layer normalization:

\[ \mathbf{X}_2 = \text{LayerNorm}(\mathbf{X}_1) \]

Step 4: Feedforward network with residual connection:

\[ \mathbf{X}_3 = \mathbf{X}_2 + \text{FFN}(\mathbf{X}_2) \]

Step 5: Apply layer normalization:

\[ \mathbf{X}_4 = \text{LayerNorm}(\mathbf{X}_3) \]

This sequence is repeated in every transformer layer. The output \(\mathbf{X}_4\) from one layer becomes the input \(\mathbf{X}_0\) to the next layer.

This arrangement is called post-norm because layer normalization comes after the residual addition. Some modern transformers (like GPT-2/GPT-3) use pre-norm, where layer normalization is applied before each sublayer:

\[ \mathbf{X}_1 = \mathbf{X}_0 + \text{SelfAttention}(\text{LayerNorm}(\mathbf{X}_0)) \]

\[ \mathbf{X}_2 = \mathbf{X}_1 + \text{FFN}(\text{LayerNorm}(\mathbf{X}_1)) \]

Pre-norm is often easier to train for very deep networks (50+ layers) because it normalizes the input to each sublayer, preventing extreme values from entering the computation.

9.9.7 The complete transformer layer equation

Let’s integrate everything we’ve built in this chapter into one complete equation for a single transformer layer. Start with:

  • Input: \(\mathbf{X} \in \mathbb{R}^{n \times d}\) (sequence of \(n\) token embeddings, each \(d\)-dimensional)

Self-attention sublayer:

\[\begin{align} \mathbf{Q} &= \mathbf{X}\mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}^K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}^V \\ \mathbf{A} &= \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right) \\ \mathbf{O}_{\text{attn}} &= \mathbf{A}\mathbf{V} \\ \mathbf{X}^{(1)} &= \text{LayerNorm}(\mathbf{X} + \mathbf{O}_{\text{attn}}) \end{align}\]

Feedforward sublayer:

\[\begin{align} \mathbf{O}_{\text{ffn}} &= \mathbf{W}_2 \cdot \text{ReLU}(\mathbf{W}_1 \mathbf{X}^{(1)} + \mathbf{b}_1) + \mathbf{b}_2 \\ \mathbf{X}^{(2)} &= \text{LayerNorm}(\mathbf{X}^{(1)} + \mathbf{O}_{\text{ffn}}) \end{align}\]

Output: \(\mathbf{X}^{(2)} \in \mathbb{R}^{n \times d}\)

This is the complete computation for one transformer layer. Stack \(L\) of these layers (typically \(L = 6\) to \(L = 96\)):

\[ \mathbf{X}^{(0)} \xrightarrow{\text{Layer 1}} \mathbf{X}^{(1)} \xrightarrow{\text{Layer 2}} \mathbf{X}^{(2)} \xrightarrow{\text{Layer 3}} \cdots \xrightarrow{\text{Layer L}} \mathbf{X}^{(L)} \]

The final output \(\mathbf{X}^{(L)}\) is passed to a task-specific head (like a language model head for predicting the next token, or a classification head for sentiment analysis).

9.9.8 Connecting to the full picture

Let’s trace the complete flow from raw text to predictions:

1. Tokenization: Text “The cat sat” becomes token IDs [50, 123, 456]

2. Embedding: Each token ID is converted to a \(d\)-dimensional embedding, and positional encodings are added (discussed in detail in Chapter 11):

\[ \mathbf{X}^{(0)} = \text{TokenEmbedding}(\text{tokens}) + \text{PositionalEncoding}(\text{positions}) \]

3. Transformer layers: Apply \(L\) layers of self-attention + FFN + residuals + layer norms:

\[ \mathbf{X}^{(\ell)} = \text{TransformerLayer}_\ell(\mathbf{X}^{(\ell-1)}) \quad \text{for } \ell = 1, 2, \ldots, L \]

4. Output head: The final representation \(\mathbf{X}^{(L)}\) is passed to a task-specific head. For language modeling:

\[ \text{logits} = \mathbf{X}^{(L)} \mathbf{W}_{\text{vocab}} \]

where \(\mathbf{W}_{\text{vocab}} \in \mathbb{R}^{d \times V}\) projects to vocabulary size \(V\). Apply softmax to get probabilities over the vocabulary.

Each component plays a role: - Embeddings: Convert discrete tokens to continuous vectors - Positional encoding: Inject position information (covered in Chapter 11) - Self-attention: Mix information across positions based on relevance - FFN: Compute complex nonlinear transformations at each position - Residual connections: Enable gradient flow through deep networks - Layer normalization: Stabilize activation scales

Together, these components let transformers learn powerful representations from data. The residual connections and layer normalization are crucial glue that makes deep transformers trainable. Without them, we’d be stuck with shallow networks that couldn’t learn complex patterns.

9.10 Causal (masked) self-attention

For autoregressive language modeling - where the goal is to predict the next token given previous ones - standard self-attention has a critical flaw: it can see the future.

9.10.1 The problem: Cheating prevents learning

Imagine we are training a model to predict the next word in the sequence “The cat sat on the mat.” * Input: “The cat sat on the” * Target Output: “cat sat on the mat”

If we use standard self-attention, when the model is processing the word “sat” (position 3), it can attend to “on” (position 4) and “the” (position 5). The model essentially “sees” the answer key.

If the model can see the future tokens it is supposed to predict, it will learn a trivial shortcut: just copy the next token. It won’t learn grammar, logic, or world knowledge; it will just become a lookup table.

This is disastrous because at inference time (when we actually use the model to generate text), the future doesn’t exist yet. We generate one token at a time. If the model relies on seeing the future to make predictions, it will fail completely when that future is hidden.

To force the model to actually learn language structure, we must blind it to the future. When processing position \(i\), the model should only be allowed to access positions \(1, 2, \ldots, i\).

9.10.2 The solution: Masking

Causal self-attention (or masked self-attention) enforces this constraint by modifying the attention scores before the softmax step. We explicitly mask out any connection from a position to a future position.

Mathematically, we adjust the score matrix \(\mathbf{S}\):

\[ \mathbf{S}_{ij} = \begin{cases} \frac{\mathbf{q}_i^T \mathbf{k}_j}{\sqrt{d_k}} & \text{if } j \leq i \quad (\text{past and present}) \\ -\infty & \text{if } j > i \quad (\text{future}) \end{cases} \]

Why negative infinity? Recall that we pass these scores through a softmax function: \(\text{softmax}(x_i) = \frac{e^{x_i}}{\sum e^{x_k}}\). Since \(e^{-\infty} = 0\), setting a score to \(-\infty\) ensures that the resulting attention weight is exactly zero. The future token is effectively erased from the weighted sum.

9.10.3 Implementation

In practice, this is done efficiently using a mask matrix \(\mathbf{M}\). For a sequence of length 4, the mask looks like this:

\[ \mathbf{M} = \begin{bmatrix} 0 & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty \\ 0 & 0 & 0 & -\infty \\ 0 & 0 & 0 & 0 \end{bmatrix} \]

We compute \(\mathbf{A} = \text{softmax}(\mathbf{S} + \mathbf{M})\).

  • Row 1: Can only attend to column 1 (itself).
  • Row 2: Can attend to columns 1 and 2.
  • Row 4: Can attend to columns 1, 2, 3, and 4.

This produces a lower-triangular attention matrix. This structure allows us to train on all tokens in a sequence simultaneously (parallel training) while mathematically guaranteeing that no prediction depends on future information. This is the mechanism that powers GPT and other decoder-only architectures.

9.11 Summary

Self-attention is the core of transformers:

  • Queries, keys, and values all come from the same sequence
  • Each position can attend to every other position (or only earlier positions in causal attention)
  • Attention weights are computed via scaled dot products followed by softmax
  • The output is a weighted combination of values

Key properties:

  • \(O(n^2 \cdot d)\) complexity enables direct connections between all positions
  • Highly parallelizable, unlike sequential RNNs
  • Path length of 1 between any two positions
  • Learns to extract relevant relationships through training

Self-attention alone processes all positions with identical weights. In the next chapter, we’ll see how multi-head attention allows the model to attend to different aspects of the input simultaneously.