Example number
80
Example slug
example_80_transformers_attention_as_qk_t_then_av
Background

Matrix products became the computational substrate of deep learning because GPUs and TPUs are hardware-optimized for dense matrix multiplication (GEMM kernels). Early neural networks (1980s-90s) used matrix-vector products for feedforward layers ($y = Wx + b$), but attention mechanisms introduced matrix-matrix products as the dominant operation. The attention mechanism originated in sequence-to-sequence models (Bahdanau et al., 2015) to allow decoders to “attend” to relevant encoder states, replacing fixed-length context vectors with dynamic weighted sums. Transformers (Vaswani et al., 2017, “Attention Is All You Need”) removed recurrence entirely, using only attention and feedforward layers, enabling parallelization across sequence positions (unlike RNNs, which process sequentially). The scaled dot-product formulation $\text{softmax}(QK^\top / \sqrt{d_k}) V$ emerged as the efficient implementation: (1) $QK^\top$ computes all $n_q \times n_k$ pairwise similarities in one BLAS3 operation ($O(n^2 d)$ for self-attention), (2) softmax normalizes in $O(n^2)$, (3) $AV$ aggregates values in another BLAS3 call. This three-step pattern—score, normalize, aggregate—became the universal attention primitive, powering BERT (Devlin et al., 2019, 340M parameters, pre-training on masked language modeling), GPT (Radford et al., 2018-2020, autoregressive language models scaling to 175B parameters), Vision Transformers (Dosovitskiy et al., 2021, image patches as tokens), and diffusion models (Ho et al., 2020, cross-attention for text-to-image generation). The quadratic $O(n^2)$ complexity drove research into efficient attention variants: sparse patterns (Longformer, BigBird), low-rank approximations (Linformer), and kernel methods (Performers), enabling $10^4$-token contexts.

Purpose

Demonstrate how scaled dot-product attention—the core operation of Transformer architectures—decomposes into two matrix products ($QK^\top$ for similarity scoring, $AV$ for weighted aggregation) connected by row-wise softmax normalization. For queries $Q \in \mathbb{R}^{n_q \times d_k}$, keys $K \in \mathbb{R}^{n_k \times d_k}$, and values $V \in \mathbb{R}^{n_k \times d_v}$, attention computes $O = \text{softmax}(QK^\top / \sqrt{d_k}) V \in \mathbb{R}^{n_q \times d_v}$, producing data-dependent weighted combinations of values where weights arise from query-key similarity. The scaling factor $1/\sqrt{d_k}$ prevents dot products from growing with dimension (which would cause softmax saturation and vanishing gradients). This operation—$O(n_q n_k d_k)$ for scoring, $O(n_q n_k d_v)$ for aggregation—is the computational bottleneck of Transformers, limiting sequence length to $\sim 10^3$ tokens without specialized sparse attention patterns. The example validates row-stochasticity of attention weights (each query’s weights sum to 1, forming a probability distribution over keys) and shows how shapes/transposes “feel inevitable” once the query-key-value paradigm is understood, enabling students to reason about forward/backward passes in BERT, GPT, and Vision Transformers without memorizing formulas.

Problem

Compute scores, attention weights, and outputs for a tiny attention head; verify weights sum to 1.

Solution (Math)

Scaled dot-product attention computes output $O \in \mathbb{R}^{n_q \times d_v}$ from queries $Q$, keys $K$, and values $V$ via:

\[ O = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V = A V, \]

where $A \in \mathbb{R}^{n_q \times n_k}$ is the attention weight matrix.

Step 1: Compute attention scores via scaled dot products:

\[ ext{scores} = \frac{Q K^\top}{\sqrt{d_k}} \in \mathbb{R}^{n_q \times n_k}, \]

where $Q \in \mathbb{R}^{n_q \times d_k}$, $K \in \mathbb{R}^{n_k \times d_k}$, and $d_k$ is the key dimension. Each entry $\text{scores}_{ij} = q_i^\top k_j / \sqrt{d_k}$ measures similarity between query $i$ and key $j$. Scaling by $1/\sqrt{d_k}$ prevents scores from growing proportionally to $d_k$ (inner products of random unit vectors have variance $d_k$), which would push softmax into saturation regions (one weight $\approx 1$, others $\approx 0$).

Step 2: Normalize to attention weights via row-wise softmax:

\[ A_{ij} = \frac{\exp(\text{scores}_{ij})}{\sum_{k=1}^{n_k} \exp(\text{scores}_{ik})} \quad \text{for } i=1,\ldots,n_q. \]

Properties: (1) $A_{ij} \ge 0$ (non-negative), (2) $\sum_{j=1}^{n_k} A_{ij} = 1$ (rows sum to 1, row-stochastic), (3) $A$ is a probability distribution over keys for each query.

Step 3: Aggregate values via weighted sum:

\[ O = A V \in \mathbb{R}^{n_q \times d_v}, \]

where $V \in \mathbb{R}^{n_k \times d_v}$. Each output row is:

