Example number
16
Example slug
example_16_attention_block_computation_qk_t_then_av
Background

Matrix products became the computational substrate of deep learning because GPUs/TPUs are engineered around GEMM (General Matrix Multiply) kernels. A single @ operation maps to highly optimized, parallelizable hardware operations. Attention decomposes into two GEMM operations: $QK^\top$ (compute pairwise similarities) and $AV$ (aggregate values). The intervening softmax is numerically cheap compared to GEMM, so optimizing matrix shapes and memory layout for efficient GEMM is paramount. The scaling factor $1/\sqrt{d_k}$ prevents attention scores from saturating: without it, dot products of high-dimensional vectors grow large, pushing softmax into regimes where gradients vanish. This simple scaling stabilizes both forward and backward passes.

Purpose

Make shapes and transposes feel inevitable—so you can reason about forward/backward passes and attention without memorizing formulas:

  • Understand scaled dot-product attention: $\text{scores} = QK^\top / \sqrt{d_k}$, then softmax weights, then $O = AV$.
  • Build shape intuition: $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$ produce output $O \in \mathbb{R}^{n \times d_v}$.
  • See that attention is a learnable convex combination: weights are non-negative and sum to 1.
  • Connect to gradient flow: understand why scaling by $1/\sqrt{d_k}$ prevents softmax saturation.
Problem

Compute scores, softmax weights, and outputs for tiny attention; confirm row sums of weights are 1.

Solution (Math)

Attention computes scores $QK^T/\sqrt{d}$, applies row-wise softmax to get row-stochastic weights, then outputs $AV$.

We use:

  • Data matrix $X\in\mathbb{R}^{n\times d}$ (rows are examples).
  • Vectors are column vectors by default.
  • $\|x\|_2$ is Euclidean norm; $\langle x,y\rangle=x^Ty$.
Solution (Python)

import numpy as np
from scripts.toy_data import softmax

Q=np.array([[1.,0.],[0.,1.]])
K=np.array([[1.,0.],[1.,1.],[0.,1.]])
V=np.array([[1.,0.],[0.,2.],[1.,1.]])

