Example number
48
Example slug
example_48_transformers_attention_as_qk_t_then_av
Background

Historical context: The attention mechanism originated in sequence-to-sequence models (Bahdanau et al., 2015) as a way to let decoders selectively focus on encoder states, replacing fixed-length context vectors with weighted sums. The breakthrough came with “Attention Is All You Need” (Vaswani et al., 2017), which dispensed with RNNs entirely and built transformers purely from attention and feedforward layers. The key insight: attention is query-key-value retrieval via matrix products, not a specialized neural module. This reframing made attention parallelizable (all queries attend simultaneously) and hardware-efficient (GPUs/TPUs are optimized for dense matrix multiplication).

Mathematical characterization: Scaled dot-product attention computes: \[ ext{Attention}(Q, K, V) = ext{softmax}\left(rac{QK^ op}{\sqrt{d_k}} ight) V, \] where $Q \in \mathbb{R}^{n_q imes d_k}$ (queries), $K \in \mathbb{R}^{n_k imes d_k}$ (keys), $V \in \mathbb{R}^{n_k imes d_v}$ (values). The first product $QK^ op \in \mathbb{R}^{n_q imes n_k}$ computes all pairwise query-key similarities; dividing by $\sqrt{d_k}$ prevents score variance from growing with dimension (which would saturate softmax). Softmax row-wise normalizes to produce attention weights $A \in \mathbb{R}^{n_q imes n_k}$, where each row is a probability distribution over keys. The second product $AV \in \mathbb{R}^{n_q imes d_v}$ aggregates values weighted by attention.

Prevalence in ML: Transformers now dominate NLP (BERT, GPT, T5), vision (ViT, CLIP), speech (Whisper), and multimodal tasks (Flamingo, GPT-4). Self-attention (where $Q, K, V$ all come from the same sequence) captures long-range dependencies without the sequential bottleneck of RNNs. Multi-head attention runs multiple attention operations in parallel, each learning different query-key relationships (syntax, semantics, positional patterns).

Why attention matters: Unlike convolution (local receptive fields) or RNN (sequential processing), attention provides global, content-based context in a single layer. Every position can attend to every other position, with learned weights determining relevance. This enables modeling dependencies across arbitrary distances, making transformers the architecture of choice for large-scale pretraining.

Purpose

Show how attention decomposes into two matrix products—$QK^ op$ (query-key scoring) followed by $AV$ (attention-weighted value aggregation)—making the transformer architecture’s core computation transparent. Emphasize that shapes dictate the algorithm: queries and keys must have matching dimension for dot products, softmax normalizes row-wise to produce probability distributions, and the output inherits value dimension. Build intuition for why this pattern—content-based addressing via similarity, soft weighting via softmax, and convex combination of values—powers modern NLP and vision models. Connect matrix product mechanics (transpose for alignment, batched inner products) to ML semantics (attention weights, context aggregation). Stress that understanding this as linear algebra (not neural network magic) enables reasoning about gradients, memory complexity, and architectural variations (multi-head, sparse, cross-attention).

Problem

Compute scores, attention weights, and outputs for a tiny attention head; verify weights sum to 1.

Solution (Math)

Attention: $O=\mathrm{softmax}$QK^T/$V$. Row-wise softmax yields row-stochastic weights; multiplying by $V$ forms weighted sums.

We use:

  • Data matrix $X\in\mathbb{R}^{n\times d}$ (rows are examples).
  • Vectors are column vectors by default.
  • $\|x\|_2$ is Euclidean norm; $\langle x,y\rangle=x^Ty$.
Solution (Python)

import numpy as np
from scripts.toy_data import softmax

Q = np.array([[1., 0.],
              [0., 1.]])
K = np.array([[1., 0.],
              [1., 1.],
              [0., 1.]])
V = np.array([[1., 0.],
              [0., 2.],
              [1., 1.]])

scores = Q @ K.T / np.sqrt(2)
A = softmax(scores, axis=1)
O = A @ V

