12  The transformer

NoteLearning objectives

After completing this chapter, you will be able to:

  • Trace information flow through the complete encoder-decoder transformer
  • Compute forward propagation with explicit matrix dimensions
  • Derive backward propagation gradients for all transformer components
  • Distinguish encoder-only, decoder-only, and encoder-decoder architectures
  • Specify the complete transformer with sufficient detail for implementation

We now mathematically derive the complete transformer architecture from first principles (Vaswani et al. 2017). This chapter presents every computational step in both forward and backward propagation, with explicit dimensions, indexing, and gradient computations.

12.1 Notation and hyperparameters

We work with the following standard transformer configuration:

  • \(V = 50000\): vocabulary size (number of unique tokens)
  • \(d_{model} = 512\): model dimension (dimensionality of all embeddings, hidden states, and layer outputs throughout the network)
  • \(d_k = d_v = 64\): key/value dimension per head (each attention head projects to this lower dimension, where \(d_k = d_{model} / h\))
  • \(d_{ff} = 2048\): feed-forward intermediate dimension (hidden layer size in the position-wise FFN)
  • \(h = 8\): number of attention heads (parallel attention mechanisms, where \(h \cdot d_k = d_{model}\))
  • \(N = 6\): number of encoder/decoder blocks (depth of the network)
  • \(n\): source sequence length (number of tokens in the input sequence)
  • \(m\): target sequence length (number of tokens in the output sequence)
  • \(\epsilon = 10^{-6}\): numerical stability constant (small value added in denominators to prevent division by zero)

Matrices are denoted with bold uppercase (\(\mathbf{W}\)), vectors with bold lowercase (\(\mathbf{x}\)), and scalars with regular font (\(x\)). We use 1-indexing for positions and dimension indices in mathematical expressions.

12.2 Architecture overview

This section builds a complete mental model of the transformer without mathematics. We describe what each component does and why, tracing information flow from input to output. The subsequent sections provide the precise mathematical formulation needed for implementation.

12.2.1 The big picture

The transformer solves sequence-to-sequence problems like machine translation. Given source sequence “The cat sat on the mat” in English, it produces target sequence “Le chat s’est assis sur le tapis” in French. The architecture has two main components working together: an encoder that understands the source sequence, and a decoder that generates the target sequence one token at a time.

The key innovation is attention: instead of compressing the entire input into a single fixed-size vector like recurrent networks do, the transformer lets every position directly access information from every other position. This eliminates the information bottleneck and enables parallel processing.

12.2.2 From discrete tokens to continuous vectors

Text arrives as discrete tokens (words or subwords). We need continuous vectors to perform meaningful computation. The embedding layer maps each token to a point in a high-dimensional space where similar meanings cluster together. The word “cat” might map to vector \([0.2, -0.5, 0.8, \ldots]\) with 512 dimensions. Words with similar meanings like “kitten” or “feline” map to nearby points in this space.

But there’s a problem: these embeddings contain no information about word order. The sequences “dog bites man” and “man bites dog” would have identical representations if we just summed their embeddings. We need to inject position information.

12.2.3 Positional encoding: giving order to chaos

Positional encoding adds a unique signature to each position in the sequence. Position 1 gets one pattern, position 2 gets a different pattern, and so on. These patterns are carefully designed using sine and cosine functions at different frequencies. Low frequencies change slowly across positions, encoding coarse position information. High frequencies change rapidly, encoding fine-grained distinctions between nearby positions.

We add (element-wise) these positional signatures to the token embeddings. Now the embedding for “cat” at position 2 is different from “cat” at position 5, even though they represent the same word. The model can use this positional information throughout all subsequent computations.

12.2.4 The encoder: building rich representations

The encoder transforms the input sequence into a rich, contextualized representation. Each position starts with its embedding plus positional encoding. Then we apply six identical blocks in sequence, with each block refining the representation.

Self-attention mechanism. Each encoder block begins with self-attention. This is where the magic happens. For each position, we ask: “Which other positions in the sequence are relevant for understanding this position?” The word “it” might strongly attend to “cat” from earlier in the sentence to resolve what “it” refers to. The verb “sat” might attend to its subject “cat” to verify subject-verb agreement.

Attention works by computing compatibility scores between positions. For position \(i\), we compute how well it matches with every position \(j\) in the sequence. These scores become weights in a weighted average: positions that are highly relevant contribute more to the output, irrelevant positions contribute little.

Multi-head attention. We don’t compute just one attention pattern. We compute eight parallel attention mechanisms (heads), each learning to capture different types of relationships. One head might specialize in syntactic dependencies (subject-verb relationships). Another might capture coreference (pronouns linking to their antecedents). Another might focus on semantic similarity. The model learns these specializations automatically during training.