\[ O_i = \sum_{j=1}^{n_k} A_{ij} V_j, \]

a convex combination of value vectors weighted by attention scores.

Computational complexity: - $QK^\top$: $O(n_q n_k d_k)$ operations (matrix-matrix multiply) - Softmax: $O(n_q n_k)$ operations (exponentiation + normalization) - $AV$: $O(n_q n_k d_v)$ operations (matrix-matrix multiply) - Total: $O(n_q n_k (d_k + d_v))$ ≈ $O(n^2 d)$ for self-attention ($n_q = n_k = n$, $d_k \approx d_v \approx d$)

Standard notation: - $Q \in \mathbb{R}^{n_q \times d_k}$: query matrix ($n_q$ queries, each $d_k$-dimensional) - $K \in \mathbb{R}^{n_k \times d_k}$: key matrix ($n_k$ keys, each $d_k$-dimensional) - $V \in \mathbb{R}^{n_k \times d_v}$: value matrix ($n_k$ values, each $d_v$-dimensional) - $A \in \mathbb{R}^{n_q \times n_k}$: attention weight matrix (row-stochastic) - $O \in \mathbb{R}^{n_q \times d_v}$: output matrix - Transpose: $K^\top$ (not $K^T$) - Softmax along axis 1 (across keys for each query)

Solution (Python)

import numpy as np
from scripts.toy_data import softmax

Q = np.array([[1., 0.],
              [0., 1.]])
K = np.array([[1., 0.],
              [1., 1.],
              [0., 1.]])
V = np.array([[1., 0.],
              [0., 2.],
              [1., 1.]])

scores = Q @ K.T / np.sqrt(2)
A = softmax(scores, axis=1)
O = A @ V

