Matrix products became the computational substrate of deep learning because GPUs/TPUs are engineered around GEMM-like kernels. Scaled dot-product attention (Vaswani et al. 2017) is two back-to-back matrix multipliesâ$QK^\top$ then $AV$âand its efficiency hinges on matching shapes to hardware-friendly GEMM calls. The $1/\sqrt{d_k}$ factor keeps logits numerically stable so softmax does not saturate when $d_k$ is large, which in turn keeps gradients well-conditioned.
- Log in to post comments
Make shapes and transposes feel inevitable so you can reason about forward/backward passes and attention without memorizing formulas: $Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$, scores $S = QK^\top/\sqrt{d_k} \in \mathbb{R}^{n_q\times n_k}$, weights $A = \mathrm{softmax}(S)$ row-stochastic, outputs $O = AV \in \mathbb{R}^{n_q\times d_v}$.
Compute scores, attention weights, and outputs for a tiny attention head; verify weights sum to 1.
Attention: $O = \mathrm{softmax}(QK^\top/\sqrt{d_k})V$. Row-wise softmax yields row-stochastic weights; multiplying by $V$ forms weighted sums. Shapes: $Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$, so $S, A \in \mathbb{R}^{n_q\times n_k}$ and $O \in \mathbb{R}^{n_q\times d_v}$.
import numpy as np
from scripts.toy_data import softmax
Q = np.array([[1., 0.],
[0., 1.]])
K = np.array([[1., 0.],
[1., 1.],
[0., 1.]])
V = np.array([[1., 0.],
[0., 2.],
[1., 1.]])
scores = Q @ K.T / np.sqrt(2)
A = softmax(scores, axis=1)
O = A @ V
print("A row sums:", A.sum(axis=1))
print("O:\n", O)This code builds a toy attention head to make the two core matrix products concrete. Queries $Q \in \mathbb{R}^{2\times 2}$ attend over keys $K \in \mathbb{R}^{3\times 2}$ to produce weights over three value rows $V \in \mathbb{R}^{3\times 2}$. Scores $S = QK^\top / \sqrt{2}$ apply the standard scaling for $d_k=2$, and softmax along rows yields attention weights $A \in \mathbb{R}^{2\times 3}$ whose rows sum to 1 (verified via A.sum(axis=1)). The output $O = A @ V \in \mathbb{R}^{2\times 2}$ is a convex combination of the value rows per query. This demonstrates the two matrix multiplies at the heart of attention: $QK^\top$ to measure similarity, then $AV$ to mix values according to those similarities.
Numerical and Shape Notes
- Shapes first: Declare shapes (e.g., $X \in \mathbb{R}^{n imes d}$, $w \in \mathbb{R}^{d}$, $b \in \mathbb{R}^{n}$). Vectors are column by convention; keep row/column usage consistent.
- Axis discipline: Be explicit with
axisin reductions and normalizations. For attention-like ops, softmax over keys (row-wise) so rows sum to â1. - Broadcasting: Check that broadcasts are intended (e.g.,
(n,1)with(n,d)). Prefer reshape/expand-dims to make semantics clear. - Stability eps: Add $arepsilon$ for divisions/logs and $arepsilon I$ (jitter) for SPD solves; use log-sum-exp for softmax.
- Masking preserves shape: Masks should broadcast to the score/activation tensor; verify masked outputs keep the same shape and zero out excluded entries.
- Dtype choices: Use float64 for clarity in scripts; with mixed precision, keep reductions/factorizations in float32/float64 to avoid under/overflow.
- Sanity checks: Print shapes and residuals (e.g.,
||Ax-b||, reconstruction error, row-sum â 1). Assert finiteness and expected monotonicity where applicable.
Numerical and Implementation Notes
- Dtype & precision: Prefer float64 for clarity; if using mixed precision, keep reductions (norms, softmax sums, factorizations) in float32/float64. Avoid explicit inverses; use
solve,lstsq, Cholesky/QR/SVD. - Shapes & broadcasting: Annotate shapes (e.g., $X \in \mathbb{R}^{n imes d}$); vectors are column by default. Verify axes for reductions (
axis) and ensure broadcasts are intended. - Stability: Use log-sum-exp for softmax; add small diagonal $arepsilon I$ (jitter) for SPD solves; prefer QR/SVD for ill-conditioned least squares.
- Conditioning: Inspect
np.linalg.cond(A)when solutions look unstable; regularize (ridge) or rescale features to improve conditioning. - Reproducibility: Set NumPy seed for random data; print shapes and residuals (e.g.,
||Ax-b||, reconstruction errors) and assert finiteness. - Complexity & memory: Matmul ~ $O(n^3)$ for factorizations, $O(n^2)$ for triangular solves/products. Prefer vectorization over Python loops; avoid materializing large intermediates.
- Masking & indexing: Use boolean masks that broadcast to target shapes; for attention-like ops, add $-\infty$ before softmax or zero-out after, then verify rows sum to ~1.
- Sanity checks: Compare against references (e.g.,
lstsqvs.solve), check orthogonality (U.T @ U â I), PSD (x.T @ A @ x > 0), and residual norms within tolerance (~1e-12 for float64).
- Define small, integer-valued $Q \in \mathbb{R}^{2\times 2}$, $K \in \mathbb{R}^{3\times 2}$, $V \in \mathbb{R}^{3\times 2}$ so every intermediate is easy to inspect.
- Compute scores with $S = QK^\top / \sqrt{d_k}$ using $d_k = 2$; scaling keeps logits in a moderate range for softmax.
- Apply
softmax(scores, axis=1)to get $A \in \mathbb{R}^{2\times 3}$ with row sums equal to 1 (row-stochastic weights). - Form outputs with $O = A @ V$, producing $O \in \mathbb{R}^{2\times 2}$ as weighted averages of $V$ rows per query.
- Verify numerically that
A.sum(axis=1)is exactly 1 within floating-point tolerance and print $O$ for inspection. - (Optional) Swap $Q$ rows or permute $K$ rows to see how attention weights redistribute while shapes remain valid.
Pedagogical Significance
- Learning goals: Build intuition for when and why this tool is used in ML, not just how to compute it.
- ML-first framing: Tie the concept to a concrete task pattern (fit / project / decompose / solve / measure) to anchor understanding.
- Shape discipline: Habitually annotating dimensions prevents silent bugs and reinforces linear map thinking.
- Numerical habits: Prefer stable factorizations over inverses; check residuals and condition numbers to separate bugs from ill-conditioning.
- Transfer: Reuse the same pattern across models (e.g., projection in PCA, orthogonalization in regressions, attention as weighted sums).
- Assessment ideas: Quick checks: predict sensitivity from $\kappa(A)$, verify projection properties, or compare solver outputs within tolerance.
ML Examples and Patterns
- Fit: Linear/logistic regression via least squares or softmax; regularization (ridge) improves conditioning and generalization.
- Project: PCA/SVD for dimensionality reduction; orthogonal projections to subspaces for denoising and feature extraction.
- Decompose: Eigen/SVD factorizations to expose structure (low rank, PSD) used in recommender systems, LSA, and spectral clustering.
- Solve: Stable solves without inversion (Cholesky/QR/SVD; CG for SPD) for optimization steps and kernel methods.
- Measure: Norms, angles, and condition number $\kappa(A)$ to diagnose sensitivity, stability, and training difficulty.
- The single takeaway: attention is just two matrix products with a row-wise softmax in between: $QK^\top$ to measure similarity and $AV$ to mix values.
- How scaling by $1/\sqrt{d_k}$ prevents softmax saturation and stabilizes both forward logits and backward gradients.
- How row-stochastic $A$ makes each output row a convex combination of value rows, ensuring interpretability and bounded outputs.
- How the shapes compose: $QK^\top$ yields $n_q \times n_k$, softmax preserves shape, $AV$ yields $n_q \times d_v$.
- A minimal numerical check that $A$ rows sum to 1 and that $O$ matches hand-computable expectations on toy data.
- Part 1: Shape disciplineâ$Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$; $QK^\top$ is $n_q \times n_k$, $A$ matches that shape, $AV$ is $n_q \times d_v$.
- Part 2: Numerical stabilityâscale by $1/\sqrt{d_k}$ and rely on a softmax implementation with max-shift to avoid overflow.
- Part 3: Interpretationâeach row of $A$ is a probability distribution; $O$ rows are convex combinations of $V$ rows.
- Why This Matters for ML: Attention replaces fixed receptive fields with data-dependent weighting; stability and shape correctness are prerequisites for training deep transformers.
- ML Examples and Patterns: Fit: attention weights act as adaptive feature selectors; Project: $AV$ is a projection onto the simplex-weighted span of $V$; Decompose: multi-head attention parallels block-structured factorizations; Solve: masking enforces causal or structural constraints.
- Connection to Linear Algebra Theory: Softmaxed $QK^\top$ resembles a kernel matrix over queries/keys; $AV$ is a linear map applied row-wise using those kernel weights.
- Numerical and Implementation Notes: Use float32/float16 with careâapply log-sum-exp tricks for softmax and consider dropout/masking before softmax. Check that broadcasting and transposes match intended shapes.
- Numerical and Shape Notes: Validate
A.sum(axis=1) \approx 1andO.shape == (n_q, d_v); mismatched shapes usually indicate swapped axes or missing transposes. - Pedagogical Significance: This is the smallest end-to-end attention example that still exercises both matrix products and a stabilizing scale/softmax, making it ideal for debugging intuition.
Scaled dot-product attention emerged in 2017 as the core of the transformer architecture, but it builds on earlier alignment ideas from sequence-to-sequence models (Bahdanau et al. 2015). The $QK^\top$ then $AV$ pattern became dominant because it maps cleanly to fast matrix-multiply kernels on GPUs/TPUs. Subsequent work explored efficiency (sparse/linear attention, Performer), scaling (multi-head, multi-query, grouped-query attention), and hardware co-design for large context windows. Today, the same algebra underpins language models, vision transformers, audio models, and retrieval-augmented systems, with continual innovations in masking, positional encodings, and low-rank or kernelized approximations to reduce $O(n^2)$ cost.
- Links to Chapter 16 (matrix products): attention is a canonical $QK^\top$ then $AV$ pipeline.
- Connects to conditioning (Chapter 14): the $1/\sqrt{d_k}$ scale mitigates exploding logits; masking and clipping further stabilize softmax.
- Complements least squares and projections (Chapter 12): $AV$ is a weighted projection of $V$ rows, constrained to a simplex per query row.
- Relates to SVD/PCA (Chapter 10/11): $QK^\top$ measures alignment in the shared feature space $\mathbb{R}^{d_k}$, akin to cosine similarity.
- Bridges to sparse chapters (Chapter 15): attention on long sequences often uses sparse or low-rank approximations to reduce $O(n^2)$ cost.
- Supports matrix-products pedagogy: reinforces associativity and shape reasoning used throughout earlier examples.
Comments