Each head operates in a lower-dimensional space (64 dimensions instead of 512) for computational efficiency. After computing attention in parallel, we concatenate all heads back together and project to the original dimension. This gives us a rich representation that combines multiple perspectives on the input.

Feed-forward transformation. After attention gathers information from across the sequence, we apply a position-wise feed-forward network. This is a simple two-layer neural network with a ReLU activation, applied independently to each position. Think of this as giving the model capacity to perform complex nonlinear transformations on the mixed information from attention.

The feed-forward network first expands the representation to a higher dimension (2048), applies ReLU to introduce nonlinearity, then projects back down to the model dimension (512). This happens at each position independently, allowing the model to refine representations without mixing information across positions (attention already did that).

Residual connections and layer normalization. Both the attention and feed-forward sub-layers use residual connections: we add the input directly to the output before passing to the next layer. This creates shortcut paths for gradients during training, enabling very deep networks without vanishing gradients.

After each residual connection, we apply layer normalization, which standardizes the values across the embedding dimensions for each position independently. This keeps activations in a stable range, making training more reliable.

Stacking blocks. We apply six encoder blocks in sequence. Early blocks learn shallow patterns like syntax and local word relationships. Middle blocks capture more complex dependencies. Deep blocks learn high-level semantic relationships and long-range dependencies. The final encoder output contains deeply contextualized representations where each position has gathered relevant information from the entire sequence through multiple rounds of attention.

12.2.5 The decoder: generating the target sequence

The decoder generates the output sequence autoregressively, one token at a time. During training, we have the complete target sequence and process it in parallel with teacher forcing. During inference, we generate token by token, feeding each new token back as input for the next prediction.

Masked self-attention. The decoder’s first sub-layer is masked self-attention. Like encoder self-attention, this lets positions attend to each other. But there’s a crucial constraint: position \(i\) can only attend to positions \(1\) through \(i\), never to future positions. This causal masking ensures the model can’t “cheat” by looking ahead to tokens it’s supposed to predict.

We implement masking by setting attention scores to negative infinity for future positions. After softmax, these become zero weights, effectively blocking information flow from the future. This makes generation possible: at inference time, when we only have partial sequences, the model uses the same computation pattern it learned during training.

Cross-attention: connecting encoder and decoder. The decoder’s second sub-layer is where encoder and decoder communicate. In cross-attention, queries come from the decoder (what we’re currently generating), while keys and values come from the encoder output (the source sequence we’re translating from).

For each decoder position, cross-attention asks: “Which source positions are relevant for generating this target word?” When generating the French word “chat”, the decoder might strongly attend to the English word “cat”. When generating “assis” (sat), it attends to “sat”. The model learns these alignments automatically without explicit supervision.

Unlike self-attention, cross-attention isn’t masked. All decoder positions can attend to all encoder positions. The source sequence is fixed and fully available, so there’s no information leakage concern.

Feed-forward network. After gathering information from both the previous target tokens (via masked self-attention) and the source sequence (via cross-attention), the decoder applies the same position-wise feed-forward network as the encoder. This refines the combined representation.

Stacking decoder blocks. We stack six decoder blocks. Each block refines the representation by attending to (1) previous target positions, (2) the source sequence, and (3) applying nonlinear transformations. By the final decoder block, each position has a rich representation that encodes both what has been generated so far and relevant information from the source.

12.2.6 Output projection: from representations to tokens

The final decoder output is a matrix where each row represents one target position. We need to convert these continuous representations into probability distributions over the vocabulary. A linear layer projects from the model dimension (512) to the vocabulary size (50,000), producing logits for each token. Softmax converts these logits to probabilities.

For position \(i\), the probability distribution tells us how likely each vocabulary token is as the next word. During training, we compare this distribution to the true next token and compute cross-entropy loss. During inference, we sample from this distribution (or take the argmax for greedy decoding).

12.2.7 Information flow through the network

Let’s trace how information flows through the complete architecture for translating “The cat” to “Le chat”.

Encoder path: Tokens “The” and “cat” become embeddings, receive positional encodings, and enter the encoder. In the first encoder block, self-attention lets “cat” gather context from “The” and vice versa. The feed-forward network processes each position. Residual connections and layer normalization maintain training stability. This repeats through six blocks, producing deeply contextualized representations of the English sentence.

Decoder path: We start with target tokens [START], “Le”, “chat”. Masked self-attention ensures position 1 only sees [START], position 2 sees [START, Le], position 3 sees all three. Cross-attention connects each target position to the English words: “Le” attends strongly to “The”, “chat” attends to “cat”. Feed-forward processing refines these representations. This repeats through six decoder blocks.