print("A row sums:", A.sum(axis=1))
print("O:
", O)
Code Introduction

This code demonstrates scaled dot-product attention, the fundamental operation in Transformer architectures. Attention lets the model dynamically select and aggregate relevant information from a set of values, guided by query-key similarity scores. The mechanism decomposes into two sequential matrix products: $QK^\top$ (scoring) followed by $AV$ (aggregation).

Numerical Implementation Details

Step-by-step execution of scaled dot-product attention:

  1. Define input matrices: Queries $Q \in \mathbb{R}^{2 \times 2}$, keys $K \in \mathbb{R}^{3 \times 2}$, values $V \in \mathbb{R}^{3 \times 2}$ are hardcoded NumPy arrays representing two queries attending to three key-value pairs in a 2-dimensional space. Shapes: $Q$ has 2 rows (queries), $K$ has 3 rows (keys), both with $d_k = 2$ columns. $V$ has 3 rows (values) with $d_v = 2$ columns.

  2. Compute scaled scores: scores = Q @ K.T / np.sqrt(2) computes the matrix product $QK^\top \in \mathbb{R}^{2 \times 3}$, then scales by $1/\sqrt{d_k} = 1/\sqrt{2} \approx 0.707$. The transpose K.T converts $K$ from $(3, 2)$ to $(2, 3)$, making the product $(2, 2) \times (2, 3) \to (2, 3)$ valid. Each entry $\text{scores}_{ij}$ is the scaled dot product between query $i$ and key $j$. For example: $\text{scores}_{00} = (Q[0] \cdot K[0]) / \sqrt{2} = (1 \cdot 1 + 0 \cdot 0) / 1.414 \approx 0.707$.

  3. Apply row-wise softmax: A = softmax(scores, axis=1) normalizes each row of scores independently (each query’s distribution over keys). For row 0: $\text{scores}_0 = [0.707, 0.707, 0]$ becomes $A_0 \approx [0.382, 0.382, 0.236]$ after exponentiating and normalizing by the row sum. The axis=1 argument is critical—it specifies normalization across keys (columns) for each query (row). Using axis=0 would normalize across queries, which is meaningless.

  4. Aggregate values: O = A @ V computes the output $O \in \mathbb{R}^{2 \times 2}$ by multiplying attention weights $(2, 3)$ by values $(3, 2) \to (2, 2)$. Each output row is a weighted average of value vectors: $O_0 = 0.382 \cdot V_0 + 0.382 \cdot V_1 + 0.236 \cdot V_2 \approx [0.618, 1.000]$.

  5. Verify row-stochasticity: A.sum(axis=1) computes row sums, yielding [1., 1.] (each query’s weights sum to 1, confirming valid probability distributions). This is a fundamental property of softmax normalization—it guarantees that attention weights are convex combination coefficients.

  6. Inspect output: Print statements show attention weights sum to 1 and display the final output matrix $O$. The output represents query-specific weighted combinations of values, where weights reflect query-key similarity.

What This Example Demonstrates
  • Two-stage matrix product pattern: Attention decomposes into (1) similarity scoring via $\text{scores} = QK^\top / \sqrt{d_k} \in \mathbb{R}^{n_q \times n_k}$, computing all pairwise query-key dot products in one matrix multiply ($O(n_q n_k d_k)$ operations), and (2) value aggregation via $O = AV \in \mathbb{R}^{n_q \times d_v}$, forming weighted sums of values ($O(n_q n_k d_v)$ operations). Softmax normalization ($O(n_q n_k)$) connects these stages, converting raw scores to probability distributions.

  • Scaled dot products prevent saturation: Without the $1/\sqrt{d_k}$ factor, scores grow proportionally to $d_k$ (for random unit vectors, $\mathbb{E}[q^\top k] = 0$, $\text{Var}(q^\top k) = d_k$). For $d_k = 512$ (typical in GPT), unscaled scores would be $\sim 20\times$ larger, pushing softmax into saturation (one weight $\approx 1$, others $\approx 0$, gradients $\approx 0$). Scaling keeps scores in a moderate range ($\sim [-2, 2]$), preserving gradient flow. This example uses $d_k = 2$, so scaling by $1/\sqrt{2} \approx 0.707$ reduces scores from $\{0, 1\}$ to $\{0, 0.707\}$.

  • Row-stochastic attention weights: A.sum(axis=1) returns [1., 1.], confirming each query’s attention weights form a probability distribution over keys ($\sum_j A_{ij} = 1$, $A_{ij} \ge 0$). This normalization ensures outputs are convex combinations of values: $O_i = \sum_j A_{ij} V_j$, where weights depend on query-key similarity. The code verifies this property explicitly, showing that attention is not just weighted averaging but normalized weighted averaging.

  • Shape discipline via transposes: The transpose $K^\top$ is essential for dimensional compatibility: $Q \in \mathbb{R}^{n_q \times d_k}$ times $K^\top \in \mathbb{R}^{d_k \times n_k}$ yields $\text{scores} \in \mathbb{R}^{n_q \times n_k}$ (all pairwise similarities). Without transpose, $Q @ K$ would fail (inner dimensions mismatch: $d_k \ne n_k$). This “shape reasoning” generalizes to all attention variants—once the query-key-value paradigm is understood, transposes “feel inevitable.”

  • Query-key symmetry and value asymmetry: $Q$ and $K$ must have the same dimension $d_k$ (to compute dot products), but $V$ can have a different dimension $d_v$ (values encode different information than keys). In self-attention ($Q, K, V$ all derived from the same input $X$), typically $d_k = d_v = d_{\text{model}} / h$ (model dimension divided by number of heads). In cross-attention (e.g., encoder-decoder), $K, V$ come from encoder outputs, $Q$ from decoder, allowing asymmetric dimensions.

  • Attention as soft dictionary lookup: The query-key-value structure mirrors database retrieval: query = “what information do I need?”, keys = “searchable index”, values = “actual content to retrieve”. Unlike hard indexing (retrieve record $k$), attention performs soft retrieval—a weighted combination of all values, where weights are determined by query-key similarity. This differentiability enables end-to-end training of the attention mechanism itself (learning optimal $Q, K, V$ projections).

Notes

Part 1: Query-Key Similarity via Scaled Dot Products

The first matrix product $QK^\top$ computes all pairwise similarities between queries and keys in one operation. For $Q \in \mathbb{R}^{2 \times 2}$ and $K \in \mathbb{R}^{3 \times 2}$, the product $QK^\top \in \mathbb{R}^{2 \times 3}$ has 6 entries (2 queries $\times$ 3 keys), each representing the dot product $q_i^\top k_j$. The transpose $K^\top$ is essential: $Q$ has shape $(2, 2)$, $K$ has shape $(3, 2)$, so $Q @ K$ would fail (inner dimensions $2 \ne 3$ mismatch), but $Q @ K.T$ succeeds ($(2, 2) \times (2, 3) \to (2, 3)$). Scaling by $1/\sqrt{d_k}$: For random unit vectors $q, k \in \mathbb{R}^{d_k}$, the dot product $q^\top k$ has variance $d_k$ (sum of $d_k$ independent products, each with variance 1). For $d_k = 512$ (GPT-3), unscaled scores would have standard deviation $\sim 23$, pushing softmax into saturation (e.g., $\exp(20) / (\exp(20) + \exp(0)) \approx 0.9999$, one key dominates, gradients vanish). Scaling by $1/\sqrt{d_k}$ normalizes variance to $1$, keeping scores in a moderate range ($\sim [-2, 2]$) where softmax remains sensitive to all inputs. Why not normalize by $d_k$? Dividing by $d_k$ (linear scaling) would make scores $\sim 1/d_k$ (tiny for large $d_k$), causing underflow in softmax. The $\sqrt{d_k}$ balances: large enough to prevent saturation, small enough to avoid underflow.

Part 2: Softmax Normalization to Attention Weights

Row-wise softmax converts raw scores to probability distributions: $A_{ij} = \exp(\text{scores}_{ij}) / \sum_k \exp(\text{scores}_{ik})$. Properties: (1) Non-negative: $\exp(\cdot) > 0$ ensures $A_{ij} \ge 0$. (2) Row-stochastic: $\sum_j A_{ij} = 1$ (each query distributes 100% attention across keys). (3) Differentiable: Softmax gradients $\partial A_{ij} / \partial \text{scores}_{ik} = A_{ij} (\delta_{jk} - A_{ik})$ enable backpropagation. Why softmax vs. other normalizations? (a) Max-normalization ($A_{ij} = \text{scores}_{ij} / \max_k \text{scores}_{ik}$) violates row-stochasticity (sums $\ne 1$). (b) L1-normalization ($A_{ij} = \text{scores}_{ij} / \sum_k \text{scores}_{ik}$) can produce negative weights if scores are negative. (c) Softmax (exponential + normalization) guarantees non-negativity and stochasticity, and has a maximum entropy interpretation (Boltzmann distribution). Temperature scaling: Softmax with temperature $\tau$: $A_{ij} = \exp(\text{scores}_{ij} / \tau) / \sum_k \exp(\text{scores}_{ik} / \tau)$. As $\tau \to 0$, attention becomes “hard” (argmax, one-hot). As $\tau \to \infty$, attention becomes uniform (all keys equally weighted). Standard attention uses $\tau = 1$.

Part 3: Value Aggregation via Weighted Sum

The second matrix product $AV$ computes attention-weighted aggregations: $O = AV \in \mathbb{R}^{2 \times 2}$. Each output row is $O_i = \sum_{j=1}^3 A_{ij} V_j$, a convex combination of value vectors (since $\sum_j A_{ij} = 1$, $A_{ij} \ge 0$). Geometric interpretation: Each query $i$ produces an output $O_i$ lying in the convex hull of $\{V_1, V_2, V_3\}$ (the set of all weighted averages). Why separate keys and values? In principle, we could set $V = K$ (attention weights determined by and applied to the same vectors). However, decoupling allows keys to encode what information is available (searchable metadata) while values encode what to retrieve (actual content). In practice, $Q, K, V$ are learned projections of the input: $Q = XW_Q$, $K = XW_K$, $V = XW_V$, where $W_Q, W_K, W_V$ are trained end-to-end. Self-attention vs. cross-attention: In self-attention (e.g., BERT, GPT), $Q, K, V$ all come from the same sequence $X$. In cross-attention (e.g., encoder-decoder translation), $Q$ comes from decoder states, $K, V$ from encoder outputs—allowing the decoder to “attend” to source sentence positions when generating target words.

Why This Matters for ML

Transformer architecture dominance: Transformers replaced RNNs/LSTMs as the default sequence model because attention enables full parallelization—all positions processed simultaneously ($O(n^2 d)$ total, but parallelizable across $n$ positions), whereas RNNs require sequential processing ($O(nd)$ per step, $n$ sequential steps, total $O(n^2 d)$ but sequential). This $100-1000\times$ wall-clock speedup on GPUs enabled scaling to billions of parameters (GPT-3: 175B, PaLM: 540B). Quadratic bottleneck: For sequence length $n$ and dimension $d$, attention costs $O(n^2 d)$. For $n = 10^4$ (long documents), this becomes prohibitive ($10^8$ pairwise scores per head). Efficient attention variants address this: sparse attention (Longformer: $O(n \log n)$ via local+global masks), low-rank attention (Linformer: $O(nd)$ via key/value projection to rank $k \ll n$), kernel attention (Performer: $O(nd)$ via random feature approximation of softmax kernel). Multi-head attention: Standard Transformers use $h = 8$-16 parallel heads, each with $d_k = d_{\text{model}} / h$. Heads learn complementary attention patterns (e.g., head 1: syntactic dependencies, head 2: semantic similarity, head 3: positional locality). Positional encodings: Raw attention is permutation-invariant (shuffling input tokens doesn’t change attention weights). Transformers add sinusoidal or learned positional embeddings to inject sequence order information: $X_{\text{pos}} = X + \text{PE}(pos)$, where $\text{PE}(pos)$ encodes absolute or relative positions.

ML Examples and Patterns

Self-attention layer (single head): Q = X @ W_Q, K = X @ W_K, V = X @ W_V project input $X \in \mathbb{R}^{n \times d_{\text{model}}}$ to queries/keys/values via learned matrices $W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_k}$. Then compute scores = Q @ K.T / sqrt(d_k), A = softmax(scores, axis=-1), O = A @ V. Output $O \in \mathbb{R}^{n \times d_k}$ is projected back to model dimension via $W_O \in \mathbb{R}^{d_k \times d_{\text{model}}}$.