print("A row sums:", A.sum(axis=1))
print("O:
", O)
Code Introduction

This snippet implements a single-head scaled dot-product attention forward pass using explicit matrix products. It computes query–key scores, applies scaling and row-wise softmax to obtain attention weights, and then mixes values via a final matrix product.

  • Inputs: Q \in \mathbb{R}^{n\times d_k}, K \in \mathbb{R}^{m\times d_k}, V \in \mathbb{R}^{m\times d_v}.
  • Scores: S = Q @ K.T, then scale by 1/\sqrt{d_k} to keep softmax in a numerically friendly range.
  • Weights: A = softmax(S, axis=1); each row is a probability distribution. The code prints row sums to verify they equal 1.
  • Output: O = A @ V with shape (n, d_v). Shapes are printed or can be checked interactively.
  • Variants: Add a batch dimension for batched attention, or apply a mask to S before softmax for causal/structured attention.
Numerical Implementation Details

Numerical and Shape Notes

  • Shapes first: Declare shapes (e.g., $X \in \mathbb{R}^{n imes d}$, $w \in \mathbb{R}^{d}$, $b \in \mathbb{R}^{n}$). Vectors are column by convention; keep row/column usage consistent.
  • Axis discipline: Be explicit with axis in reductions and normalizations. For attention-like ops, softmax over keys (row-wise) so rows sum to ≈1.
  • Broadcasting: Check that broadcasts are intended (e.g., (n,1) with (n,d)). Prefer reshape/expand-dims to make semantics clear.
  • Stability eps: Add $arepsilon$ for divisions/logs and $arepsilon I$ (jitter) for SPD solves; use log-sum-exp for softmax.
  • Masking preserves shape: Masks should broadcast to the score/activation tensor; verify masked outputs keep the same shape and zero out excluded entries.
  • Dtype choices: Use float64 for clarity in scripts; with mixed precision, keep reductions/factorizations in float32/float64 to avoid under/overflow.
  • Sanity checks: Print shapes and residuals (e.g., ||Ax-b||, reconstruction error, row-sum ≈ 1). Assert finiteness and expected monotonicity where applicable.

Numerical and Implementation Notes

  • Dtype & precision: Prefer float64 for clarity; if using mixed precision, keep reductions (norms, softmax sums, factorizations) in float32/float64. Avoid explicit inverses; use solve, lstsq, Cholesky/QR/SVD.
  • Shapes & broadcasting: Annotate shapes (e.g., $X \in \mathbb{R}^{n imes d}$); vectors are column by default. Verify axes for reductions (axis) and ensure broadcasts are intended.
  • Stability: Use log-sum-exp for softmax; add small diagonal $arepsilon I$ (jitter) for SPD solves; prefer QR/SVD for ill-conditioned least squares.
  • Conditioning: Inspect np.linalg.cond(A) when solutions look unstable; regularize (ridge) or rescale features to improve conditioning.
  • Reproducibility: Set NumPy seed for random data; print shapes and residuals (e.g., ||Ax-b||, reconstruction errors) and assert finiteness.
  • Complexity & memory: Matmul ~ $O(n^3)$ for factorizations, $O(n^2)$ for triangular solves/products. Prefer vectorization over Python loops; avoid materializing large intermediates.
  • Masking & indexing: Use boolean masks that broadcast to target shapes; for attention-like ops, add $-\infty$ before softmax or zero-out after, then verify rows sum to ~1.
  • Sanity checks: Compare against references (e.g., lstsq vs. solve), check orthogonality (U.T @ U ≈ I), PSD (x.T @ A @ x > 0), and residual norms within tolerance (~1e-12 for float64).
  1. Define query, key, value matrices: Set up $Q \in \mathbb{R}^{2 imes 2}$, $K \in \mathbb{R}^{3 imes 2}$, $V \in \mathbb{R}^{3 imes 2}$ with explicit values for transparency.
  2. Compute scaled scores: scores = Q @ K.T / np.sqrt(d_k) yields $ ext{scores} ^{2 imes 3}$ with dimension $d_k = 2$.
  3. Apply row-wise softmax: A = softmax(scores, axis=1) normalizes each query’s scores into a probability distribution over keys.
  4. Verify attention weights: Check A.sum(axis=1) equals [1., 1.] to confirm valid probability distributions.
  5. Compute attention output: O = A @ V aggregates values weighted by attention, yielding $O \in \mathbb{R}^{2 imes 2}$.
  6. Shape tracking: $Q \in \mathbb{R}^{n_q imes d_k}$, $K \in \mathbb{R}^{n_k imes d_k}$, $V \in \mathbb{R}^{n_k imes d_v}$ → $O \in \mathbb{R}^{n_q imes d_v}$.
  7. Softmax axis: axis=1 normalizes across keys (columns) for each query (row); axis=0 would be incorrect.
  8. Scaling factor: $1/\sqrt{d_k}$ prevents dot product variance growth; critical for gradient stability in deep networks.
What This Example Demonstrates

Pedagogical Significance

  • Learning goals: Build intuition for when and why this tool is used in ML, not just how to compute it.
  • ML-first framing: Tie the concept to a concrete task pattern (fit / project / decompose / solve / measure) to anchor understanding.
  • Shape discipline: Habitually annotating dimensions prevents silent bugs and reinforces linear map thinking.
  • Numerical habits: Prefer stable factorizations over inverses; check residuals and condition numbers to separate bugs from ill-conditioning.
  • Transfer: Reuse the same pattern across models (e.g., projection in PCA, orthogonalization in regressions, attention as weighted sums).
  • Assessment ideas: Quick checks: predict sensitivity from $\kappa(A)$, verify projection properties, or compare solver outputs within tolerance.

ML Examples and Patterns

  • Fit: Linear/logistic regression via least squares or softmax; regularization (ridge) improves conditioning and generalization.
  • Project: PCA/SVD for dimensionality reduction; orthogonal projections to subspaces for denoising and feature extraction.
  • Decompose: Eigen/SVD factorizations to expose structure (low rank, PSD) used in recommender systems, LSA, and spectral clustering.
  • Solve: Stable solves without inversion (Cholesky/QR/SVD; CG for SPD) for optimization steps and kernel methods.
  • Measure: Norms, angles, and condition number $\kappa(A)$ to diagnose sensitivity, stability, and training difficulty.
  • Two-stage decomposition: Attention separates “where to look” ($QK^ op$ → similarity scores) from “what to retrieve” ($AV$ → weighted values).
  • Shape alignment via transpose: $QK^ op$ aligns query rows with key rows for pairwise dot products; without transpose, dimensions wouldn’t match.
  • Row-wise softmax yields probability distributions: Each row of $A$ sums to 1; attention weights are convex combination coefficients.
  • Output dimension follows values: $O \in \mathbb{R}^{n_q imes d_v}$ inherits value dimension $d_v$, not key dimension $d_k$.
  • Scaling prevents saturation: Dividing by $\sqrt{d_k}$ keeps score variance stable as dimension grows, avoiding vanishing gradients.
  • Content-based addressing: High query-key similarity → high attention weight; this is key-value retrieval by analogy.
  • Permutation equivariance: Reordering keys/values identically permutes attention output; order-independence requires positional encoding.
  • Efficient parallelization: All queries attend simultaneously (no sequential dependencies), leveraging hardware matrix-multiply units.
Notes

Part 1: Query-Key Scoring - $Q \in \mathbb{R}^{2 imes 2}$: 2 queries, each 2-dimensional. - $K \in \mathbb{R}^{3 imes 2}$: 3 keys, matching query dimension for dot products. - $ ext{scores} = QK^ op / ^{2 imes 3}$: entry $(i,j)$ is $\langle q_i, k_j angle / \sqrt{2}$. - Transpose alignment: $QK^ op$ computes all pairwise query-key similarities; $QK$ would fail (shape mismatch). - Scaling: Without $1/\sqrt{d_k}$, scores grow with dimension, pushing softmax into saturation (near-zero gradients).

Part 2: Attention Weights via Softmax - $A = ext{softmax}( ext{scores}, ext{axis}=1) \in \mathbb{R}^{2 imes 3}$. - Row-wise normalization: $A_{ij} = \exp( ext{scores}_{ij}) / \sum_k \exp( ext{scores}_{ik})$. - Each row of $A$ is a probability distribution: $\sum_j A_{ij} = 1$, $A_{ij} \ge 0$. - High scores → high attention weights; softmax is a differentiable approximation to argmax. - Numerical stability: Softmax implementations subtract $\max( ext{scores}_i)$ before exponentiating to prevent overflow.

Part 3: Output as Weighted Value Sum - $O = AV \in \mathbb{R}^{2 imes 2}$: each row is a weighted sum of value vectors. - $o_i = \sum_j A_{ij} v_j$: convex combination of values (weights sum to 1). - Output dimension: $O \in \mathbb{R}^{n_q imes d_v}$ matches value dimension $d_v$, independent of key dimension $d_k$. - Geometric interpretation: $o_i$ lies in the convex hull of value vectors; extreme attention (one weight near 1) retrieves a single value.

Why This Pattern Powers Transformers - Content-based addressing: Query-key similarity determines relevance; no hand-crafted attention patterns (unlike RNN’s sequential bias). - Parallelizable: All $n_q$ queries attend simultaneously; no sequential bottleneck. - Differentiable: Softmax and matrix products are smooth; gradients flow through entire attention operation. - Hardware-efficient: GPUs/TPUs optimize dense matrix multiplication (GEMM kernels); $QK^ op$ and $AV$ map directly to BLAS calls. - Composable: Stack many attention layers; each refines representations hierarchically.

Why This Matters for ML - Transformers dominate modern ML: BERT, GPT, ViT, CLIP, Whisper, Flamingo all use scaled dot-product attention. - Long-range dependencies: Attention captures dependencies across arbitrary distances in a single layer (vs. RNN’s sequential propagation). - Interpretability: Attention weights $A$ reveal which keys each query attends to; useful for debugging and visualization. - Scalability challenges: $O(n^2 d)$ complexity for sequence length $n$; sparse/efficient attention variants address this.

Connection to ML Applications - Self-attention: $Q, K, V$ all project from the same input sequence; each token attends to all others. - Cross-attention: Queries from one sequence, keys/values from another (e.g., decoder attends to encoder in seq2seq). - Multi-head attention: Run $h$ attention heads in parallel with different weight matrices; concatenate outputs for richer representations. - Masked attention: Set $ ext{scores}_{ij} = -$ for $j > i$ (causal mask) to prevent future tokens from influencing past predictions (autoregressive models). - Sparse attention: Restrict attention to local windows, strided patterns, or learned sparsity to reduce $O(n^2)$ to $O(n \log n)$ or $O(n)$.

Connection to Linear Algebra Theory - Matrix products as operator composition: $QK^ op$ followed by softmax and $AV$ is function composition with learned parameters. - Projection interpretation: Each output is a weighted projection onto the span of value vectors; attention weights are soft coefficients. - Inner products define similarity: Dot product $q \cdot k$ measures alignment; cosine similarity normalizes by norms for scale-invariance. - Softmax as differentiable max: Hard argmax is non-differentiable; softmax provides a smooth relaxation with usable gradients. - Low-rank structure: Attention matrices often have low effective rank; SVD-based compression or low-rank parameterizations (e.g., Linformer) exploit this.

Numerical and Implementation Notes - Avoid materializing large attention matrices: For long sequences, compute $QK^ op$ in blocks or use kernel fusion to save memory. - Softmax numerical stability: Subtract row-wise max before exponentiating: $\exp(x_i - \max(x))$ prevents overflow. - Gradient clipping: Attention gradients can explode in early training; clip by norm to stabilize. - Mixed precision: Use float16 for forward pass, float32 for softmax and gradients to balance speed and stability. - Flash Attention: Tiled computation of attention reduces memory from $O(n^2)$ to $O(n)$ by recomputing on-the-fly instead of materializing full $A$.

Numerical and Shape Notes - Shape checks: Before $QK^ op$, verify $Q.shape[1] == K.shape[1]$ (query/key dimension must match). - Softmax axis: axis=1 normalizes across keys for each query; verify with A.sum(axis=1) → all ones. - Output shape: $O.shape = (n_q, d_v)$, inheriting rows from queries and columns from values. - Batch dimensions: In practice, prepend batch dimension: $Q \in \mathbb{R}^{B imes n_q imes d_k}$; all operations generalize via broadcasting.

ML Context: From Attention to Transformers - Self-attention layer: Input $X \in \mathbb{R}^{n imes d}$ → project to $Q = XW_Q$, $K = XW_K$, $V = XW_V$ → attention → output $O \in \mathbb{R}^{n imes d_v}$. - Multi-head attention: Run $h$ heads in parallel, concatenate: $ ext{MultiHead}(X) = ext{Concat}( ext{head}_1, , ext{head}_h) W_O$. - Positional encoding: Attention is permutation-equivariant; add positional embeddings to $X$ to encode token order. - Residual connections: $X' = X + ext{Attention}(X)$ stabilizes training and enables deep stacks (100+ layers). - Layer normalization: Normalize hidden states before/after attention to control activation magnitudes and improve conditioning. - Feedforward sublayer: After attention, apply position-wise MLP: $ ext{FFN}(x) = W_2 (W_1 x)$ for additional expressiveness. - Causal masking: Autoregressive models mask future positions to maintain left-to-right generation order. - Cross-attention: Encoder-decoder models use cross-attention (decoder queries, encoder keys/values) for source-target alignment.

History and Applications

History: Attention emerged in sequence-to-sequence models (Bahdanau et al., 2015) to let decoders focus on relevant encoder states. The breakthrough came with Transformers (Vaswani et al., 2017), which reframed attention as two matrix products: $QK^\top$ (similarities) and $AV$ (weighted sums). This linear algebraic view made attention fully parallel and hardware-friendly.

Applications: Self-attention now powers NLP (BERT, GPT), vision (ViT), speech (Whisper), and multimodal models (CLIP). Variants include multi-head attention, cross-attention, causal masks for autoregressive generation, and sparse/linear attention for long contexts. In practice, this example’s shapes and matrix products map directly to high-performance GEMM kernels.

Connection to Broader Examples
  • Chapter 4 (Linear maps): Attention is a sequence of linear maps: $Q, K, V$ projections, then $QK^ op$, then $AV$.
  • Chapter 5 (Inner products): Query-key scoring is batched inner products; cosine similarity is a normalized variant.
  • Chapter 6 (Projections): Each output row is a weighted projection onto the span of value vectors.
  • Chapter 10 (SVD): Low-rank approximation of attention matrices enables efficient sparse attention variants.
  • Chapter 12 (Least-squares): Attention weights are soft assignments; hard assignments would solve a weighted least-squares problem.
  • Chapter 13 (Solving systems): Iterative attention refinement (e.g., adaptive span) adjusts weights via pseudo-inverse-like updates.
  • Chapter 14 (Conditioning): Ill-conditioned attention matrices (dominated by few large singular values) concentrate weights on few keys.
  • Chapter 15 (Sparse): Sparse attention (local windows, block patterns) exploits sparsity for $O(n)$ instead of $O(n^2)$ complexity.
  • Chapter 16 (Matrix products): Core chapter; attention is the flagship example of matrix product composition in ML.
  • Transformers: Multi-head attention, cross-attention, masked attention all use this $QK^ op V$ pattern.

Comments