Output: The final decoder state for position 1 produces probabilities over the vocabulary. The model has learned that after [START], when translating “The cat”, the word “Le” has high probability. Position 2’s state predicts “chat” with high probability. Position 3 predicts [END].

12.2.8 Why this architecture works

The transformer’s power comes from several key design choices. Attention provides direct paths between all positions, eliminating the information bottleneck of sequential processing. Multi-head attention captures diverse relationships in parallel. The feed-forward network adds expressiveness for complex transformations. Residual connections enable deep stacking. Layer normalization stabilizes training. Positional encoding injects order information while maintaining the permutation-equivariance of attention.

The result is an architecture that scales beautifully. By increasing model dimension, number of heads, number of layers, and training data, we get models ranging from small 86-million-parameter base transformers to massive 175-billion-parameter models like GPT-3. The fundamental computation pattern remains the same across scales.

12.2.9 Clarifying terminology: architecture vs phases

The terms “encoder” and “decoder” refer to architectural components, not to training versus inference phases. This distinction is important.

Encoder and decoder are attention masking patterns:

  • Encoder architecture uses bidirectional self-attention. Every position can attend to every other position. This is used for understanding tasks where we have the complete input.
  • Decoder architecture uses causal (masked) self-attention. Position \(i\) can only attend to positions \(1, \ldots, i\). This prevents information leakage from future tokens.

Both architectures are used in training AND inference:

In a decoder-only model like GPT, the same masked self-attention architecture is used throughout. During training, we feed complete sequences and compute loss at every position. The model learns weight matrices (\(\mathbf{W}^Q\), \(\mathbf{W}^K\), \(\mathbf{W}^V\), etc.) via backpropagation. During inference, we use those same weights with the same masked attention pattern to generate tokens one at a time.

Similarly, encoder-only models like BERT use bidirectional attention in both training and inference. The architecture doesn’t change between phases.

What actually changes between training and inference:

Aspect Training Inference
Weights Updated via backpropagation Fixed (no updates)
Input Complete sequences with known targets Partial or complete sequences
Output Loss value for optimization Predictions or generated tokens
Generation Teacher forcing (use true targets) Autoregressive (use own predictions)

For decoder models, the key difference at inference is autoregressive generation: we generate one token, append it to the sequence, and repeat. But the underlying computation (masked self-attention with the learned weights) is identical to what happened during training.

With this conceptual foundation in place, we now turn to the precise mathematical specification.

12.3 Forward propagation

We now present the precise mathematical formulation of each component described in the architecture overview. Every operation is specified with explicit dimensions and matrix operations.

12.3.1 Input embedding and positional encoding

Given source sequence with token indices \(\mathbf{t} = [t_1, t_2, \ldots, t_n]\) where \(t_i \in \{1, 2, \ldots, V\}\).

Embedding lookup: The embedding matrix \(\mathbf{E} \in \mathbb{R}^{V \times d_{model}}\) maps each token to a vector. For the sequence, we extract rows corresponding to token indices:

\[ \mathbf{X}_{embed} = \mathbf{E}[\mathbf{t}, :] \in \mathbb{R}^{n \times d_{model}} \]

where row \(i\) contains the embedding for token \(t_i\).

Positional encoding: For each position \(i \in \{1, \ldots, n\}\) and dimension \(j \in \{1, \ldots, d_{model}\}\):

\[ p_{ij} = \begin{cases} \sin\left(\frac{i-1}{10000^{(j-1)/d_{model}}}\right) & \text{if } j \text{ is odd} \\ \cos\left(\frac{i-1}{10000^{(j-2)/d_{model}}}\right) & \text{if } j \text{ is even} \end{cases} \]

This forms positional encoding matrix \(\mathbf{P} \in \mathbb{R}^{n \times d_{model}}\).

Combined input:

\[ \mathbf{X}^{(0)} = \mathbf{X}_{embed} + \mathbf{P} \in \mathbb{R}^{n \times d_{model}} \]

12.3.2 Encoder

Each encoder block applies: (1) multi-head self-attention, (2) residual + layer norm, (3) feed-forward network, (4) residual + layer norm.

Multi-head self-attention

Given input \(\mathbf{X} \in \mathbb{R}^{n \times d_{model}}\) to block \(\ell\), for each head \(k \in \{1, \ldots, h\}\):

Step 1: Project to queries, keys, values

\[ \mathbf{Q}_k = \mathbf{X} \mathbf{W}_k^Q \in \mathbb{R}^{n \times d_k} \]