Multi-head attention: Run $h$ attention heads in parallel with $d_k = d_{\text{model}} / h$ per head, then concatenate: O_multi = Concat(head_1, ..., head_h) @ W_O, where head_i = Attention(X @ W_Q^i, X @ W_K^i, X @ W_V^i). PyTorch implementation: nn.MultiheadAttention(embed_dim=512, num_heads=8).

Masked attention (autoregressive): GPT-style models mask future positions to prevent lookahead: mask = torch.triu(torch.ones(n, n), diagonal=1).bool(), then scores.masked_fill_(mask, -1e9) before softmax. This zeroes attention to future tokens, enforcing left-to-right generation.

Cross-attention (encoder-decoder): In translation, decoder queries attend to encoder keys/values: Q_dec = X_dec @ W_Q, K_enc = X_enc @ W_K, V_enc = X_enc @ W_V, then O = Attention(Q_dec, K_enc, V_enc). Decoder state at position $i$ can attend to all encoder positions (no causal mask).

Positional encoding: PE[pos, 2i] = sin(pos / 10000^(2i/d)), PE[pos, 2i+1] = cos(pos / 10000^(2i/d)) for $i=0, \ldots, d/2-1$. Add to embeddings: X_input = Embedding(tokens) + PE. Alternative: learned positional embeddings nn.Embedding(max_len, d_model).

