Example number
66
Example slug
example_66_span_in_ml_xw_and_attention_as_weighted_sums
Background

Span explains expressivity: linear models produce outputs in the span of features; attention mixes value vectors to form outputs in their span. Understanding spans clarifies when adding features or values increases model capacity, and how constraints (e.g., nonnegativity, normalization) yield convex combinations with bounded, interpretable outputs. In practice, span–based reasoning underlies basis selection, dictionary learning, and projections used throughout ML.

Purpose

Develop a clear, ML-first intuition for span: linear predictors $Xw$ live in the span of the feature columns, and attention outputs are weighted sums of value vectors. This lets you reason about what your model can express, how projections and constraints affect outputs, and why shapes and weights matter for stability and interpretation.

Problem

Show that (a) linear predictions lie in the span of feature columns and (b) attention outputs lie in the span of value vectors. Compute both explicitly and verify the span claim by expressing outputs as linear combinations.

Solution (Math)

For linear predictions $\hat y = Xw$ with $X \in \mathbb{R}^{n\times d}$ and $w \in \mathbb{R}^d$,

\[ Xw = \sum_{j=1}^d w_j\, X_{:j}, \]

so $\hat y \in \mathrm{span}\{X_{:1},\dots,X_{:d}\}$.

For attention with one query, $a = \mathrm{softmax}(QK^\top/\sqrt{d_k}) \in \mathbb{R}^{n_k}$, the output is a weighted sum of values

\[ o = a^\top V = \sum_{i=1}^{n_k} a_i\, v_i, \quad o \in \mathrm{span}\{v_1,\dots,v_{n_k}\}. \]

We use:

  • Data matrix $X\in\mathbb{R}^{n\times d}$ (rows are examples).
  • Vectors are column vectors by default.
  • $\|x\|_2$ is Euclidean norm; $\langle x,y\rangle=x^Ty$.
Solution (Python)

import numpy as np
from scripts.toy_data import softmax

X = np.array([[1., 2.],
              [3., 4.],
              [5., 6.]])
w = np.array([2., -1.])
yhat = X @ w
print("yhat:", yhat)
print("check combo:", 2*X[:,0] - 1*X[:,1])

Q = np.array([[1., 0.]])
K = np.array([[1., 0.],
              [0., 1.],
              [1., 1.]])
V = np.array([[1., 0.],
              [0., 2.],
              [1., 1.]])

scores = (Q @ K.T) / np.sqrt(2)
a = softmax(scores, axis=1).reshape(-1)
o = a @ V
print("a:", a, "sum:", a.sum())
print("o:", o)
print("o combo:", a[0]*V[0] + a[1]*V[1] + a[2]*V[2])
Code Introduction

This snippet illustrates two instances of outputs as linear combinations. First, with $X \in \mathbb{R}^{3\times 2}$ and $w \in \mathbb{R}^2$, we compute $y_{\text{hat}} = Xw$. Each entry is $y_i = \mathbf{x}_i^\top w$, so $y_{\text{hat}}$ lies in the span of the columns of $X$. With $w = [2, -1]^\top$, this is exactly $y_{\text{hat}} = 2\,X_{:,0} - 1\,X_{:,1}$, verified by the “check combo” line.

Second, the attention part forms scores $S = QK^\top / \sqrt{2}$ (with $d_k=2$), applies row-wise softmax to get weights $a \in \mathbb{R}^3$ that are nonnegative and sum to 1, then mixes values via $o = a^\top V$. This makes $o$ a convex combination of the rows of $V$: $o = \sum_{i=1}^3 a_i V_i$. The prints confirm the probability constraint and that the explicit weighted sum matches $o$.

Numerical Implementation Details

