Example number
64
Example slug
example_64_transformers_attention_as_qk_t_then_av
Background

Matrix products became the computational substrate of deep learning because GPUs/TPUs are engineered around GEMM-like kernels. Scaled dot-product attention (Vaswani et al. 2017) is two back-to-back matrix multiplies—$QK^\top$ then $AV$—and its efficiency hinges on matching shapes to hardware-friendly GEMM calls. The $1/\sqrt{d_k}$ factor keeps logits numerically stable so softmax does not saturate when $d_k$ is large, which in turn keeps gradients well-conditioned.

Purpose

Make shapes and transposes feel inevitable so you can reason about forward/backward passes and attention without memorizing formulas: $Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$, scores $S = QK^\top/\sqrt{d_k} \in \mathbb{R}^{n_q\times n_k}$, weights $A = \mathrm{softmax}(S)$ row-stochastic, outputs $O = AV \in \mathbb{R}^{n_q\times d_v}$.

Problem

Compute scores, attention weights, and outputs for a tiny attention head; verify weights sum to 1.

Solution (Math)

Attention: $O = \mathrm{softmax}(QK^\top/\sqrt{d_k})V$. Row-wise softmax yields row-stochastic weights; multiplying by $V$ forms weighted sums. Shapes: $Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$, so $S, A \in \mathbb{R}^{n_q\times n_k}$ and $O \in \mathbb{R}^{n_q\times d_v}$.

Solution (Python)

import numpy as np
from scripts.toy_data import softmax

Q = np.array([[1., 0.],
              [0., 1.]])
K = np.array([[1., 0.],
              [1., 1.],
              [0., 1.]])
V = np.array([[1., 0.],
              [0., 2.],
              [1., 1.]])

scores = Q @ K.T / np.sqrt(2)
A = softmax(scores, axis=1)
O = A @ V

print("A row sums:", A.sum(axis=1))
print("O:\n", O)
Code Introduction

This code builds a toy attention head to make the two core matrix products concrete. Queries $Q \in \mathbb{R}^{2\times 2}$ attend over keys $K \in \mathbb{R}^{3\times 2}$ to produce weights over three value rows $V \in \mathbb{R}^{3\times 2}$. Scores $S = QK^\top / \sqrt{2}$ apply the standard scaling for $d_k=2$, and softmax along rows yields attention weights $A \in \mathbb{R}^{2\times 3}$ whose rows sum to 1 (verified via A.sum(axis=1)). The output $O = A @ V \in \mathbb{R}^{2\times 2}$ is a convex combination of the value rows per query. This demonstrates the two matrix multiplies at the heart of attention: $QK^\top$ to measure similarity, then $AV$ to mix values according to those similarities.

Numerical Implementation Details

Numerical and Shape Notes

  • Shapes first: Declare shapes (e.g., $X \in \mathbb{R}^{n imes d}$, $w \in \mathbb{R}^{d}$, $b \in \mathbb{R}^{n}$). Vectors are column by convention; keep row/column usage consistent.
  • Axis discipline: Be explicit with axis in reductions and normalizations. For attention-like ops, softmax over keys (row-wise) so rows sum to ≈1.
  • Broadcasting: Check that broadcasts are intended (e.g., (n,1) with (n,d)). Prefer reshape/expand-dims to make semantics clear.
  • Stability eps: Add $arepsilon$ for divisions/logs and $arepsilon I$ (jitter) for SPD solves; use log-sum-exp for softmax.
  • Masking preserves shape: Masks should broadcast to the score/activation tensor; verify masked outputs keep the same shape and zero out excluded entries.
  • Dtype choices: Use float64 for clarity in scripts; with mixed precision, keep reductions/factorizations in float32/float64 to avoid under/overflow.
  • Sanity checks: Print shapes and residuals (e.g., ||Ax-b||, reconstruction error, row-sum ≈ 1). Assert finiteness and expected monotonicity where applicable.

Numerical and Implementation Notes

  • Dtype & precision: Prefer float64 for clarity; if using mixed precision, keep reductions (norms, softmax sums, factorizations) in float32/float64. Avoid explicit inverses; use solve, lstsq, Cholesky/QR/SVD.
  • Shapes & broadcasting: Annotate shapes (e.g., $X \in \mathbb{R}^{n imes d}$); vectors are column by default. Verify axes for reductions (axis) and ensure broadcasts are intended.
  • Stability: Use log-sum-exp for softmax; add small diagonal $arepsilon I$ (jitter) for SPD solves; prefer QR/SVD for ill-conditioned least squares.
  • Conditioning: Inspect np.linalg.cond(A) when solutions look unstable; regularize (ridge) or rescale features to improve conditioning.
  • Reproducibility: Set NumPy seed for random data; print shapes and residuals (e.g., ||Ax-b||, reconstruction errors) and assert finiteness.
  • Complexity & memory: Matmul ~ $O(n^3)$ for factorizations, $O(n^2)$ for triangular solves/products. Prefer vectorization over Python loops; avoid materializing large intermediates.
  • Masking & indexing: Use boolean masks that broadcast to target shapes; for attention-like ops, add $-\infty$ before softmax or zero-out after, then verify rows sum to ~1.
  • Sanity checks: Compare against references (e.g., lstsq vs. solve), check orthogonality (U.T @ U ≈ I), PSD (x.T @ A @ x > 0), and residual norms within tolerance (~1e-12 for float64).
  • Define small, integer-valued $Q \in \mathbb{R}^{2\times 2}$, $K \in \mathbb{R}^{3\times 2}$, $V \in \mathbb{R}^{3\times 2}$ so every intermediate is easy to inspect.
  • Compute scores with $S = QK^\top / \sqrt{d_k}$ using $d_k = 2$; scaling keeps logits in a moderate range for softmax.
  • Apply softmax(scores, axis=1) to get $A \in \mathbb{R}^{2\times 3}$ with row sums equal to 1 (row-stochastic weights).
  • Form outputs with $O = A @ V$, producing $O \in \mathbb{R}^{2\times 2}$ as weighted averages of $V$ rows per query.
  • Verify numerically that A.sum(axis=1) is exactly 1 within floating-point tolerance and print $O$ for inspection.
  • (Optional) Swap $Q$ rows or permute $K$ rows to see how attention weights redistribute while shapes remain valid.