Attention dropout: A = F.dropout(softmax(scores), p=0.1, training=True) randomly zeros attention weights during training to prevent overfitting to specific dependencies.

Connection to Linear Algebra Theory

Attention as convex combination: Each output $O_i = \sum_j A_{ij} V_j$ is a point in the convex hull of $\{V_1, \ldots, V_{n_k}\}$ (all weighted averages with non-negative weights summing to 1). This is a projection onto the value subspace, but weights are data-dependent (learned from query-key similarity) rather than fixed (like orthogonal projection in least squares).

Dot product as similarity metric: For normalized vectors $\|q\|_2 = \|k\|_2 = 1$, the dot product $q^\top k = \cos(\theta)$ is the cosine of the angle between $q$ and $k$. High dot product ($\cos(\theta) \approx 1$) means $\theta \approx 0$ (aligned vectors, high similarity). Attention amplifies this via softmax: similar pairs get disproportionately large weights.

Softmax as exponential family: Softmax $A_{ij} = \exp(s_{ij}) / Z_i$ (where $Z_i = \sum_k \exp(s_{ik})$) is the maximum entropy distribution subject to expected score constraint $\mathbb{E}[s] = \sum_j A_{ij} s_{ij}$. This is a Boltzmann distribution in statistical mechanics, with “temperature” $\tau = 1$ (or $\tau = \sqrt{d_k}$ if unscaled).

Row-stochastic matrices as Markov transitions: Attention matrix $A$ (rows sum to 1) is a transition matrix for a Markov chain on keys. Applying $A$ repeatedly ($A^2, A^3, \ldots$) corresponds to multi-hop random walks on the key-value graph. Some architectures (e.g., Graph Attention Networks) explicitly interpret attention as graph traversal.

Rank-1 outer products: Each attention contribution $A_{ij} V_j$ is a rank-1 matrix (outer product of scalar weight $A_{ij}$ with vector $V_j$). Attention sums $n_k$ such rank-1 contributions: $O_i = \sum_j A_{ij} V_j$, analogous to SVD reconstruction $X \approx \sum_{k=1}^r \sigma_k u_k v_k^\top$.

Kernel methods interpretation: Attention can be written as $O_i = \sum_j \kappa(q_i, k_j) V_j / \sum_j \kappa(q_i, k_j)$, where $\kappa(q, k) = \exp(q^\top k / \sqrt{d_k})$ is an unnormalized Gaussian kernel. This enables kernel approximation techniques (random Fourier features, Nyström) for efficient attention (Performer, Linear Transformers).

Numerical and Implementation Notes

Shape discipline: $Q \in \mathbb{R}^{n_q \times d_k}$ (2 queries, 2-dim), $K \in \mathbb{R}^{n_k \times d_k}$ (3 keys, 2-dim), $V \in \mathbb{R}^{n_k \times d_v}$ (3 values, 2-dim). Scores: $(n_q, n_k) = (2, 3)$. Attention weights: $(n_q, n_k) = (2, 3)$. Output: $(n_q, d_v) = (2, 2)$. Check dimensions before every matrix multiply: $Q @ K.T$ requires $Q.shape[1] == K.shape[1]$ ($d_k$ match), $A @ V$ requires $A.shape[1] == V.shape[0]$ ($n_k$ match).

Gotcha 1: Softmax axis. softmax(scores, axis=1) normalizes across keys (axis=1, columns) for each query (axis=0, rows). Using axis=0 would normalize across queries, which is meaningless (each key would distribute attention to queries, backwards). Always normalize across the key dimension.

Gotcha 2: Scaling factor. Use $1/\sqrt{d_k}$, not $1/\sqrt{d_v}$ or $1/\sqrt{n_k}$. The goal is to counteract dot product variance growth with key dimension ($\text{Var}(q^\top k) = d_k$ for unit vectors).