Numerical and Shape Notes

  • Shapes first: Declare shapes (e.g., $X \in \mathbb{R}^{n imes d}$, $w \in \mathbb{R}^{d}$, $b \in \mathbb{R}^{n}$). Vectors are column by convention; keep row/column usage consistent.
  • Axis discipline: Be explicit with axis in reductions and normalizations. For attention-like ops, softmax over keys (row-wise) so rows sum to ≈1.
  • Broadcasting: Check that broadcasts are intended (e.g., (n,1) with (n,d)). Prefer reshape/expand-dims to make semantics clear.
  • Stability eps: Add $arepsilon$ for divisions/logs and $arepsilon I$ (jitter) for SPD solves; use log-sum-exp for softmax.
  • Masking preserves shape: Masks should broadcast to the score/activation tensor; verify masked outputs keep the same shape and zero out excluded entries.
  • Dtype choices: Use float64 for clarity in scripts; with mixed precision, keep reductions/factorizations in float32/float64 to avoid under/overflow.
  • Sanity checks: Print shapes and residuals (e.g., ||Ax-b||, reconstruction error, row-sum ≈ 1). Assert finiteness and expected monotonicity where applicable.

Numerical and Implementation Notes

  • Dtype & precision: Prefer float64 for clarity; if using mixed precision, keep reductions (norms, softmax sums, factorizations) in float32/float64. Avoid explicit inverses; use solve, lstsq, Cholesky/QR/SVD.
  • Shapes & broadcasting: Annotate shapes (e.g., $X \in \mathbb{R}^{n imes d}$); vectors are column by default. Verify axes for reductions (axis) and ensure broadcasts are intended.
  • Stability: Use log-sum-exp for softmax; add small diagonal $arepsilon I$ (jitter) for SPD solves; prefer QR/SVD for ill-conditioned least squares.
  • Conditioning: Inspect np.linalg.cond(A) when solutions look unstable; regularize (ridge) or rescale features to improve conditioning.
  • Reproducibility: Set NumPy seed for random data; print shapes and residuals (e.g., ||Ax-b||, reconstruction errors) and assert finiteness.
  • Complexity & memory: Matmul ~ $O(n^3)$ for factorizations, $O(n^2)$ for triangular solves/products. Prefer vectorization over Python loops; avoid materializing large intermediates.
  • Masking & indexing: Use boolean masks that broadcast to target shapes; for attention-like ops, add $-\infty$ before softmax or zero-out after, then verify rows sum to ~1.
  • Sanity checks: Compare against references (e.g., lstsq vs. solve), check orthogonality (U.T @ U ≈ I), PSD (x.T @ A @ x > 0), and residual norms within tolerance (~1e-12 for float64).
  • Construct $X \in \mathbb{R}^{n\times d}$ and a weight vector $w \in \mathbb{R}^d$; compute $\hat y = Xw$ and verify $\hat y = \sum_j w_j X_{:j}$ by forming the explicit combination.
  • For attention, build small $Q \in \mathbb{R}^{1\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$.
  • Compute scores $S = QK^\top/\sqrt{d_k}$ and apply softmax(scores, axis=1) to obtain $a \in \mathbb{R}^{n_k}$.
  • Form $o = a^\top V$ and verify elementwise that $o$ equals the weighted sum $\sum_i a_i v_i$.
  • Check $a$ sums to 1 (row-stochastic) and that shapes match: $S \in \mathbb{R}^{1\times n_k}$, $a \in \mathbb{R}^{n_k}$, $o \in \mathbb{R}^{d_v}$.
  • For multi-query cases, keep $A \in \mathbb{R}^{n_q\times n_k}$ and compute $O = AV \in \mathbb{R}^{n_q\times d_v}$.
What This Example Demonstrates

Pedagogical Significance

  • Learning goals: Build intuition for when and why this tool is used in ML, not just how to compute it.
  • ML-first framing: Tie the concept to a concrete task pattern (fit / project / decompose / solve / measure) to anchor understanding.
  • Shape discipline: Habitually annotating dimensions prevents silent bugs and reinforces linear map thinking.
  • Numerical habits: Prefer stable factorizations over inverses; check residuals and condition numbers to separate bugs from ill-conditioning.
  • Transfer: Reuse the same pattern across models (e.g., projection in PCA, orthogonalization in regressions, attention as weighted sums).
  • Assessment ideas: Quick checks: predict sensitivity from $\kappa(A)$, verify projection properties, or compare solver outputs within tolerance.

ML Examples and Patterns

  • Fit: Linear/logistic regression via least squares or softmax; regularization (ridge) improves conditioning and generalization.
  • Project: PCA/SVD for dimensionality reduction; orthogonal projections to subspaces for denoising and feature extraction.
  • Decompose: Eigen/SVD factorizations to expose structure (low rank, PSD) used in recommender systems, LSA, and spectral clustering.
  • Solve: Stable solves without inversion (Cholesky/QR/SVD; CG for SPD) for optimization steps and kernel methods.
  • Measure: Norms, angles, and condition number $\kappa(A)$ to diagnose sensitivity, stability, and training difficulty.
  • Linear predictions $Xw$ lie in the span of the columns of $X$.
  • Attention outputs are convex combinations of the rows of $V$ (row-wise mixing).
  • Shapes and weights determine expressivity and numerical stability.
  • How to verify span claims numerically by explicit linear combinations.
  • Why softmax normalization makes attention weights nonnegative and sum to 1.
Notes
  • Part 1: Linear predictions as spans—$\hat y = Xw$ equals a linear combination of columns; basis choice and feature engineering expand the span.
  • Part 2: Attention as convex mixing—rows of $A$ are probability distributions; each output row is a convex combination of $V$ rows.
  • Part 3: Numerical checks—verify equality by explicit sums; confirm $\sum_i a_i = 1$ and inspect shapes to avoid axis errors.
  • Why This Matters for ML. Span reveals expressivity and constraints: what outputs can be produced given features/values. Convex mixing stabilizes outputs and supports interpretation.
  • ML Examples and Patterns. Fit: regression/classification in the column span; Project: least-squares projection onto spans; Decompose: PCA/SVD basis selection; Solve: iterative methods that move within spans via Krylov subspaces.
  • Connection to Linear Algebra Theory. Spans define subspaces; convex combinations live in convex hulls. Attention weights create barycentric coordinates over $\{v_i\}$.
  • Numerical and Implementation Notes. Avoid explicit inverses; rely on GEMM and softmax with max-shift for stability. Validate shapes and axis choices.
  • Numerical and Shape Notes. Annotate $X \in \mathbb{R}^{n\times d}$, $w \in \mathbb{R}^d$, $\hat y \in \mathbb{R}^n$, $Q \in \mathbb{R}^{n_q\times d_k}$, $K \in \mathbb{R}^{n_k\times d_k}$, $V \in \mathbb{R}^{n_k\times d_v}$.
  • Pedagogical Significance. A minimal, end-to-end demonstration that links spans to both classical linear models and modern attention mechanisms.
History and Applications

Span-based reasoning dates to early linear algebra and approximation theory: representing signals as linear combinations of basis elements underlies Fourier analysis, polynomial approximation, and finite element methods. In statistics and ML, outputs in spans explain linear model expressivity, least-squares projections, and PCA reconstructions. In modern architectures, attention implements convex combinations to form interpretable, stable mixtures of value vectors, with variants (sparse/linear attention) optimizing cost while preserving the span structure.

Connection to Broader Examples
  • Links to Chapter 2 (span): linear combinations define subspaces; models operate within spans of provided bases.
  • Connects to Chapter 12 (least squares): solutions live in the span of columns; normal equations and projections use column-span geometry.
  • Bridges to Chapter 11/16 (PCA, matrix products): PCA expresses data in the span of principal directions; attention is $QK^\top$ then $AV$.
  • Relates to Chapter 5/6 (inner products/projections): spans and projections are dual—projecting onto a span yields the closest representable output.
  • Complements Chapter 15 (sparse): sparse dictionaries and graphs still generate outputs in spans, with $O(\mathrm{nnz})$ mixing costs.

Comments