What This Example Demonstrates

Pedagogical Significance

  • Learning goals: Build intuition for when and why this tool is used in ML, not just how to compute it.
  • ML-first framing: Tie the concept to a concrete task pattern (fit / project / decompose / solve / measure) to anchor understanding.
  • Shape discipline: Habitually annotating dimensions prevents silent bugs and reinforces linear map thinking.
  • Numerical habits: Prefer stable factorizations over inverses; check residuals and condition numbers to separate bugs from ill-conditioning.
  • Transfer: Reuse the same pattern across models (e.g., projection in PCA, orthogonalization in regressions, attention as weighted sums).
  • Assessment ideas: Quick checks: predict sensitivity from $\kappa(A)$, verify projection properties, or compare solver outputs within tolerance.

ML Examples and Patterns

  • Fit: Linear/logistic regression via least squares or softmax; regularization (ridge) improves conditioning and generalization.
  • Project: PCA/SVD for dimensionality reduction; orthogonal projections to subspaces for denoising and feature extraction.
  • Decompose: Eigen/SVD factorizations to expose structure (low rank, PSD) used in recommender systems, LSA, and spectral clustering.
  • Solve: Stable solves without inversion (Cholesky/QR/SVD; CG for SPD) for optimization steps and kernel methods.
  • Measure: Norms, angles, and condition number $\kappa(A)$ to diagnose sensitivity, stability, and training difficulty.
  • The single takeaway: attention is just two matrix products with a row-wise softmax in between: $QK^\top$ to measure similarity and $AV$ to mix values.
  • How scaling by $1/\sqrt{d_k}$ prevents softmax saturation and stabilizes both forward logits and backward gradients.
  • How row-stochastic $A$ makes each output row a convex combination of value rows, ensuring interpretability and bounded outputs.
  • How the shapes compose: $QK^\top$ yields $n_q \times n_k$, softmax preserves shape, $AV$ yields $n_q \times d_v$.
  • A minimal numerical check that $A$ rows sum to 1 and that $O$ matches hand-computable expectations on toy data.
Notes
  • Part 1: Shape discipline—$Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$; $QK^\top$ is $n_q \times n_k$, $A$ matches that shape, $AV$ is $n_q \times d_v$.
  • Part 2: Numerical stability—scale by $1/\sqrt{d_k}$ and rely on a softmax implementation with max-shift to avoid overflow.
  • Part 3: Interpretation—each row of $A$ is a probability distribution; $O$ rows are convex combinations of $V$ rows.
  • Why This Matters for ML: Attention replaces fixed receptive fields with data-dependent weighting; stability and shape correctness are prerequisites for training deep transformers.
  • ML Examples and Patterns: Fit: attention weights act as adaptive feature selectors; Project: $AV$ is a projection onto the simplex-weighted span of $V$; Decompose: multi-head attention parallels block-structured factorizations; Solve: masking enforces causal or structural constraints.
  • Connection to Linear Algebra Theory: Softmaxed $QK^\top$ resembles a kernel matrix over queries/keys; $AV$ is a linear map applied row-wise using those kernel weights.
  • Numerical and Implementation Notes: Use float32/float16 with care—apply log-sum-exp tricks for softmax and consider dropout/masking before softmax. Check that broadcasting and transposes match intended shapes.
  • Numerical and Shape Notes: Validate A.sum(axis=1) \approx 1 and O.shape == (n_q, d_v); mismatched shapes usually indicate swapped axes or missing transposes.
  • Pedagogical Significance: This is the smallest end-to-end attention example that still exercises both matrix products and a stabilizing scale/softmax, making it ideal for debugging intuition.
History and Applications

Scaled dot-product attention emerged in 2017 as the core of the transformer architecture, but it builds on earlier alignment ideas from sequence-to-sequence models (Bahdanau et al. 2015). The $QK^\top$ then $AV$ pattern became dominant because it maps cleanly to fast matrix-multiply kernels on GPUs/TPUs. Subsequent work explored efficiency (sparse/linear attention, Performer), scaling (multi-head, multi-query, grouped-query attention), and hardware co-design for large context windows. Today, the same algebra underpins language models, vision transformers, audio models, and retrieval-augmented systems, with continual innovations in masking, positional encodings, and low-rank or kernelized approximations to reduce $O(n^2)$ cost.

Connection to Broader Examples
  • Links to Chapter 16 (matrix products): attention is a canonical $QK^\top$ then $AV$ pipeline.
  • Connects to conditioning (Chapter 14): the $1/\sqrt{d_k}$ scale mitigates exploding logits; masking and clipping further stabilize softmax.
  • Complements least squares and projections (Chapter 12): $AV$ is a weighted projection of $V$ rows, constrained to a simplex per query row.
  • Relates to SVD/PCA (Chapter 10/11): $QK^\top$ measures alignment in the shared feature space $\mathbb{R}^{d_k}$, akin to cosine similarity.
  • Bridges to sparse chapters (Chapter 15): attention on long sequences often uses sparse or low-rank approximations to reduce $O(n^2)$ cost.
  • Supports matrix-products pedagogy: reinforces associativity and shape reasoning used throughout earlier examples.

Comments