Gotcha 3: Masked attention. For autoregressive models (GPT), mask future positions: mask = np.triu(np.ones((n_q, n_k)), k=1) (upper triangular), then scores = scores - 1e9 * mask (set future scores to $-\infty$, softmax $\to 0$). Don’t use scores[mask] = -inf (in-place modification breaks autograd).

Gotcha 4: Batch dimensions. Production attention has shape (batch, num_heads, seq_len, seq_len). Ensure softmax normalizes over the last dimension (keys): F.softmax(scores, dim=-1). Broadcasting rules: (B, H, N_q, N_k) @ (B, H, N_k, D_v) -> (B, H, N_q, D_v).

Gotcha 5: Numerical stability. Softmax can overflow ($\exp(1000) \to \infty$) or underflow ($\exp(-1000) \to 0$). Use log-sum-exp trick: $\log \sum_j \exp(x_j) = \max_j x_j + \log \sum_j \exp(x_j - \max_j x_j)$. PyTorch F.softmax handles this automatically.

Gotcha 6: FlashAttention. For long sequences ($n > 2048$), materializing full $QK^\top \in \mathbb{R}^{n \times n}$ exceeds GPU memory. FlashAttention (Dao et al., 2022) fuses attention computation in CUDA kernels, tiling to avoid storing full attention matrix—$3\times$ memory reduction, $2\times$ speedup.

Numerical and Shape Notes

Verification checks: (1) Row sums: A.sum(axis=1) should equal [1., 1.] (each query’s weights sum to 1). (2) Non-negativity: A >= 0 all elements (softmax guarantees this). (3) Output shape: O.shape == (n_q, d_v) (2 queries, 2-dim values). (4) Numerical equivalence: Forward pass should match backprop gradients (use torch.autograd.gradcheck for validation).

Cost analysis: For $n_q = n_k = n = 512$ (typical sentence), $d_k = d_v = 64$ (per-head dimension): - $QK^\top$: $n^2 d_k = 512^2 \times 64 \approx 17M$ FLOPs - Softmax: $n^2 = 512^2 \approx 262K$ FLOPs (negligible) - $AV$: $n^2 d_v = 512^2 \times 64 \approx 17M$ FLOPs - Total: $\sim 34M$ FLOPs per head. For 8 heads: $\sim 270M$ FLOPs per layer. GPT-3 (96 layers, 16k context): $\sim 4 \times 10^{12}$ FLOPs per forward pass.

Memory: Attention matrix $A$ requires $n^2$ storage. For $n = 16384$ (Longformer), $A$ is $16384^2 \approx 268M$ elements ($\sim 1$ GB for float32). This is the memory bottleneck of Transformers—efficient variants (sparse, low-rank) reduce this to $O(n)$ or $O(n \log n)$.

Tolerance: Softmax is numerically stable (log-sum-exp trick), but attention weights can be tiny ($< 10^{-6}$) for dissimilar query-key pairs. Check for underflow: A.min() should be $> 0$ (if $\approx 0$, may indicate numerical issues or extreme sparsity).

Pedagogical Significance

This example is the foundational demonstration of attention mechanisms in modern deep learning:

Key takeaways: (1) Attention = weighted aggregation: Dynamically select and combine information based on query-key similarity. (2) Query-key-value paradigm: Separates “what to search for” (query) from “what’s available” (key) and “what to retrieve” (value). (3) Softmax normalization: Converts raw scores to probabilities (row-stochastic, interpretable weights). (4) Scaling prevents saturation: $1/\sqrt{d_k}$ keeps scores in moderate range for stable gradients. (5) Permutation invariance: Raw attention has no notion of order; positional encodings inject sequence structure.

Common misconceptions addressed: (a) “Attention requires recurrence”—No, Transformers are fully feedforward (parallel), unlike RNNs. (b) “Attention is approximate”—No, it’s an exact weighted sum, fully differentiable. (c) “Softmax axis doesn’t matter”—Critical—must normalize across keys (each query distributes attention). (d) “Attention always attends uniformly”—No, exponential weighting (softmax) creates sharp distributions for high-similarity pairs. (e) “Transformers don’t need positional info”—False—raw attention is permutation-invariant, positional encodings are essential.

Connection to other examples: Example 70 (attention as data-dependent projection), Example 72 (query-key similarity via inner products), Example 79 (sparse attention masks), Example 74 (multi-head as low-rank approximation). This example provides the implementation foundation for understanding BERT, GPT, Vision Transformers, and diffusion models—all built on variants of this core mechanism.

Why pedagogically powerful: It isolates the core attention operation (scaled dot-product attention) in minimal code (~10 lines), showing how three matrices ($Q, K, V$) and two operations (matmul + softmax) produce dynamic context aggregation. The small dimensions (2 queries, 3 keys) make manual calculation feasible, demystifying the “black box” of Transformers. Students see that attention is fundamentally linear algebra (matrix products + softmax), not magic. This is the gateway to understanding modern NLP, computer vision, and generative AI—attention is the computational primitive underlying language models (GPT-4, Gemini), vision models (ViT, CLIP), and multimodal systems (DALL-E, Stable Diffusion).

