Matrix products became the computational substrate of deep learning because GPUs and TPUs are hardware-optimized for dense matrix multiplication (GEMM kernels). Early neural networks (1980s-90s) used matrix-vector products for feedforward layers ($y = Wx + b$), but attention mechanisms introduced matrix-matrix products as the dominant operation. The attention mechanism originated in sequence-to-sequence models (Bahdanau et al., 2015) to allow decoders to âattendâ to relevant encoder states, replacing fixed-length context vectors with dynamic weighted sums. Transformers (Vaswani et al., 2017, âAttention Is All You Needâ) removed recurrence entirely, using only attention and feedforward layers, enabling parallelization across sequence positions (unlike RNNs, which process sequentially). The scaled dot-product formulation $\text{softmax}(QK^\top / \sqrt{d_k}) V$ emerged as the efficient implementation: (1) $QK^\top$ computes all $n_q \times n_k$ pairwise similarities in one BLAS3 operation ($O(n^2 d)$ for self-attention), (2) softmax normalizes in $O(n^2)$, (3) $AV$ aggregates values in another BLAS3 call. This three-step patternâscore, normalize, aggregateâbecame the universal attention primitive, powering BERT (Devlin et al., 2019, 340M parameters, pre-training on masked language modeling), GPT (Radford et al., 2018-2020, autoregressive language models scaling to 175B parameters), Vision Transformers (Dosovitskiy et al., 2021, image patches as tokens), and diffusion models (Ho et al., 2020, cross-attention for text-to-image generation). The quadratic $O(n^2)$ complexity drove research into efficient attention variants: sparse patterns (Longformer, BigBird), low-rank approximations (Linformer), and kernel methods (Performers), enabling $10^4$-token contexts.
- Log in to post comments
Demonstrate how scaled dot-product attentionâthe core operation of Transformer architecturesâdecomposes into two matrix products ($QK^\top$ for similarity scoring, $AV$ for weighted aggregation) connected by row-wise softmax normalization. For queries $Q \in \mathbb{R}^{n_q \times d_k}$, keys $K \in \mathbb{R}^{n_k \times d_k}$, and values $V \in \mathbb{R}^{n_k \times d_v}$, attention computes $O = \text{softmax}(QK^\top / \sqrt{d_k}) V \in \mathbb{R}^{n_q \times d_v}$, producing data-dependent weighted combinations of values where weights arise from query-key similarity. The scaling factor $1/\sqrt{d_k}$ prevents dot products from growing with dimension (which would cause softmax saturation and vanishing gradients). This operationâ$O(n_q n_k d_k)$ for scoring, $O(n_q n_k d_v)$ for aggregationâis the computational bottleneck of Transformers, limiting sequence length to $\sim 10^3$ tokens without specialized sparse attention patterns. The example validates row-stochasticity of attention weights (each queryâs weights sum to 1, forming a probability distribution over keys) and shows how shapes/transposes âfeel inevitableâ once the query-key-value paradigm is understood, enabling students to reason about forward/backward passes in BERT, GPT, and Vision Transformers without memorizing formulas.
Compute scores, attention weights, and outputs for a tiny attention head; verify weights sum to 1.
Scaled dot-product attention computes output $O \in \mathbb{R}^{n_q \times d_v}$ from queries $Q$, keys $K$, and values $V$ via:
\[ O = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V = A V, \]
where $A \in \mathbb{R}^{n_q \times n_k}$ is the attention weight matrix.
Step 1: Compute attention scores via scaled dot products:
\[ ext{scores} = \frac{Q K^\top}{\sqrt{d_k}} \in \mathbb{R}^{n_q \times n_k}, \]
where $Q \in \mathbb{R}^{n_q \times d_k}$, $K \in \mathbb{R}^{n_k \times d_k}$, and $d_k$ is the key dimension. Each entry $\text{scores}_{ij} = q_i^\top k_j / \sqrt{d_k}$ measures similarity between query $i$ and key $j$. Scaling by $1/\sqrt{d_k}$ prevents scores from growing proportionally to $d_k$ (inner products of random unit vectors have variance $d_k$), which would push softmax into saturation regions (one weight $\approx 1$, others $\approx 0$).
Step 2: Normalize to attention weights via row-wise softmax:
\[ A_{ij} = \frac{\exp(\text{scores}_{ij})}{\sum_{k=1}^{n_k} \exp(\text{scores}_{ik})} \quad \text{for } i=1,\ldots,n_q. \]
Properties: (1) $A_{ij} \ge 0$ (non-negative), (2) $\sum_{j=1}^{n_k} A_{ij} = 1$ (rows sum to 1, row-stochastic), (3) $A$ is a probability distribution over keys for each query.
Step 3: Aggregate values via weighted sum:
\[ O = A V \in \mathbb{R}^{n_q \times d_v}, \]
where $V \in \mathbb{R}^{n_k \times d_v}$. Each output row is:
\[ O_i = \sum_{j=1}^{n_k} A_{ij} V_j, \]
a convex combination of value vectors weighted by attention scores.
Computational complexity: - $QK^\top$: $O(n_q n_k d_k)$ operations (matrix-matrix multiply) - Softmax: $O(n_q n_k)$ operations (exponentiation + normalization) - $AV$: $O(n_q n_k d_v)$ operations (matrix-matrix multiply) - Total: $O(n_q n_k (d_k + d_v))$ â $O(n^2 d)$ for self-attention ($n_q = n_k = n$, $d_k \approx d_v \approx d$)
Standard notation: - $Q \in \mathbb{R}^{n_q \times d_k}$: query matrix ($n_q$ queries, each $d_k$-dimensional) - $K \in \mathbb{R}^{n_k \times d_k}$: key matrix ($n_k$ keys, each $d_k$-dimensional) - $V \in \mathbb{R}^{n_k \times d_v}$: value matrix ($n_k$ values, each $d_v$-dimensional) - $A \in \mathbb{R}^{n_q \times n_k}$: attention weight matrix (row-stochastic) - $O \in \mathbb{R}^{n_q \times d_v}$: output matrix - Transpose: $K^\top$ (not $K^T$) - Softmax along axis 1 (across keys for each query)
import numpy as np
from scripts.toy_data import softmax
Q = np.array([[1., 0.],
[0., 1.]])
K = np.array([[1., 0.],
[1., 1.],
[0., 1.]])
V = np.array([[1., 0.],
[0., 2.],
[1., 1.]])
scores = Q @ K.T / np.sqrt(2)
A = softmax(scores, axis=1)
O = A @ V
print("A row sums:", A.sum(axis=1))
print("O:
", O)This code demonstrates scaled dot-product attention, the fundamental operation in Transformer architectures. Attention lets the model dynamically select and aggregate relevant information from a set of values, guided by query-key similarity scores. The mechanism decomposes into two sequential matrix products: $QK^\top$ (scoring) followed by $AV$ (aggregation).
Step-by-step execution of scaled dot-product attention:
Define input matrices: Queries $Q \in \mathbb{R}^{2 \times 2}$, keys $K \in \mathbb{R}^{3 \times 2}$, values $V \in \mathbb{R}^{3 \times 2}$ are hardcoded NumPy arrays representing two queries attending to three key-value pairs in a 2-dimensional space. Shapes: $Q$ has 2 rows (queries), $K$ has 3 rows (keys), both with $d_k = 2$ columns. $V$ has 3 rows (values) with $d_v = 2$ columns.
Compute scaled scores:
scores = Q @ K.T / np.sqrt(2)computes the matrix product $QK^\top \in \mathbb{R}^{2 \times 3}$, then scales by $1/\sqrt{d_k} = 1/\sqrt{2} \approx 0.707$. The transposeK.Tconverts $K$ from $(3, 2)$ to $(2, 3)$, making the product $(2, 2) \times (2, 3) \to (2, 3)$ valid. Each entry $\text{scores}_{ij}$ is the scaled dot product between query $i$ and key $j$. For example: $\text{scores}_{00} = (Q[0] \cdot K[0]) / \sqrt{2} = (1 \cdot 1 + 0 \cdot 0) / 1.414 \approx 0.707$.Apply row-wise softmax:
A = softmax(scores, axis=1)normalizes each row of scores independently (each queryâs distribution over keys). For row 0: $\text{scores}_0 = [0.707, 0.707, 0]$ becomes $A_0 \approx [0.382, 0.382, 0.236]$ after exponentiating and normalizing by the row sum. Theaxis=1argument is criticalâit specifies normalization across keys (columns) for each query (row). Usingaxis=0would normalize across queries, which is meaningless.Aggregate values:
O = A @ Vcomputes the output $O \in \mathbb{R}^{2 \times 2}$ by multiplying attention weights $(2, 3)$ by values $(3, 2) \to (2, 2)$. Each output row is a weighted average of value vectors: $O_0 = 0.382 \cdot V_0 + 0.382 \cdot V_1 + 0.236 \cdot V_2 \approx [0.618, 1.000]$.Verify row-stochasticity:
A.sum(axis=1)computes row sums, yielding[1., 1.](each queryâs weights sum to 1, confirming valid probability distributions). This is a fundamental property of softmax normalizationâit guarantees that attention weights are convex combination coefficients.Inspect output: Print statements show attention weights sum to 1 and display the final output matrix $O$. The output represents query-specific weighted combinations of values, where weights reflect query-key similarity.
Two-stage matrix product pattern: Attention decomposes into (1) similarity scoring via $\text{scores} = QK^\top / \sqrt{d_k} \in \mathbb{R}^{n_q \times n_k}$, computing all pairwise query-key dot products in one matrix multiply ($O(n_q n_k d_k)$ operations), and (2) value aggregation via $O = AV \in \mathbb{R}^{n_q \times d_v}$, forming weighted sums of values ($O(n_q n_k d_v)$ operations). Softmax normalization ($O(n_q n_k)$) connects these stages, converting raw scores to probability distributions.
Scaled dot products prevent saturation: Without the $1/\sqrt{d_k}$ factor, scores grow proportionally to $d_k$ (for random unit vectors, $\mathbb{E}[q^\top k] = 0$, $\text{Var}(q^\top k) = d_k$). For $d_k = 512$ (typical in GPT), unscaled scores would be $\sim 20\times$ larger, pushing softmax into saturation (one weight $\approx 1$, others $\approx 0$, gradients $\approx 0$). Scaling keeps scores in a moderate range ($\sim [-2, 2]$), preserving gradient flow. This example uses $d_k = 2$, so scaling by $1/\sqrt{2} \approx 0.707$ reduces scores from $\{0, 1\}$ to $\{0, 0.707\}$.
Row-stochastic attention weights:
A.sum(axis=1)returns[1., 1.], confirming each queryâs attention weights form a probability distribution over keys ($\sum_j A_{ij} = 1$, $A_{ij} \ge 0$). This normalization ensures outputs are convex combinations of values: $O_i = \sum_j A_{ij} V_j$, where weights depend on query-key similarity. The code verifies this property explicitly, showing that attention is not just weighted averaging but normalized weighted averaging.Shape discipline via transposes: The transpose $K^\top$ is essential for dimensional compatibility: $Q \in \mathbb{R}^{n_q \times d_k}$ times $K^\top \in \mathbb{R}^{d_k \times n_k}$ yields $\text{scores} \in \mathbb{R}^{n_q \times n_k}$ (all pairwise similarities). Without transpose, $Q @ K$ would fail (inner dimensions mismatch: $d_k \ne n_k$). This âshape reasoningâ generalizes to all attention variantsâonce the query-key-value paradigm is understood, transposes âfeel inevitable.â
Query-key symmetry and value asymmetry: $Q$ and $K$ must have the same dimension $d_k$ (to compute dot products), but $V$ can have a different dimension $d_v$ (values encode different information than keys). In self-attention ($Q, K, V$ all derived from the same input $X$), typically $d_k = d_v = d_{\text{model}} / h$ (model dimension divided by number of heads). In cross-attention (e.g., encoder-decoder), $K, V$ come from encoder outputs, $Q$ from decoder, allowing asymmetric dimensions.
Attention as soft dictionary lookup: The query-key-value structure mirrors database retrieval: query = âwhat information do I need?â, keys = âsearchable indexâ, values = âactual content to retrieveâ. Unlike hard indexing (retrieve record $k$), attention performs soft retrievalâa weighted combination of all values, where weights are determined by query-key similarity. This differentiability enables end-to-end training of the attention mechanism itself (learning optimal $Q, K, V$ projections).
Part 1: Query-Key Similarity via Scaled Dot Products
The first matrix product $QK^\top$ computes all pairwise similarities between queries and keys in one operation. For $Q \in \mathbb{R}^{2 \times 2}$ and $K \in \mathbb{R}^{3 \times 2}$, the product $QK^\top \in \mathbb{R}^{2 \times 3}$ has 6 entries (2 queries $\times$ 3 keys), each representing the dot product $q_i^\top k_j$. The transpose $K^\top$ is essential: $Q$ has shape $(2, 2)$, $K$ has shape $(3, 2)$, so $Q @ K$ would fail (inner dimensions $2 \ne 3$ mismatch), but $Q @ K.T$ succeeds ($(2, 2) \times (2, 3) \to (2, 3)$). Scaling by $1/\sqrt{d_k}$: For random unit vectors $q, k \in \mathbb{R}^{d_k}$, the dot product $q^\top k$ has variance $d_k$ (sum of $d_k$ independent products, each with variance 1). For $d_k = 512$ (GPT-3), unscaled scores would have standard deviation $\sim 23$, pushing softmax into saturation (e.g., $\exp(20) / (\exp(20) + \exp(0)) \approx 0.9999$, one key dominates, gradients vanish). Scaling by $1/\sqrt{d_k}$ normalizes variance to $1$, keeping scores in a moderate range ($\sim [-2, 2]$) where softmax remains sensitive to all inputs. Why not normalize by $d_k$? Dividing by $d_k$ (linear scaling) would make scores $\sim 1/d_k$ (tiny for large $d_k$), causing underflow in softmax. The $\sqrt{d_k}$ balances: large enough to prevent saturation, small enough to avoid underflow.
Part 2: Softmax Normalization to Attention Weights
Row-wise softmax converts raw scores to probability distributions: $A_{ij} = \exp(\text{scores}_{ij}) / \sum_k \exp(\text{scores}_{ik})$. Properties: (1) Non-negative: $\exp(\cdot) > 0$ ensures $A_{ij} \ge 0$. (2) Row-stochastic: $\sum_j A_{ij} = 1$ (each query distributes 100% attention across keys). (3) Differentiable: Softmax gradients $\partial A_{ij} / \partial \text{scores}_{ik} = A_{ij} (\delta_{jk} - A_{ik})$ enable backpropagation. Why softmax vs. other normalizations? (a) Max-normalization ($A_{ij} = \text{scores}_{ij} / \max_k \text{scores}_{ik}$) violates row-stochasticity (sums $\ne 1$). (b) L1-normalization ($A_{ij} = \text{scores}_{ij} / \sum_k \text{scores}_{ik}$) can produce negative weights if scores are negative. (c) Softmax (exponential + normalization) guarantees non-negativity and stochasticity, and has a maximum entropy interpretation (Boltzmann distribution). Temperature scaling: Softmax with temperature $\tau$: $A_{ij} = \exp(\text{scores}_{ij} / \tau) / \sum_k \exp(\text{scores}_{ik} / \tau)$. As $\tau \to 0$, attention becomes âhardâ (argmax, one-hot). As $\tau \to \infty$, attention becomes uniform (all keys equally weighted). Standard attention uses $\tau = 1$.
Part 3: Value Aggregation via Weighted Sum
The second matrix product $AV$ computes attention-weighted aggregations: $O = AV \in \mathbb{R}^{2 \times 2}$. Each output row is $O_i = \sum_{j=1}^3 A_{ij} V_j$, a convex combination of value vectors (since $\sum_j A_{ij} = 1$, $A_{ij} \ge 0$). Geometric interpretation: Each query $i$ produces an output $O_i$ lying in the convex hull of $\{V_1, V_2, V_3\}$ (the set of all weighted averages). Why separate keys and values? In principle, we could set $V = K$ (attention weights determined by and applied to the same vectors). However, decoupling allows keys to encode what information is available (searchable metadata) while values encode what to retrieve (actual content). In practice, $Q, K, V$ are learned projections of the input: $Q = XW_Q$, $K = XW_K$, $V = XW_V$, where $W_Q, W_K, W_V$ are trained end-to-end. Self-attention vs. cross-attention: In self-attention (e.g., BERT, GPT), $Q, K, V$ all come from the same sequence $X$. In cross-attention (e.g., encoder-decoder translation), $Q$ comes from decoder states, $K, V$ from encoder outputsâallowing the decoder to âattendâ to source sentence positions when generating target words.
Why This Matters for ML
Transformer architecture dominance: Transformers replaced RNNs/LSTMs as the default sequence model because attention enables full parallelizationâall positions processed simultaneously ($O(n^2 d)$ total, but parallelizable across $n$ positions), whereas RNNs require sequential processing ($O(nd)$ per step, $n$ sequential steps, total $O(n^2 d)$ but sequential). This $100-1000\times$ wall-clock speedup on GPUs enabled scaling to billions of parameters (GPT-3: 175B, PaLM: 540B). Quadratic bottleneck: For sequence length $n$ and dimension $d$, attention costs $O(n^2 d)$. For $n = 10^4$ (long documents), this becomes prohibitive ($10^8$ pairwise scores per head). Efficient attention variants address this: sparse attention (Longformer: $O(n \log n)$ via local+global masks), low-rank attention (Linformer: $O(nd)$ via key/value projection to rank $k \ll n$), kernel attention (Performer: $O(nd)$ via random feature approximation of softmax kernel). Multi-head attention: Standard Transformers use $h = 8$-16 parallel heads, each with $d_k = d_{\text{model}} / h$. Heads learn complementary attention patterns (e.g., head 1: syntactic dependencies, head 2: semantic similarity, head 3: positional locality). Positional encodings: Raw attention is permutation-invariant (shuffling input tokens doesnât change attention weights). Transformers add sinusoidal or learned positional embeddings to inject sequence order information: $X_{\text{pos}} = X + \text{PE}(pos)$, where $\text{PE}(pos)$ encodes absolute or relative positions.
ML Examples and Patterns
Self-attention layer (single head): Q = X @ W_Q, K = X @ W_K, V = X @ W_V project input $X \in \mathbb{R}^{n \times d_{\text{model}}}$ to queries/keys/values via learned matrices $W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_k}$. Then compute scores = Q @ K.T / sqrt(d_k), A = softmax(scores, axis=-1), O = A @ V. Output $O \in \mathbb{R}^{n \times d_k}$ is projected back to model dimension via $W_O \in \mathbb{R}^{d_k \times d_{\text{model}}}$.
Multi-head attention: Run $h$ attention heads in parallel with $d_k = d_{\text{model}} / h$ per head, then concatenate: O_multi = Concat(head_1, ..., head_h) @ W_O, where head_i = Attention(X @ W_Q^i, X @ W_K^i, X @ W_V^i). PyTorch implementation: nn.MultiheadAttention(embed_dim=512, num_heads=8).
Masked attention (autoregressive): GPT-style models mask future positions to prevent lookahead: mask = torch.triu(torch.ones(n, n), diagonal=1).bool(), then scores.masked_fill_(mask, -1e9) before softmax. This zeroes attention to future tokens, enforcing left-to-right generation.
Cross-attention (encoder-decoder): In translation, decoder queries attend to encoder keys/values: Q_dec = X_dec @ W_Q, K_enc = X_enc @ W_K, V_enc = X_enc @ W_V, then O = Attention(Q_dec, K_enc, V_enc). Decoder state at position $i$ can attend to all encoder positions (no causal mask).
Positional encoding: PE[pos, 2i] = sin(pos / 10000^(2i/d)), PE[pos, 2i+1] = cos(pos / 10000^(2i/d)) for $i=0, \ldots, d/2-1$. Add to embeddings: X_input = Embedding(tokens) + PE. Alternative: learned positional embeddings nn.Embedding(max_len, d_model).
Attention dropout: A = F.dropout(softmax(scores), p=0.1, training=True) randomly zeros attention weights during training to prevent overfitting to specific dependencies.
Connection to Linear Algebra Theory
Attention as convex combination: Each output $O_i = \sum_j A_{ij} V_j$ is a point in the convex hull of $\{V_1, \ldots, V_{n_k}\}$ (all weighted averages with non-negative weights summing to 1). This is a projection onto the value subspace, but weights are data-dependent (learned from query-key similarity) rather than fixed (like orthogonal projection in least squares).
Dot product as similarity metric: For normalized vectors $\|q\|_2 = \|k\|_2 = 1$, the dot product $q^\top k = \cos(\theta)$ is the cosine of the angle between $q$ and $k$. High dot product ($\cos(\theta) \approx 1$) means $\theta \approx 0$ (aligned vectors, high similarity). Attention amplifies this via softmax: similar pairs get disproportionately large weights.
Softmax as exponential family: Softmax $A_{ij} = \exp(s_{ij}) / Z_i$ (where $Z_i = \sum_k \exp(s_{ik})$) is the maximum entropy distribution subject to expected score constraint $\mathbb{E}[s] = \sum_j A_{ij} s_{ij}$. This is a Boltzmann distribution in statistical mechanics, with âtemperatureâ $\tau = 1$ (or $\tau = \sqrt{d_k}$ if unscaled).
Row-stochastic matrices as Markov transitions: Attention matrix $A$ (rows sum to 1) is a transition matrix for a Markov chain on keys. Applying $A$ repeatedly ($A^2, A^3, \ldots$) corresponds to multi-hop random walks on the key-value graph. Some architectures (e.g., Graph Attention Networks) explicitly interpret attention as graph traversal.
Rank-1 outer products: Each attention contribution $A_{ij} V_j$ is a rank-1 matrix (outer product of scalar weight $A_{ij}$ with vector $V_j$). Attention sums $n_k$ such rank-1 contributions: $O_i = \sum_j A_{ij} V_j$, analogous to SVD reconstruction $X \approx \sum_{k=1}^r \sigma_k u_k v_k^\top$.
Kernel methods interpretation: Attention can be written as $O_i = \sum_j \kappa(q_i, k_j) V_j / \sum_j \kappa(q_i, k_j)$, where $\kappa(q, k) = \exp(q^\top k / \sqrt{d_k})$ is an unnormalized Gaussian kernel. This enables kernel approximation techniques (random Fourier features, Nyström) for efficient attention (Performer, Linear Transformers).
Numerical and Implementation Notes
Shape discipline: $Q \in \mathbb{R}^{n_q \times d_k}$ (2 queries, 2-dim), $K \in \mathbb{R}^{n_k \times d_k}$ (3 keys, 2-dim), $V \in \mathbb{R}^{n_k \times d_v}$ (3 values, 2-dim). Scores: $(n_q, n_k) = (2, 3)$. Attention weights: $(n_q, n_k) = (2, 3)$. Output: $(n_q, d_v) = (2, 2)$. Check dimensions before every matrix multiply: $Q @ K.T$ requires $Q.shape[1] == K.shape[1]$ ($d_k$ match), $A @ V$ requires $A.shape[1] == V.shape[0]$ ($n_k$ match).
Gotcha 1: Softmax axis. softmax(scores, axis=1) normalizes across keys (axis=1, columns) for each query (axis=0, rows). Using axis=0 would normalize across queries, which is meaningless (each key would distribute attention to queries, backwards). Always normalize across the key dimension.
Gotcha 2: Scaling factor. Use $1/\sqrt{d_k}$, not $1/\sqrt{d_v}$ or $1/\sqrt{n_k}$. The goal is to counteract dot product variance growth with key dimension ($\text{Var}(q^\top k) = d_k$ for unit vectors).
Gotcha 3: Masked attention. For autoregressive models (GPT), mask future positions: mask = np.triu(np.ones((n_q, n_k)), k=1) (upper triangular), then scores = scores - 1e9 * mask (set future scores to $-\infty$, softmax $\to 0$). Donât use scores[mask] = -inf (in-place modification breaks autograd).
Gotcha 4: Batch dimensions. Production attention has shape (batch, num_heads, seq_len, seq_len). Ensure softmax normalizes over the last dimension (keys): F.softmax(scores, dim=-1). Broadcasting rules: (B, H, N_q, N_k) @ (B, H, N_k, D_v) -> (B, H, N_q, D_v).
Gotcha 5: Numerical stability. Softmax can overflow ($\exp(1000) \to \infty$) or underflow ($\exp(-1000) \to 0$). Use log-sum-exp trick: $\log \sum_j \exp(x_j) = \max_j x_j + \log \sum_j \exp(x_j - \max_j x_j)$. PyTorch F.softmax handles this automatically.
Gotcha 6: FlashAttention. For long sequences ($n > 2048$), materializing full $QK^\top \in \mathbb{R}^{n \times n}$ exceeds GPU memory. FlashAttention (Dao et al., 2022) fuses attention computation in CUDA kernels, tiling to avoid storing full attention matrixâ$3\times$ memory reduction, $2\times$ speedup.
Numerical and Shape Notes
Verification checks: (1) Row sums: A.sum(axis=1) should equal [1., 1.] (each queryâs weights sum to 1). (2) Non-negativity: A >= 0 all elements (softmax guarantees this). (3) Output shape: O.shape == (n_q, d_v) (2 queries, 2-dim values). (4) Numerical equivalence: Forward pass should match backprop gradients (use torch.autograd.gradcheck for validation).
Cost analysis: For $n_q = n_k = n = 512$ (typical sentence), $d_k = d_v = 64$ (per-head dimension): - $QK^\top$: $n^2 d_k = 512^2 \times 64 \approx 17M$ FLOPs - Softmax: $n^2 = 512^2 \approx 262K$ FLOPs (negligible) - $AV$: $n^2 d_v = 512^2 \times 64 \approx 17M$ FLOPs - Total: $\sim 34M$ FLOPs per head. For 8 heads: $\sim 270M$ FLOPs per layer. GPT-3 (96 layers, 16k context): $\sim 4 \times 10^{12}$ FLOPs per forward pass.
Memory: Attention matrix $A$ requires $n^2$ storage. For $n = 16384$ (Longformer), $A$ is $16384^2 \approx 268M$ elements ($\sim 1$ GB for float32). This is the memory bottleneck of Transformersâefficient variants (sparse, low-rank) reduce this to $O(n)$ or $O(n \log n)$.
Tolerance: Softmax is numerically stable (log-sum-exp trick), but attention weights can be tiny ($< 10^{-6}$) for dissimilar query-key pairs. Check for underflow: A.min() should be $> 0$ (if $\approx 0$, may indicate numerical issues or extreme sparsity).
Pedagogical Significance
This example is the foundational demonstration of attention mechanisms in modern deep learning:
Key takeaways: (1) Attention = weighted aggregation: Dynamically select and combine information based on query-key similarity. (2) Query-key-value paradigm: Separates âwhat to search forâ (query) from âwhatâs availableâ (key) and âwhat to retrieveâ (value). (3) Softmax normalization: Converts raw scores to probabilities (row-stochastic, interpretable weights). (4) Scaling prevents saturation: $1/\sqrt{d_k}$ keeps scores in moderate range for stable gradients. (5) Permutation invariance: Raw attention has no notion of order; positional encodings inject sequence structure.
Common misconceptions addressed: (a) âAttention requires recurrenceââNo, Transformers are fully feedforward (parallel), unlike RNNs. (b) âAttention is approximateââNo, itâs an exact weighted sum, fully differentiable. (c) âSoftmax axis doesnât matterââCriticalâmust normalize across keys (each query distributes attention). (d) âAttention always attends uniformlyââNo, exponential weighting (softmax) creates sharp distributions for high-similarity pairs. (e) âTransformers donât need positional infoââFalseâraw attention is permutation-invariant, positional encodings are essential.
Connection to other examples: Example 70 (attention as data-dependent projection), Example 72 (query-key similarity via inner products), Example 79 (sparse attention masks), Example 74 (multi-head as low-rank approximation). This example provides the implementation foundation for understanding BERT, GPT, Vision Transformers, and diffusion modelsâall built on variants of this core mechanism.
Why pedagogically powerful: It isolates the core attention operation (scaled dot-product attention) in minimal code (~10 lines), showing how three matrices ($Q, K, V$) and two operations (matmul + softmax) produce dynamic context aggregation. The small dimensions (2 queries, 3 keys) make manual calculation feasible, demystifying the âblack boxâ of Transformers. Students see that attention is fundamentally linear algebra (matrix products + softmax), not magic. This is the gateway to understanding modern NLP, computer vision, and generative AIâattention is the computational primitive underlying language models (GPT-4, Gemini), vision models (ViT, CLIP), and multimodal systems (DALL-E, Stable Diffusion).
Historical foundations: Attention mechanisms originated in neural machine translation (NMT) to address the information bottleneck of encoder-decoder architectures. In Bahdanau et al. (2015), RNN-based seq2seq models compressed entire source sentences into a single fixed-length vector $h$ (the final encoder hidden state), forcing $h$ to encode all informationâimpossible for long sentences. Bahdanauâs additive attention let the decoder dynamically attend to all encoder states $h_1, \ldots, h_n$, computing weighted sums $c_t = \sum_i \alpha_{ti} h_i$ at each decoding step $t$. Weights $\alpha_{ti}$ were learned via a small neural network (âalignment modelâ) that scored query-key compatibility: $e_{ti} = \text{MLP}(s_{t-1}, h_i)$, then $\alpha_{ti} = \text{softmax}(e_{ti})$. This improved BLEU scores by 5-10 points on WMT translation benchmarks, demonstrating that dynamic, context-dependent routing of information was far superior to fixed bottlenecks.
Vaswani et al. (2017) revolutionized the field with âAttention Is All You Need,â introducing Transformers. They replaced recurrent layers entirely with self-attention, where each token attends to all other tokens in parallel. Their key innovations: (1) Scaled dot-product attention ($QK^\top / \sqrt{d_k}$) replaced additive attentionâsimpler, faster (pure BLAS3 matrix multiplies), and more interpretable. (2) Multi-head attention ran $h = 8$ heads in parallel, each learning different dependency patterns (syntactic, semantic, positional). (3) Positional encodings (sinusoidal) injected sequence order, compensating for attentionâs permutation-invariance. (4) No recurrence: All tokens processed simultaneously, enabling 10-100à training speedup on GPUs/TPUs (RNNs required sequential steps, $n$ serial dependencies per layer). On WMT English-German translation, Transformers achieved 28.4 BLEU (+2 over state-of-the-art ConvS2S), trained in 3.5 days on 8 P100 GPUsâan order of magnitude faster than RNN baselines.
Modern ML applications: Post-2017, Transformers became the dominant architecture across domains. BERT (Devlin et al., 2019) introduced bidirectional pre-training via masked language modeling (predict 15% of masked tokens from context) on 3.3B words (Wikipedia + BookCorpus). With 340M parameters (24 layers, 16 heads), BERT achieved state-of-the-art on 11 NLP benchmarks (GLUE, SQuAD), often by 5-10% absolute accuracy. BERTâs attention learned syntactic dependencies (subject-verb agreement spans 20+ tokens), coreference resolution (pronouns attend to antecedents), and semantic similarityâall emergent from self-supervised learning, no hand-engineered features.
GPT lineage (Radford et al., 2018-2020-2023): Autoregressive language modeling (âpredict next tokenâ) scaled to massive datasets and parameter counts. GPT-2 (1.5B params, 40GB WebText) demonstrated zero-shot generalizationâmodels could translate, summarize, and answer questions without task-specific fine-tuning. GPT-3 (175B params, 300B tokens) achieved near-human performance on reading comprehension and writing tasks. GPT-4 (rumored 1T+ params) powers ChatGPT (2022), handling multi-turn conversation, code generation, and reasoning. Key architectural detail: GPT uses masked attention (causal attention)âat position $i$, attention only to positions $1, \ldots, i$ (future positions masked to $-\infty$), ensuring autoregressive generation.
Vision Transformers (ViT): Dosovitskiy et al. (2021) applied Transformers to images by splitting images into $16 \times 16$ patches, flattening to 1D sequences, and treating patches as tokens. ViT-Large (307M params) matched or exceeded ResNet-152 on ImageNet (88.5% top-1 accuracy), with better scaling (ViT improves log-linearly with data, CNNs plateau). Self-attention in ViT learns spatial relationships: early layers attend locally (edges, textures), late layers globally (object parts, context). Hybrid models (CLIP, DINO, SAM) use ViT for vision + Transformers for language, achieving zero-shot image classification, open-vocabulary segmentation, and cross-modal retrieval.
Diffusion models: Stable Diffusion (Rombach et al., 2022) and DALL-E 2 (Ramesh et al., 2022) use cross-attention to condition image generation on text. Text encoder (CLIP or T5) produces embeddings $E_{\text{text}} \in \mathbb{R}^{n_t \times d}$; image decoder (U-Net with Transformer blocks) generates latent image $I \in \mathbb{R}^{n_p \times d}$ (patches). At each diffusion step, decoder queries $Q = I W_Q$ attend to text keys $K = E_{\text{text}} W_K$, pulling semantic information from prompt into image: âa cat wearing a hatâ â cross-attention ensures cat pixels attend to âcat,â hat pixels to âhat.â This mechanism enables fine-grained text-to-image alignment, generating 512Ã512 images in seconds on consumer GPUs.
Efficient attention (2019-2023): Quadratic complexity ($O(n^2)$) limits Transformers to $n \sim 2048$ (GPU memory). Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020) use sparse attention patterns (local + global + random): each token attends to $w$ neighbors + $g$ global tokens + $r$ random tokens, reducing to $O(n(w + g + r)) = O(n)$ for fixed $w, g, r$. Enables 4096-16384 token contexts. Linformer (Wang et al., 2020) approximates attention via low-rank projection: $QK^\top \approx Q (E K)^\top + Q (F K)^\top$ with $E, F \in \mathbb{R}^{n \times k}$ ($k = 256 \ll n$), reducing to $O(nk)$. Performer (Choromanski et al., 2020) uses kernel approximation: rewrite softmax as $\exp(q^\top k) \approx \phi(q)^\top \phi(k)$ with random Fourier features $\phi: \mathbb{R}^d \to \mathbb{R}^m$ ($m = 256$), then factorize: $\sum_j \phi(q_i)^\top \phi(k_j) v_j = \phi(q_i)^\top (\sum_j \phi(k_j) v_j)$. Pre-compute $\sum_j \phi(k_j) v_j$ once ($O(nm)$), then queries in $O(m)$ per tokenâtotal $O(nm)$ for all queries, linear in $n$. FlashAttention (Dao et al., 2022) exploits IO-aware tiling: instead of materializing full $A \in \mathbb{R}^{n \times n}$ in HBM (GPU memory), compute attention in blocks that fit in SRAM (on-chip cache), recomputing partial sumsâreduces memory from $O(n^2)$ to $O(n)$ with no approximation, $2\times$ faster and $3\times$ less memory than standard PyTorch attention. Enables training 64k-token contexts (Llama-2-Long, 2023).
Current frontiers (2023-2024): Transformers dominate language (GPT-4, Claude, Gemini), vision (ViT, DINOv2), multimodal (CLIP, Flamingo, GPT-4V), reinforcement learning (Decision Transformer), biology (AlphaFold2), and code generation (Copilot, CodeLlama). Mixture-of-experts (Switch Transformer, Mistral 8x7B) scale to trillions of parameters via sparse routing (each token activates 1 of $k$ experts). Retrieval-augmented generation (RAG) combines attention over retrieved documents (from vector databases) with parametric knowledge, grounding LLMs in external sources. Long-context models (Claude 100k, GPT-4-Turbo 128k) use efficient attention variants + positional interpolation (RoPE, ALiBi) to process entire books/codebases. Attention remains the universal building block for sequence modelingâfrom 512-token BERT (2019) to 1M-token research prototypes (2024), scaled dot-product attention ($QK^\top V$) is the computational primitive underlying the AI revolution.
Example 72 (Inner products and cosine similarity): Attention scores $q_i^\top k_j$ are inner products measuring query-key similarity. For normalized $q, k$, this becomes cosine similarity. The scaling factor $1/\sqrt{d_k}$ normalizes for dimension, analogous to dividing by norms in cosine similarity.
Example 70 (Projections): Attention output $O_i = \sum_j A_{ij} V_j$ is a data-dependent projection onto the convex hull of value vectors $\{V_1, \ldots, V_{n_k}\}$, where projection weights come from softmax-normalized query-key scores instead of orthogonal least-squares projections.
Example 74 (SVD and low-rank approximations): Multi-head attention can be viewed as a low-rank approximation mechanismâeach head learns a rank-$d_k$ projection of the input space. The concatenation of heads reconstructs a higher-rank representation, similar to truncated SVD retaining top singular vectors.
Example 76 (Least squares and regularization): Attention dropout (randomly zeroing attention weights) acts as regularization, preventing the model from overfitting to specific key-value dependencies. This parallels ridge regression ($\ell_2$ penalty) or Lasso ($\ell_1$ penalty) in least squares.
Example 79 (Sparse matrices): Efficient attention variants (Longformer, BigBird) use sparse attention patterns (local windows + global tokens) to reduce $O(n^2)$ complexity to $O(n \log n)$. The attention matrix $A$ becomes sparse (most entries masked to zero), stored in COO/CSR format like sparse adjacency matrices.
Example 71 (Norms and distances): The scaling factor $1/\sqrt{d_k}$ normalizes query-key dot products to have unit variance (for random unit vectors), preventing gradient explosion. This is analogous to feature scaling ($X / \|X\|_2$) in least squares to improve conditioning.
Chapter 11 (PCA): Attention weights $A$ can be interpreted as a soft clustering mechanismâqueries assign âcluster membershipâ (attention weights) to keys. This parallels soft k-means or PCA (projecting onto principal components with data-dependent weights).
Chapter 16 (Matrix products): This example showcases composition of matrix products as a design pattern: $O = (\text{softmax}(QK^\top / \sqrt{d_k})) V$ chains two matrix multiplies with a nonlinear activation. Modern accelerators (GPUs, TPUs) are optimized for this pattern via fused kernels (FlashAttention).
Comments