\[ \mathbf{K}_k = \mathbf{X} \mathbf{W}_k^K \in \mathbb{R}^{n \times d_k} \]

\[ \mathbf{V}_k = \mathbf{X} \mathbf{W}_k^V \in \mathbb{R}^{n \times d_k} \]

where \(\mathbf{W}_k^Q, \mathbf{W}_k^K, \mathbf{W}_k^V \in \mathbb{R}^{d_{model} \times d_k}\) are learnable projection matrices.

Step 2: Compute attention scores

\[ \mathbf{S}_k = \frac{\mathbf{Q}_k \mathbf{K}_k^T}{\sqrt{d_k}} \in \mathbb{R}^{n \times n} \]

Entry \(S_{k,ij}\) measures the compatibility between position \(i\) and position \(j\).

Step 3: Apply softmax

\[ \mathbf{A}_k = \text{softmax}(\mathbf{S}_k) \in \mathbb{R}^{n \times n} \]

where softmax is applied row-wise: \(A_{k,ij} = \frac{\exp(S_{k,ij})}{\sum_{j'=1}^{n} \exp(S_{k,ij'})}\). Each row sums to 1.

Step 4: Compute weighted sum

\[ \mathbf{H}_k = \mathbf{A}_k \mathbf{V}_k \in \mathbb{R}^{n \times d_k} \]

Step 5: Concatenate heads

\[ \mathbf{H} = [\mathbf{H}_1 \,|\, \mathbf{H}_2 \,|\, \cdots \,|\, \mathbf{H}_h] \in \mathbb{R}^{n \times (h \cdot d_k)} = \mathbb{R}^{n \times d_{model}} \]

Step 6: Output projection

\[ \mathbf{Z}^{attn} = \mathbf{H} \mathbf{W}^O \in \mathbb{R}^{n \times d_{model}} \]

where \(\mathbf{W}^O \in \mathbb{R}^{d_{model} \times d_{model}}\).

Residual connection and layer normalization

Residual addition:

\[ \mathbf{R}^{(1)} = \mathbf{X} + \mathbf{Z}^{attn} \in \mathbb{R}^{n \times d_{model}} \]

Layer normalization: For each position \(i\), normalize across the \(d_{model}\) dimensions. Let \(\mathbf{r}_i \in \mathbb{R}^{d_{model}}\) be row \(i\) of \(\mathbf{R}^{(1)}\). Compute:

\[ \mu_i = \frac{1}{d_{model}} \sum_{j=1}^{d_{model}} r_{ij}, \quad \sigma_i^2 = \frac{1}{d_{model}} \sum_{j=1}^{d_{model}} (r_{ij} - \mu_i)^2 \]

\[ \hat{r}_{ij} = \frac{r_{ij} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} \]

\[ \tilde{x}_{ij} = \gamma^{(1)}_j \hat{r}_{ij} + \beta^{(1)}_j \]

where \(\gamma^{(1)}, \beta^{(1)} \in \mathbb{R}^{d_{model}}\) are learnable parameters. This gives \(\tilde{\mathbf{X}} \in \mathbb{R}^{n \times d_{model}}\).

Position-wise feed-forward network

\[ \mathbf{F}^{(1)} = \tilde{\mathbf{X}} \mathbf{W}_1^T + \mathbf{1}_n \mathbf{b}_1^T \in \mathbb{R}^{n \times d_{ff}} \]

\[ \mathbf{F}^{(2)} = \text{ReLU}(\mathbf{F}^{(1)}) \in \mathbb{R}^{n \times d_{ff}} \]

where \(\text{ReLU}(x) = \max(0, x)\) applied element-wise.

\[ \mathbf{Z}^{ffn} = \mathbf{F}^{(2)} \mathbf{W}_2^T + \mathbf{1}_n \mathbf{b}_2^T \in \mathbb{R}^{n \times d_{model}} \]

where \(\mathbf{W}_1 \in \mathbb{R}^{d_{ff} \times d_{model}}\), \(\mathbf{b}_1 \in \mathbb{R}^{d_{ff}}\), \(\mathbf{W}_2 \in \mathbb{R}^{d_{model} \times d_{ff}}\), \(\mathbf{b}_2 \in \mathbb{R}^{d_{model}}\), and \(\mathbf{1}_n\) is the \(n\)-dimensional vector of ones.

Second residual and layer normalization

\[ \mathbf{R}^{(2)} = \tilde{\mathbf{X}} + \mathbf{Z}^{ffn} \]

Apply layer normalization (same procedure as before) with parameters \(\gamma^{(2)}, \beta^{(2)}\) to get encoder block output \(\mathbf{X}^{(\ell)} \in \mathbb{R}^{n \times d_{model}}\).