History and Applications

Historical foundations: Attention mechanisms originated in neural machine translation (NMT) to address the information bottleneck of encoder-decoder architectures. In Bahdanau et al. (2015), RNN-based seq2seq models compressed entire source sentences into a single fixed-length vector $h$ (the final encoder hidden state), forcing $h$ to encode all information—impossible for long sentences. Bahdanau’s additive attention let the decoder dynamically attend to all encoder states $h_1, \ldots, h_n$, computing weighted sums $c_t = \sum_i \alpha_{ti} h_i$ at each decoding step $t$. Weights $\alpha_{ti}$ were learned via a small neural network (“alignment model”) that scored query-key compatibility: $e_{ti} = \text{MLP}(s_{t-1}, h_i)$, then $\alpha_{ti} = \text{softmax}(e_{ti})$. This improved BLEU scores by 5-10 points on WMT translation benchmarks, demonstrating that dynamic, context-dependent routing of information was far superior to fixed bottlenecks.

Vaswani et al. (2017) revolutionized the field with “Attention Is All You Need,” introducing Transformers. They replaced recurrent layers entirely with self-attention, where each token attends to all other tokens in parallel. Their key innovations: (1) Scaled dot-product attention ($QK^\top / \sqrt{d_k}$) replaced additive attention—simpler, faster (pure BLAS3 matrix multiplies), and more interpretable. (2) Multi-head attention ran $h = 8$ heads in parallel, each learning different dependency patterns (syntactic, semantic, positional). (3) Positional encodings (sinusoidal) injected sequence order, compensating for attention’s permutation-invariance. (4) No recurrence: All tokens processed simultaneously, enabling 10-100× training speedup on GPUs/TPUs (RNNs required sequential steps, $n$ serial dependencies per layer). On WMT English-German translation, Transformers achieved 28.4 BLEU (+2 over state-of-the-art ConvS2S), trained in 3.5 days on 8 P100 GPUs—an order of magnitude faster than RNN baselines.

Modern ML applications: Post-2017, Transformers became the dominant architecture across domains. BERT (Devlin et al., 2019) introduced bidirectional pre-training via masked language modeling (predict 15% of masked tokens from context) on 3.3B words (Wikipedia + BookCorpus). With 340M parameters (24 layers, 16 heads), BERT achieved state-of-the-art on 11 NLP benchmarks (GLUE, SQuAD), often by 5-10% absolute accuracy. BERT’s attention learned syntactic dependencies (subject-verb agreement spans 20+ tokens), coreference resolution (pronouns attend to antecedents), and semantic similarity—all emergent from self-supervised learning, no hand-engineered features.

GPT lineage (Radford et al., 2018-2020-2023): Autoregressive language modeling (“predict next token”) scaled to massive datasets and parameter counts. GPT-2 (1.5B params, 40GB WebText) demonstrated zero-shot generalization—models could translate, summarize, and answer questions without task-specific fine-tuning. GPT-3 (175B params, 300B tokens) achieved near-human performance on reading comprehension and writing tasks. GPT-4 (rumored 1T+ params) powers ChatGPT (2022), handling multi-turn conversation, code generation, and reasoning. Key architectural detail: GPT uses masked attention (causal attention)—at position $i$, attention only to positions $1, \ldots, i$ (future positions masked to $-\infty$), ensuring autoregressive generation.

Vision Transformers (ViT): Dosovitskiy et al. (2021) applied Transformers to images by splitting images into $16 \times 16$ patches, flattening to 1D sequences, and treating patches as tokens. ViT-Large (307M params) matched or exceeded ResNet-152 on ImageNet (88.5% top-1 accuracy), with better scaling (ViT improves log-linearly with data, CNNs plateau). Self-attention in ViT learns spatial relationships: early layers attend locally (edges, textures), late layers globally (object parts, context). Hybrid models (CLIP, DINO, SAM) use ViT for vision + Transformers for language, achieving zero-shot image classification, open-vocabulary segmentation, and cross-modal retrieval.

Diffusion models: Stable Diffusion (Rombach et al., 2022) and DALL-E 2 (Ramesh et al., 2022) use cross-attention to condition image generation on text. Text encoder (CLIP or T5) produces embeddings $E_{\text{text}} \in \mathbb{R}^{n_t \times d}$; image decoder (U-Net with Transformer blocks) generates latent image $I \in \mathbb{R}^{n_p \times d}$ (patches). At each diffusion step, decoder queries $Q = I W_Q$ attend to text keys $K = E_{\text{text}} W_K$, pulling semantic information from prompt into image: “a cat wearing a hat” → cross-attention ensures cat pixels attend to “cat,” hat pixels to “hat.” This mechanism enables fine-grained text-to-image alignment, generating 512×512 images in seconds on consumer GPUs.