scores=Q@K.T/np.sqrt(2)
A=softmax(scores, axis=1)
O=A@V
print("row sums:", A.sum(axis=1))
print("O:
", O)
Code Introduction

This code implements scaled dot-product attention, the core mechanism of transformer architectures. It demonstrates how matrix products chain together to compute a weighted average of value vectors, where weights come from learned similarities between queries and keys.

Numerical Implementation Details

Numerical and Shape Notes

  • Shapes first: Declare shapes (e.g., $X \in \mathbb{R}^{n imes d}$, $w \in \mathbb{R}^{d}$, $b \in \mathbb{R}^{n}$). Vectors are column by convention; keep row/column usage consistent.
  • Axis discipline: Be explicit with axis in reductions and normalizations. For attention-like ops, softmax over keys (row-wise) so rows sum to ≈1.
  • Broadcasting: Check that broadcasts are intended (e.g., (n,1) with (n,d)). Prefer reshape/expand-dims to make semantics clear.
  • Stability eps: Add $arepsilon$ for divisions/logs and $arepsilon I$ (jitter) for SPD solves; use log-sum-exp for softmax.
  • Masking preserves shape: Masks should broadcast to the score/activation tensor; verify masked outputs keep the same shape and zero out excluded entries.
  • Dtype choices: Use float64 for clarity in scripts; with mixed precision, keep reductions/factorizations in float32/float64 to avoid under/overflow.
  • Sanity checks: Print shapes and residuals (e.g., ||Ax-b||, reconstruction error, row-sum ≈ 1). Assert finiteness and expected monotonicity where applicable.

Numerical and Implementation Notes

  • Dtype & precision: Prefer float64 for clarity; if using mixed precision, keep reductions (norms, softmax sums, factorizations) in float32/float64. Avoid explicit inverses; use solve, lstsq, Cholesky/QR/SVD.
  • Shapes & broadcasting: Annotate shapes (e.g., $X \in \mathbb{R}^{n imes d}$); vectors are column by default. Verify axes for reductions (axis) and ensure broadcasts are intended.
  • Stability: Use log-sum-exp for softmax; add small diagonal $arepsilon I$ (jitter) for SPD solves; prefer QR/SVD for ill-conditioned least squares.
  • Conditioning: Inspect np.linalg.cond(A) when solutions look unstable; regularize (ridge) or rescale features to improve conditioning.
  • Reproducibility: Set NumPy seed for random data; print shapes and residuals (e.g., ||Ax-b||, reconstruction errors) and assert finiteness.
  • Complexity & memory: Matmul ~ $O(n^3)$ for factorizations, $O(n^2)$ for triangular solves/products. Prefer vectorization over Python loops; avoid materializing large intermediates.
  • Masking & indexing: Use boolean masks that broadcast to target shapes; for attention-like ops, add $-\infty$ before softmax or zero-out after, then verify rows sum to ~1.
  • Sanity checks: Compare against references (e.g., lstsq vs. solve), check orthogonality (U.T @ U ≈ I), PSD (x.T @ A @ x > 0), and residual norms within tolerance (~1e-12 for float64).
  • Query-key scores: scores = Q @ K.T / np.sqrt(d_k) computes all pairwise similarities; shape $(n, m)$.
  • Softmax normalization: A = softmax(scores, axis=1) normalizes each row; each entry is in $[0, 1]$, rows sum to 1.
  • Verification: A.sum(axis=1) should yield all ones (or near 1.0 up to roundoff).
  • Value aggregation: O = A @ V weights-averages values; shape $(n, d_v)$.
  • Scaling factor: $1/\sqrt{d_k}$ prevents scores from growing as $d_k$ increases; standard in all transformer implementations.
  • Multi-head: repeat with different $Q, K, V$ projections, concatenate outputs, and project—enables learning diverse attention patterns.
  • Shapes: $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$ in this example $n=2, m=3, d_k=d_v=2$.
What This Example Demonstrates

Pedagogical Significance

  • Learning goals: Build intuition for when and why this tool is used in ML, not just how to compute it.
  • ML-first framing: Tie the concept to a concrete task pattern (fit / project / decompose / solve / measure) to anchor understanding.
  • Shape discipline: Habitually annotating dimensions prevents silent bugs and reinforces linear map thinking.
  • Numerical habits: Prefer stable factorizations over inverses; check residuals and condition numbers to separate bugs from ill-conditioning.
  • Transfer: Reuse the same pattern across models (e.g., projection in PCA, orthogonalization in regressions, attention as weighted sums).
  • Assessment ideas: Quick checks: predict sensitivity from $\kappa(A)$, verify projection properties, or compare solver outputs within tolerance.

ML Examples and Patterns

  • Fit: Linear/logistic regression via least squares or softmax; regularization (ridge) improves conditioning and generalization.
  • Project: PCA/SVD for dimensionality reduction; orthogonal projections to subspaces for denoising and feature extraction.
  • Decompose: Eigen/SVD factorizations to expose structure (low rank, PSD) used in recommender systems, LSA, and spectral clustering.
  • Solve: Stable solves without inversion (Cholesky/QR/SVD; CG for SPD) for optimization steps and kernel methods.
  • Measure: Norms, angles, and condition number $\kappa(A)$ to diagnose sensitivity, stability, and training difficulty.
  • Attention decomposes into two matrix products: $QK^\top$ (scores) and $AV$ (outputs).
  • Softmax converts scores into row-stochastic weights (sum to 1 per row).
  • Output is a convex combination of value vectors—each output lies in $\text{span}(V)$.
  • Shape tracking: $Q:(n \times d_k)$, $K:(m \times d_k)$, $V:(m \times d_v)$ yield $QK^\top:(n \times m)$, $A:(n \times m)$, $O:(n \times d_v)$.
  • Numerical stability: scaling by $1/\sqrt{d_k}$ keeps attention scores in a range where softmax gradients are informative.
Notes

Part 1: Core setup - Attention block computation (QK^T then AV)

State the objects, shapes, and target question for Attention block computation (QK^T then AV). Name the data matrices or vectors, specify their dimensions, and clarify the transformation or comparison this example develops.

Part 2: Geometry and algebraic insight - Attention block computation (QK^T then AV)

Describe the geometric picture (subspaces, projections, bases, or decompositions) and the algebraic identities that make Attention block computation (QK^T then AV) work. Highlight how these structures constrain solutions and connect to earlier linear algebra tools.

Part 3: Numerics and ML practice - Attention block computation (QK^T then AV)

Give the computational recipe for Attention block computation (QK^T then AV), note stability or conditioning checks, and tie to an ML use case. Mention parameter choices, common pitfalls, and quick sanity checks such as shape validation or reconstruction error.

  • Shape discipline: check dimensions before manipulating formulas.
  • Numerical note: prefer stable primitives (lstsq, QR/SVD, Cholesky for SPD) over explicit inverses.
  • Interpretation: relate algebraic steps to geometry (subspaces, projections) and to ML behavior (generalization, stability).
  • Computing Attention Scores The code begins with three embedding matrices representing different roles in attention. The query matrix $Q \in \mathbb{R}^{2 \times 2}$ contains 2 query vectors (one per head or position), the key matrix $K \in \mathbb{R}^{3 \times 2}$ contains 3 key vectors, and the value matrix $V \in \mathbb{R}^{3 \times 2}$ contains 3 corresponding value vectors. The first operation scores = Q @ K.T / np.sqrt(2) computes the query-key dot products: Q @ K.T produces a $2 \times 3$ matrix where entry $(i,j)$ is the inner product between query $i$ and key $j$. This measures alignment—how much the $i$-th query “matches” the $j$-th key. Scaling by $1/\sqrt{d_k} = 1/\sqrt{2}$ (where $d_k=2$ is the key dimension) stabilizes the magnitudes; without scaling, dot products grow with embedding dimension, pushing softmax into saturation regions where gradients vanish.
  • Converting Scores to Attention Weights The raw scores are passed through softmax(scores, axis=1), which normalizes each row independently to form a probability distribution. For each query $i$, the softmax computes $a_{ij} = \frac{\exp(\text{scores}_{ij})}{\sum_k \exp(\text{scores}_{ik})}$, producing attention weights $A \in \mathbb{R}^{2 \times 3}$ where each row sums to 1.0 (verified by A.sum(axis=1)). The printout should show [1. 1.], confirming that softmax produced valid probability distributions. This is the crucial step that makes attention a convex combination rather than an arbitrary linear combination: the weights are non-negative and sum to 1, so the output lies in the convex hull of the value vectors.
  • Weighted Aggregation of Values The final operation O = A @ V computes the attention output by taking a weighted sum of value vectors. With $A \in \mathbb{R}^{2 \times 3}$ and $V \in \mathbb{R}^{3 \times 2}$, the result is $O \in \mathbb{R}^{2 \times 2}$. For each query position $i$, the output $O_i$ is the convex combination $O_i = \sum_j a_{ij} V_j$. This is the core insight of attention: for each query, compute a learned weighted average of the values, where the weights reflect how much each value is relevant to that query. The weights emerge from the dot products in the keys—if query $i$ aligns strongly with key $j$, then value $j$ contributes heavily to output $i$.
  • Connection to Broader ML Patterns This attention computation encodes the span and linear combination pattern (Example 2) applied dynamically: the output lives in $\text{span}(V)$ (the set of all possible convex combinations of value vectors), but which point in that span we select depends on the attention weights. In transformers, stacking multiple attention heads (each with separate $Q, K, V$ projections) and MLP layers allows the model to compute different weighted combinations at each layer, hierarchically building richer representations. The matrix products Q @ K.T and A @ V are the fundamental computational primitives; understanding their shapes and numerical stability (gradient flow through softmax, scaling to prevent saturation) is essential for debugging transformer training. Shape discipline: $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$ produce scores $\in \mathbb{R}^{n \times m}$, weights $A \in \mathbb{R}^{n \times m}$, and output $O \in \mathbb{R}^{n \times d_v}$. Batch processing extends this to $Q, K, V$ with a batch dimension, enabling efficient computation on GPU.
History and Applications

Attention mechanisms: Bahdanau et al. (2015) introduced attention to machine translation as a way to let decoders “focus” on relevant source tokens, replacing fixed-size context vectors with learned, dynamic alignments. The approach was the breakthrough enabling sequence-to-sequence models to handle long dependencies without data decay.

Scaled dot-product attention: Vaswani et al. (2017, “Attention Is All You Need”) proposed the transformer architecture built entirely on attention, removing recurrence/convolution. They introduced scaled dot-product attention ($QK^\top / \sqrt{d_k}$), multi-head attention (parallel heads with different projections), and position embeddings. This architecture scaled to billions of parameters and became the foundation of BERT, GPT, and modern LLMs.

Computational efficiency: Matrix products dominate transformer runtime; the two GEMM operations in attention are the bottleneck. Modern hardware (TPUs, GPUs with tensor cores) are optimized around these operations. Flash Attention (Dao et al., 2022) and other kernel-level optimizations reduce I/O by fusing forward/backward passes, enabling even larger models. Understanding the matrix structure—shapes, transposition, memory layout—is essential for both implementing efficient transformers and for deploying at scale.

Modern applications: Transformers power large language models (GPT-2/3/4, LLaMA, Gemini), vision transformers (ViT), multimodal systems (CLIP, vision-language models), and domain-specific models (AlphaFold for protein structure, models for code generation). Attention’s expressiveness comes from learning task-specific similarity metrics (via $Q, K$ projections) and value aggregation strategies (via $V$). Shape discipline—keeping track of $n$ (sequence length), $m$ (key/value length), $d_k$, $d_v$—is the foundation for reasoning about model capacity and computational cost.

Connection to Broader Examples
  • Span and linear combinations (Ch. 2): Output lives in $\text{span}(V)$; weights are a convex combination.
  • Projections (Ch. 6): Attention is a learned projection onto the value space, with weights chosen via query-key similarity.
  • Matrix products efficiency: GPUs optimize GEMM; attention’s two-GEMM structure is ideal for hardware acceleration.
  • Numerics and scaling: $1/\sqrt{d_k}$ prevents saturation—connects to conditioning (Ch. 14) and numerical stability in deep networks.
  • Transformers and modern deep learning: Attention is the foundation; understanding shapes and matrix operations is essential for implementing, debugging, and optimizing large models.

Comments