Complete encoder

Apply \(N\) encoder blocks sequentially:

\[ \mathbf{X}^{(\ell)} = \text{EncoderBlock}_\ell(\mathbf{X}^{(\ell-1)}) \quad \text{for } \ell = 1, \ldots, N \]

Final encoder output: \(\mathbf{X}_{enc} = \mathbf{X}^{(N)} \in \mathbb{R}^{n \times d_{model}}\).

12.3.3 Decoder

Each decoder block applies: (1) masked self-attention, (2) residual + layer norm, (3) cross-attention, (4) residual + layer norm, (5) feed-forward network, (6) residual + layer norm.

Input to decoder

Given target sequence token indices \(\mathbf{s} = [s_1, \ldots, s_m]\) where \(s_i \in \{1, \ldots, V\}\). Embed and add positional encoding (analogous to encoder) to form \(\mathbf{Y}^{(0)} \in \mathbb{R}^{m \times d_{model}}\).

Masked self-attention

Define causal mask \(\mathbf{M} \in \mathbb{R}^{m \times m}\) where:

\[ M_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases} \]

Computation: For head \(k\), compute queries, keys, values as before. Then:

\[ \mathbf{S}_k = \frac{\mathbf{Q}_k \mathbf{K}_k^T}{\sqrt{d_k}} + \mathbf{M} \in \mathbb{R}^{m \times m} \]

\[ \mathbf{A}_k = \text{softmax}(\mathbf{S}_k) \]

When \(j > i\), entry \(S_{k,ij} = -\infty\), so \(A_{k,ij} = 0\) after softmax. This ensures position \(i\) only attends to positions \(1, \ldots, i\).

Complete multi-head attention (concatenate, project) and apply residual + layer norm to get \(\tilde{\mathbf{Y}}^{(1)} \in \mathbb{R}^{m \times d_{model}}\).

Cross-attention

For head \(k\):

\[ \mathbf{Q}_k = \tilde{\mathbf{Y}}^{(1)} \mathbf{W}_k^Q \in \mathbb{R}^{m \times d_k} \]

\[ \mathbf{K}_k = \mathbf{X}_{enc} \mathbf{W}_k^K \in \mathbb{R}^{n \times d_k} \]

\[ \mathbf{V}_k = \mathbf{X}_{enc} \mathbf{W}_k^V \in \mathbb{R}^{n \times d_k} \]

\[ \mathbf{S}_k = \frac{\mathbf{Q}_k \mathbf{K}_k^T}{\sqrt{d_k}} \in \mathbb{R}^{m \times n} \]

\[ \mathbf{A}_k = \text{softmax}(\mathbf{S}_k) \in \mathbb{R}^{m \times n} \]

\[ \mathbf{H}_k = \mathbf{A}_k \mathbf{V}_k \in \mathbb{R}^{m \times d_k} \]

Note: attention matrix is \(m \times n\) because decoder positions (rows) attend to encoder positions (columns).

Concatenate heads, project, and apply residual + layer norm to get \(\tilde{\mathbf{Y}}^{(2)} \in \mathbb{R}^{m \times d_{model}}\).

FFN sub-layer

Apply the same FFN as in encoder (two linear layers with ReLU), then residual + layer norm to get decoder block output \(\mathbf{Y}^{(\ell)} \in \mathbb{R}^{m \times d_{model}}\).

Complete decoder

Apply \(N\) decoder blocks sequentially:

\[ \mathbf{Y}^{(\ell)} = \text{DecoderBlock}_\ell(\mathbf{Y}^{(\ell-1)}, \mathbf{X}_{enc}) \quad \text{for } \ell = 1, \ldots, N \]

Final decoder output: \(\mathbf{Y}_{dec} = \mathbf{Y}^{(N)} \in \mathbb{R}^{m \times d_{model}}\).

12.3.4 Output projection and loss

Linear projection to vocabulary:

\[ \mathbf{L} = \mathbf{Y}_{dec} \mathbf{W}_{out} + \mathbf{B}_{out} \in \mathbb{R}^{m \times V} \]

where \(\mathbf{Y}_{dec} \in \mathbb{R}^{m \times d_{model}}\) is the final decoder output (one row per target position), \(\mathbf{W}_{out} \in \mathbb{R}^{d_{model} \times V}\) projects from model dimension to vocabulary size, \(\mathbf{B}_{out} \in \mathbb{R}^{m \times V}\) is the bias matrix (same bias vector \(\mathbf{b}_{out} \in \mathbb{R}^V\) repeated for each row), and \(\mathbf{L} \in \mathbb{R}^{m \times V}\) contains logits (unnormalized scores) for each token at each position.