Efficient attention (2019-2023): Quadratic complexity ($O(n^2)$) limits Transformers to $n \sim 2048$ (GPU memory). Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020) use sparse attention patterns (local + global + random): each token attends to $w$ neighbors + $g$ global tokens + $r$ random tokens, reducing to $O(n(w + g + r)) = O(n)$ for fixed $w, g, r$. Enables 4096-16384 token contexts. Linformer (Wang et al., 2020) approximates attention via low-rank projection: $QK^\top \approx Q (E K)^\top + Q (F K)^\top$ with $E, F \in \mathbb{R}^{n \times k}$ ($k = 256 \ll n$), reducing to $O(nk)$. Performer (Choromanski et al., 2020) uses kernel approximation: rewrite softmax as $\exp(q^\top k) \approx \phi(q)^\top \phi(k)$ with random Fourier features $\phi: \mathbb{R}^d \to \mathbb{R}^m$ ($m = 256$), then factorize: $\sum_j \phi(q_i)^\top \phi(k_j) v_j = \phi(q_i)^\top (\sum_j \phi(k_j) v_j)$. Pre-compute $\sum_j \phi(k_j) v_j$ once ($O(nm)$), then queries in $O(m)$ per token—total $O(nm)$ for all queries, linear in $n$. FlashAttention (Dao et al., 2022) exploits IO-aware tiling: instead of materializing full $A \in \mathbb{R}^{n \times n}$ in HBM (GPU memory), compute attention in blocks that fit in SRAM (on-chip cache), recomputing partial sums—reduces memory from $O(n^2)$ to $O(n)$ with no approximation, $2\times$ faster and $3\times$ less memory than standard PyTorch attention. Enables training 64k-token contexts (Llama-2-Long, 2023).

Current frontiers (2023-2024): Transformers dominate language (GPT-4, Claude, Gemini), vision (ViT, DINOv2), multimodal (CLIP, Flamingo, GPT-4V), reinforcement learning (Decision Transformer), biology (AlphaFold2), and code generation (Copilot, CodeLlama). Mixture-of-experts (Switch Transformer, Mistral 8x7B) scale to trillions of parameters via sparse routing (each token activates 1 of $k$ experts). Retrieval-augmented generation (RAG) combines attention over retrieved documents (from vector databases) with parametric knowledge, grounding LLMs in external sources. Long-context models (Claude 100k, GPT-4-Turbo 128k) use efficient attention variants + positional interpolation (RoPE, ALiBi) to process entire books/codebases. Attention remains the universal building block for sequence modeling—from 512-token BERT (2019) to 1M-token research prototypes (2024), scaled dot-product attention ($QK^\top V$) is the computational primitive underlying the AI revolution.

Connection to Broader Examples
  • Example 72 (Inner products and cosine similarity): Attention scores $q_i^\top k_j$ are inner products measuring query-key similarity. For normalized $q, k$, this becomes cosine similarity. The scaling factor $1/\sqrt{d_k}$ normalizes for dimension, analogous to dividing by norms in cosine similarity.

  • Example 70 (Projections): Attention output $O_i = \sum_j A_{ij} V_j$ is a data-dependent projection onto the convex hull of value vectors $\{V_1, \ldots, V_{n_k}\}$, where projection weights come from softmax-normalized query-key scores instead of orthogonal least-squares projections.

  • Example 74 (SVD and low-rank approximations): Multi-head attention can be viewed as a low-rank approximation mechanism—each head learns a rank-$d_k$ projection of the input space. The concatenation of heads reconstructs a higher-rank representation, similar to truncated SVD retaining top singular vectors.

  • Example 76 (Least squares and regularization): Attention dropout (randomly zeroing attention weights) acts as regularization, preventing the model from overfitting to specific key-value dependencies. This parallels ridge regression ($\ell_2$ penalty) or Lasso ($\ell_1$ penalty) in least squares.

  • Example 79 (Sparse matrices): Efficient attention variants (Longformer, BigBird) use sparse attention patterns (local windows + global tokens) to reduce $O(n^2)$ complexity to $O(n \log n)$. The attention matrix $A$ becomes sparse (most entries masked to zero), stored in COO/CSR format like sparse adjacency matrices.

  • Example 71 (Norms and distances): The scaling factor $1/\sqrt{d_k}$ normalizes query-key dot products to have unit variance (for random unit vectors), preventing gradient explosion. This is analogous to feature scaling ($X / \|X\|_2$) in least squares to improve conditioning.

  • Chapter 11 (PCA): Attention weights $A$ can be interpreted as a soft clustering mechanism—queries assign “cluster membership” (attention weights) to keys. This parallels soft k-means or PCA (projecting onto principal components with data-dependent weights).

  • Chapter 16 (Matrix products): This example showcases composition of matrix products as a design pattern: $O = (\text{softmax}(QK^\top / \sqrt{d_k})) V$ chains two matrix multiplies with a nonlinear activation. Modern accelerators (GPUs, TPUs) are optimized for this pattern via fused kernels (FlashAttention).

Comments