Softmax over vocabulary:

For each position \(i\), convert logits to probabilities:

\[ P_{ij} = \frac{\exp(L_{ij})}{\sum_{k=1}^{V} \exp(L_{ik})} \]

where \(L_{ij}\) is the logit for token \(j\) at position \(i\), the denominator sums over all \(V\) vocabulary tokens, and \(P_{ij}\) is the probability of token \(j\) being the correct next token at position \(i\). Each row of \(\mathbf{P} \in \mathbb{R}^{m \times V}\) sums to 1.

Cross-entropy loss:

Let \(y_i \in \{1, \ldots, V\}\) be the index of the true target token at position \(i\). The loss measures how well our predictions match the true targets:

\[ \mathcal{L} = -\frac{1}{m} \sum_{i=1}^{m} \log P_{i, y_i} \]

where \(P_{i, y_i}\) is the predicted probability for the correct token at position \(i\), \(\log P_{i, y_i}\) is negative (since \(0 < P_{i, y_i} \leq 1\)), the negative sign makes \(\mathcal{L}\) positive, and we average over all \(m\) positions. Lower loss means higher probability assigned to correct tokens.

12.4 Backward propagation

Training requires computing how much each parameter contributes to the loss. We do this by propagating gradients backward through the network using the chain rule: if \(z\) depends on \(y\) which depends on \(x\), then \(\frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial x}\).

The backward pass mirrors the forward pass in reverse. We start with the gradient of loss with respect to the output, then work backward through each layer. At each layer, we compute two things: (1) gradients with respect to learnable parameters (for updating weights), and (2) gradients with respect to inputs (for continuing the backward pass to earlier layers).

12.4.1 The chain of gradients

The gradient flows backward through this path:

\[ \mathcal{L} \leftarrow \mathbf{P} \leftarrow \mathbf{L} \leftarrow \mathbf{Y}_{dec} \leftarrow \text{DecoderBlocks} \leftarrow \mathbf{Y}^{(0)} \leftarrow \mathbf{E}_{target} \]

At each arrow, we apply the chain rule. The notation \(\frac{\partial \mathcal{L}}{\partial \mathbf{X}}\) means “how much does the loss change when we change \(\mathbf{X}\)?”

12.4.2 Output layer: softmax and cross-entropy

We computed \(\mathbf{P} = \text{softmax}(\mathbf{L})\) and \(\mathcal{L} = -\frac{1}{m} \sum_{i=1}^{m} \log P_{i, y_i}\).

The gradient of loss with respect to logits combines both operations:

\[ \frac{\partial \mathcal{L}}{\partial L_{ij}} = \frac{1}{m} \left( P_{ij} - \delta_{j, y_i} \right) \]

where \(\delta_{j, y_i} = 1\) if \(j = y_i\) (the correct token), otherwise 0.

Interpretation: For position \(i\), the gradient is \(\frac{1}{m}(\mathbf{p}_i - \mathbf{e}_{y_i})\) where \(\mathbf{p}_i\) is the predicted probability vector and \(\mathbf{e}_{y_i}\) is a one-hot vector with 1 at the correct token. The gradient points from the prediction toward the target.

12.4.3 Linear layers

For a linear layer \(\mathbf{Y} = \mathbf{X} \mathbf{W} + \mathbf{b}\) (bias broadcast to all rows), given upstream gradient \(\frac{\partial \mathcal{L}}{\partial \mathbf{Y}}\):

Weight gradient (for parameter update):

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{W}} = \mathbf{X}^T \frac{\partial \mathcal{L}}{\partial \mathbf{Y}} \]

Each entry \(\frac{\partial \mathcal{L}}{\partial W_{ij}}\) accumulates contributions from all positions where input dimension \(i\) affected output dimension \(j\).

Bias gradient (for parameter update):

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{b}} = \sum_{\text{rows}} \frac{\partial \mathcal{L}}{\partial \mathbf{Y}} \]

Sum over all positions since the same bias is added everywhere.

Input gradient (for continuing backward):

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{X}} = \frac{\partial \mathcal{L}}{\partial \mathbf{Y}} \mathbf{W}^T \]

This propagates the gradient back through the weight matrix.

12.4.4 Layer normalization

For layer normalization: \(y_j = \gamma_j \hat{x}_j + \beta_j\) where \(\hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \epsilon}}\).

Parameter gradients:

\[ \frac{\partial \mathcal{L}}{\partial \gamma_j} = \sum_{\text{positions}} \frac{\partial \mathcal{L}}{\partial y_j} \cdot \hat{x}_j \]

\[ \frac{\partial \mathcal{L}}{\partial \beta_j} = \sum_{\text{positions}} \frac{\partial \mathcal{L}}{\partial y_j} \]

Input gradient: This is complex because each \(x_j\) affects the output through three paths: directly, through \(\mu\), and through \(\sigma^2\). Let \(\mathbf{g} = \gamma \odot \frac{\partial \mathcal{L}}{\partial \mathbf{y}}\) (upstream gradient scaled by \(\gamma\)):

\[ \frac{\partial \mathcal{L}}{\partial x_j} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} \left( g_j - \frac{1}{d}\sum_k g_k - \frac{\hat{x}_j}{d} \sum_k g_k \hat{x}_k \right) \]

The three terms account for: (1) direct contribution, (2) effect through mean, (3) effect through variance.

12.4.5 ReLU activation

For \(y = \text{ReLU}(x) = \max(0, x)\):

\[ \frac{\partial \mathcal{L}}{\partial x} = \begin{cases} \frac{\partial \mathcal{L}}{\partial y} & \text{if } x > 0 \\ 0 & \text{if } x \leq 0 \end{cases} \]

ReLU passes gradients through where it was active (input positive), and blocks gradients where it was inactive (input negative or zero). This is applied element-wise.

12.4.6 Attention mechanism

The attention computation \(\mathbf{H}_k = \mathbf{A}_k \mathbf{V}_k\) where \(\mathbf{A}_k = \text{softmax}(\mathbf{S}_k / \sqrt{d_k})\) and \(\mathbf{S}_k = \mathbf{Q}_k \mathbf{K}_k^T\).

Step 1: Gradient through weighted sum (\(\mathbf{H}_k = \mathbf{A}_k \mathbf{V}_k\))

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{A}_k} = \frac{\partial \mathcal{L}}{\partial \mathbf{H}_k} \mathbf{V}_k^T \]

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{V}_k} = \mathbf{A}_k^T \frac{\partial \mathcal{L}}{\partial \mathbf{H}_k} \]

Step 2: Gradient through softmax

For each row \(i\), let \(\mathbf{a} = \mathbf{A}_{k,i,:}\) be the attention weights and \(\mathbf{s} = \mathbf{S}_{k,i,:}\) be the scores:

\[ \frac{\partial \mathcal{L}}{\partial s_j} = a_j \left( \frac{\partial \mathcal{L}}{\partial a_j} - \sum_r a_r \frac{\partial \mathcal{L}}{\partial a_r} \right) \]

The subtracted term redistributes gradient to maintain the constraint that attention weights sum to 1.

Step 3: Gradient through scaling

\[ \frac{\partial \mathcal{L}}{\partial S_{k,ij}^{\text{unscaled}}} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial S_{k,ij}} \]

Step 4: Gradient through score computation (\(\mathbf{S}_k = \mathbf{Q}_k \mathbf{K}_k^T\))

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{Q}_k} = \frac{\partial \mathcal{L}}{\partial \mathbf{S}_k} \mathbf{K}_k \]

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{K}_k} = \left( \frac{\partial \mathcal{L}}{\partial \mathbf{S}_k} \right)^T \mathbf{Q}_k \]

Step 5: Gradient through projections (\(\mathbf{Q}_k = \mathbf{X} \mathbf{W}_k^Q\))

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{W}_k^Q} = \mathbf{X}^T \frac{\partial \mathcal{L}}{\partial \mathbf{Q}_k} \]

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{X}} \mathrel{+}= \frac{\partial \mathcal{L}}{\partial \mathbf{Q}_k} \left( \mathbf{W}_k^Q \right)^T \]

The \(\mathrel{+}=\) notation means we accumulate gradients because \(\mathbf{X}\) is used to compute \(\mathbf{Q}_k\), \(\mathbf{K}_k\), and \(\mathbf{V}_k\). We compute similar gradients for \(\mathbf{W}_k^K\) and \(\mathbf{W}_k^V\).

12.4.7 Residual connections

For \(\mathbf{R} = \mathbf{X} + \mathbf{Z}\):

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{X}} = \frac{\partial \mathcal{L}}{\partial \mathbf{R}}, \quad \frac{\partial \mathcal{L}}{\partial \mathbf{Z}} = \frac{\partial \mathcal{L}}{\partial \mathbf{R}} \]

The gradient flows unchanged to both branches. This is why residual connections help training: gradients can flow directly from later layers to earlier layers without passing through potentially problematic transformations. Even if gradients through \(\mathbf{Z}\) vanish, gradients through the skip connection \(\mathbf{X}\) remain intact.

12.4.8 Embedding layer

The embedding lookup \(\mathbf{X}_{embed} = \mathbf{E}[\mathbf{t}, :]\) selects rows from the embedding matrix. The gradient updates only the rows that were selected:

\[ \frac{\partial \mathcal{L}}{\partial \mathbf{E}[j, :]} = \sum_{i: t_i = j} \frac{\partial \mathcal{L}}{\partial \mathbf{X}_{embed}[i, :]} \]

If token \(j\) appears at positions 2 and 5 in the sequence, we sum the gradients from both positions and apply that update to row \(j\) of the embedding matrix. Tokens not present in the current batch receive zero gradient.

12.5 Parameter updates

We use Adam optimizer with:

  • Learning rate: \(\alpha\) (with warmup schedule)
  • Exponential decay rates: \(\beta_1 = 0.9\), \(\beta_2 = 0.98\)
  • Stability constant: \(\epsilon = 10^{-9}\)

For each parameter \(\theta\) with gradient \(g_t = \frac{\partial \mathcal{L}}{\partial \theta}\) at step \(t\):

Update first moment:

\[ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \]

Update second moment:

\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \]

Bias correction:

\[ \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \]

Parameter update:

\[ \theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \]

Learning rate schedule with warmup:

\[ \alpha_t = d_{model}^{-0.5} \cdot \min(t^{-0.5}, t \cdot \text{warmup\_steps}^{-1.5}) \]

with warmup_steps \(= 4000\). This increases learning rate linearly during warmup, then decays proportional to \(t^{-0.5}\).

12.6 Decoder-only architecture (GPT)

For decoder-only models:

Architecture:

  1. Embedding + positional encoding
  2. \(N\) decoder blocks, each with:
    • Masked multi-head self-attention
    • Residual + layer normalization
    • Position-wise FFN
    • Residual + layer normalization
  3. Output projection to vocabulary

Key differences from encoder-decoder:

  • No encoder component
  • No cross-attention sub-layer in decoder blocks
  • Only masked self-attention (causal masking ensures autoregressive generation)

Forward and backward propagation follow the masked self-attention and FFN components described above.

12.7 Encoder-only architecture (BERT)

For encoder-only models like BERT (Devlin et al. 2018):

Architecture:

  1. Embedding + positional encoding
  2. \(N\) encoder blocks (bidirectional self-attention + FFN)
  3. Task-specific head (e.g., linear classifier on [CLS] token)

Key differences from encoder-decoder:

  • No decoder component
  • No causal masking (bidirectional attention)
  • Task-specific output layer instead of vocabulary projection

Forward and backward propagation follow the encoder derivation above.

12.8 Computational complexity

Time complexity per layer:

  • Self-attention: \(O(n^2 d_{model})\) for computing attention scores
  • Cross-attention: \(O(m n d_{model})\) where \(m\) is decoder length, \(n\) is encoder length
  • FFN: \(O(n d_{model} d_{ff})\)
  • Total per sample: \(O(n^2 d_{model} N)\) for encoder/decoder

Space complexity:

  • Attention matrices: \(O(n^2 h N)\) for storing attention weights
  • Activations: \(O(n d_{model} N)\) for intermediate states
  • Parameters: \(O(d_{model}^2 N + V d_{model})\) for weights and embeddings

Parameter count (base transformer):

  • Embeddings: \(V \times d_{model} = 50000 \times 512 \approx 25\text{M}\)
  • Per encoder block: \(4d_{model}^2 + 2d_{model} d_{ff} + 4d_{model} \approx 3\text{M}\)
  • \(N = 6\) blocks (encoder + decoder): \(36\text{M}\)
  • Output projection: \(d_{model} \times V \approx 25\text{M}\)
  • Total: \(\approx 86\text{M}\) parameters

For comparison, GPT-3 (Brown et al. 2020) has 175 billion parameters using \(d_{model}=12288\), \(d_{ff}=49152\), \(h=96\), \(N=96\).

12.9 Implementation notes

A complete implementation requires:

Initialization:

  • Weight matrices: Xavier/Glorot uniform initialization
  • Biases: initialize to zero
  • Layer norm parameters: \(\gamma = 1\), \(\beta = 0\)

Regularization:

  • Dropout with rate 0.1 after attention weights and FFN activations
  • Label smoothing with value 0.1 for cross-entropy loss

Batching:

  • Process multiple sequences in parallel
  • Pad sequences to maximum length in batch
  • Create padding mask: set attention scores to \(-\infty\) for padding positions

Training stability:

  • Gradient clipping: clip global norm to prevent explosion
  • Mixed precision training: use FP16 for forward/backward, FP32 for parameter updates

This completes the mathematical specification of the transformer architecture with forward and backward propagation.