Example number
96
Example slug
example_96_transformers_attention_as_qk_t_then_av
Background

Attention mechanisms emerged to solve the bottleneck problem in sequence-to-sequence models (Cho et al. 2014, Sutskever et al. 2014). Early RNNs compressed entire source sequences into a fixed-size context vector, losing information for long sequences. Bahdanau et al. (2015) introduced attention: instead of compressing into one vector, the decoder learns to focus on relevant parts of the encoder output at each decoding step. This was a breakthrough: suddenly, long sequences became tractable. However, attention on RNNs still required sequential processing; the model couldn’t parallelize. Vaswani et al. (2017) proposed Transformers, replacing recurrence entirely with stacked attention layers (“Attention is All You Need”). The key insight: scaled dot-product attention is fully parallelizable and GPU-efficient. Each layer computes $O(n^2)$ dot products (queries vs keys) and $O(n^2)$ weighted aggregations (attention times values), but all in parallel via matrix multiplication—the fundamental GPU operation (GEMM: general matrix-matrix multiply). By 2017–2020, transformers replaced RNNs as the default for NLP, scaling to billions of parameters (BERT, GPT-2/3, T5). By 2022–2025, vision transformers (ViT) adapted attention to images, and multimodal transformers (CLIP, GPT-4V) unified vision and language. The transformer is, fundamentally, sophisticated matrix algebra executed at massive scale.

Purpose

Understand scaled dot-product attention, the fundamental building block of transformer neural networks. Learn how three matrix multiplications—$QK^\top$, softmax normalization, and $AV$—combine to create a differentiable, parallelizable mechanism for selective information aggregation. Master shape discipline: track how query, key, and value matrices flow through the computation, how scaling affects attention sharpness, and why softmax is the right normalization. Discover why this simple mechanism scales to billions of parameters: attention is GPU-friendly (matrix products = GEMM kernels), parallelizable (all dot products computed at once), and interpretable (attention weights are probabilities). Learn the key insights: similarities are dot products, normalization is softmax, aggregation is weighted sums. Apply this understanding to multi-head attention, causal masking, cross-attention, and the full transformer stack.

Problem

Compute scores, attention weights, and outputs for a tiny attention head; verify weights sum to 1.

Solution (Math)

Attention: $O=\mathrm{softmax}$QK^T/$V$. Row-wise softmax yields row-stochastic weights; multiplying by $V$ forms weighted sums.

We use:

  • Data matrix $X\in\mathbb{R}^{n\times d}$ (rows are examples).
  • Vectors are column vectors by default.
  • $\|x\|_2$ is Euclidean norm; $\langle x,y\rangle=x^Ty$.
Solution (Python)

import numpy as np
from scripts.toy_data import softmax

Q = np.array([[1., 0.],
              [0., 1.]])
K = np.array([[1., 0.],
              [1., 1.],
              [0., 1.]])
V = np.array([[1., 0.],
              [0., 2.],
              [1., 1.]])

scores = Q @ K.T / np.sqrt(2)
A = softmax(scores, axis=1)
O = A @ V

print("A row sums:", A.sum(axis=1))
print("O:
", O)
Code Introduction

Mechanism and Intuition

Scaled dot-product attention computes learned similarity-based routing of information. The core idea: each query (what we’re looking for) is compared to all keys (what’s available) via dot products. Similarities are normalized into probabilities using softmax. Finally, these probabilities weight values (the actual information we retrieve). The result: each query gets a weighted aggregation of all values, with weights determined by query-key similarity.

Code Walkthrough: Q, K, V Setup

Q = np.array([[1., 0.],
              [0., 1.]])           # 2 queries, 2D
K = np.array([[1., 0.],
              [1., 1.],
              [0., 1.]])           # 3 keys, 2D  
V = np.array([[1., 0.],
              [0., 2.],
              [1., 1.]])           # 3 values, 2D
  • $Q \in \mathbb{R}^{2 \times 2}$: Two queries, each a 2D vector. In transformer terminology, these are learned projections of input elements in the decoder (or within same input for self-attention).
  • $K \in \mathbb{R}^{3 \times 2}$: Three keys, each 2D. These are learned projections of input elements in the encoder (or same input for self-attention). Keys define “what can we attend to?”
  • $V \in \mathbb{R}^{3 \times 2}$: Three values, each 2D. These are the actual information we retrieve. Crucially, values can have different dimension than keys (here both 2D; in practice may differ).

Constraint: Number of keys must equal number of values (both 3 here). Query and key dimensions must match (both 2D). Value dimension determines output dimension.

Computing Scores via Dot Products

scores = Q @ K.T / np.sqrt(2)      # Shape: (2, 3) 
# scores = [[q1·k1, q1·k2, q1·k3],
#           [q2·k1, q2·k2, q2·k3]]
  • Matrix product $Q @ K^T$ computes all query-key dot products in one GEMM operation: $QK^\top \in \mathbb{R}^{2 \times 3}$.

  • Entry $(i, j)$ is similarity between query $i$ and key $j$: $Q_i \cdot K_j = q_i^\top k_j$.

  • Concrete values:

    scores before scaling:
    [[1·1 + 0·0, 1·1 + 0·1, 1·0 + 0·1],  = [1, 1, 0]
     [0·1 + 1·0, 0·1 + 1·1, 0·0 + 1·1]]  = [0, 1, 1]

    After dividing by $\sqrt{2} \approx 1.414$:

    [[0.707, 0.707, 0.],
     [0., 0.707, 0.707]]

Why Scaling by $\sqrt{d_k}$ Matters

Dot products grow with dimension. If we have 768D queries/keys (typical in BERT): - Average dot product magnitude: $\approx 768$ (random orthogonal vectors have dot product $\sim 0$; correlated vectors $\sim d$). - Softmax of very large/small numbers becomes one-hot: softmax([100, 50, 0]) ≈ [1, 0, 0]. - Gradient dies: backprop through one-hot softmax provides little signal.

Scaling by $\sqrt{d_k}$ (dividing scores by $\sqrt{768} \approx 27.7$) brings scores to regime $[-1, 1]$ range where softmax has reasonable entropy. This is a simple but crucial trick: it’s not in any textbook, but it’s fundamental to transformers working well in practice.

Softmax Normalization: From Scores to Probabilities

A = softmax(scores, axis=1)        # Shape: (2, 3)

Softmax converts scores to probability distributions: \[ A_{ij} = \frac{\exp(\text{scores}_{ij})}{\sum_k \exp(\text{scores}_{ik})} \]

Each row sums to 1. Using our example:

Before softmax: [[0.707, 0.707, 0.],
                 [0., 0.707, 0.707]]

After softmax (numerically):
[[0.380, 0.380, 0.239],  # row 1 sums to 1
 [0.239, 0.380, 0.380]]  # row 2 sums to 1

Why softmax? 1. Non-negative weights: $A_{ij} \geq 0$. Each query uses non-negative blend of values. 2. Probabilities: $\sum_j A_{ij} = 1$. Can interpret as probability distribution over keys for query $i$. 3. Differentiable: enables backpropagation. 4. Entropy control: softmax temperature (here $\sqrt{d_k}$ scaling) tunes how “sharp” attention is. Smaller temperature → sharper (more one-hot). Larger temperature → softer (more uniform).

Weighted Aggregation: Computing Output

O = A @ V                          # Shape: (2, 2)
# O = [[a11*v1 + a12*v2 + a13*v3],
#      [a21*v1 + a22*v2 + a23*v3]]

Output is a weighted sum of values: \[ O_i = \sum_j A_{ij} V_j \]

Each row of $O$ is a context vector: a learned aggregation of all values, with weights determined by attention.

Concrete computation:

Query 1 output: 0.380 * [1,0] + 0.380 * [0,2] + 0.239 * [1,1]
              = [0.380, 0.] + [0., 0.760] + [0.239, 0.239]
              = [0.619, 0.999]  ≈ [0.62, 1.0]

Query 2 output: 0.239 * [1,0] + 0.380 * [0,2] + 0.380 * [1,1]
              = [0.239, 0.] + [0., 0.760] + [0.380, 0.380]
              = [0.619, 1.140]  ≈ [0.62, 1.14]

Interpretation: Both queries get similar outputs because their attention patterns are similar (both attend roughly equally to all three values). If queries had very different patterns (one attends mostly to key 1, other mostly to key 3), outputs would differ more.

Verification Check: Row Sums

print("A row sums:", A.sum(axis=1))  # Should be [1., 1.]

This verifies softmax worked correctly: each row of attention matrix is a probability distribution.

Why This Matters for Machine Learning

  1. Transformers Scalability: All three operations (compute $QK^\top$, compute softmax, compute $AV$) are matrix multiplications (GEMMs), fully parallelizable on GPUs. This is why transformers scale to billions of parameters, while RNNs can’t. RNNs require sequential processing: position 1 → position 2 → position 3, impossible to parallelize.

  2. Attention is Learned Routing: Unlike convolution (fixed small receptive field) or pooling (fixed aggregation), attention learns what to look at. Each query dynamically decides which values to weight heavily.

  3. Interpretability: Attention weights are human-readable probabilities. Visualizing which keys each query attends to provides model interpretability. You can ask: “What did this query focus on?” and get a precise answer from attention weights.

  4. Transfer Learning: Transformers pretrained on massive corpora (BERT, GPT) learn attention patterns that transfer to downstream tasks. The same learned “ability to route information” works for different domains.

  5. Multimodal Learning: Attention handles variable-length sequences naturally. Text (tokens), images (patches), audio (frames) all feed through the same attention mechanism. Cross-attention between modalities is straightforward: different queries/keys/values from different modalities.

  6. Long-Range Dependencies: Attention attends to any position (unlike convolution’s limited receptive field or RNN’s gradient flow issues). This enables learning long-range dependencies directly.

Numerical and Shape Notes

  1. Log-Sum-Exp Trick: In implementation, softmax is computed as:

    scores_shifted = scores - scores.max(axis=1, keepdims=True)
    A = np.exp(scores_shifted) / np.exp(scores_shifted).sum(axis=1, keepdims=True)

    Subtracting max prevents overflow (exp of large numbers overflows). PyTorch/TensorFlow do this automatically.

  2. Causal Masking: For autoregressive models (GPT), position $i$ should attend only to positions $\leq i$. Implement by setting future scores to $-\infty$ before softmax:

    scores[:, i+1:] = -np.inf  # Future positions
    A = softmax(scores, axis=1)  # softmax(-inf) = 0
  3. Attention Dropout: For regularization, zero out random attention weights:

    mask = np.random.binomial(1, 1 - dropout_p, A.shape)
    A = A * mask / (1 - dropout_p)  # Renormalize
  4. Head Dimension: Modern transformers split into $h$ “heads” with dimension $d/h$ each. Example: 768 dims / 12 heads = 64 dims per head. Benefits: different heads learn different patterns (some focus on syntax, others semantics). Cost is same (distributed across heads).

  5. Memory Complexity: Full $n \times n$ attention matrix requires $O(n^2)$ memory. For $n = 1\text{M}$ tokens and 32-bit floats: $10^{12}$ floats = 4 TB. This is why “long context” transformers use sparse attention (only compute/store subset of pairs).

Verification and Extensions

Verify that output has correct shape and values make sense:

print("O shape:", O.shape)          # (2, 2)
print("O min/max:", O.min(), O.max())  # Should be reasonable (not NaN)

Multi-Head Extension:

# Instead of one 2D attention, split into 2 heads of 1D each
Q1, Q2 = Q[:, :1], Q[:, 1:]        # Split queries into 2 heads
K1, K2 = K[:, :1], K[:, 1:]        # Split keys into 2 heads
V1, V2 = V[:, :1], V[:, 1:]        # Split values into 2 heads

# Compute attention separately for each head
O1 = softmax(Q1 @ K1.T) @ V1       # Head 1 output: (2, 1)
O2 = softmax(Q2 @ K2.T) @ V2       # Head 2 output: (2, 1)

# Concatenate
O_multihead = np.concatenate([O1, O2], axis=1)  # (2, 2)

All modern transformers use multi-head attention. It’s empirically shown to improve model capacity and generalization (different heads capture different linguistic phenomena: syntax, semantics, anaphora, etc.).

Causal Attention Example (for autoregressive):

# Query 2 should not attend to future keys
scores = Q @ K.T / np.sqrt(2)
scores[1, :1] = -np.inf             # Query 2 can't attend to key 0 (in the future)
# (In real usage, this would prevent query 2 from attending to its own position if processing sequentially)

Pedagogical Takeaway

Scaled dot-product attention is the simplest possible mechanism that achieves sophisticated learned routing: score via bilinear form (dot product), normalize via softmax (probabilistic), aggregate via weighted sum. Yet this simplicity, combined with GPU parallelization, enables transformers to scale to hundred-billion-parameter models and solve diverse tasks (language, vision, speech, multimodal). Understanding the shapes, the role of scaling, the softmax normalization, and the GPU efficiency underlying attention is essential to understanding modern deep learning. Get the shapes right, and the math follows naturally.

Numerical Implementation Details
  • Inputs and shapes: $Q \in \mathbb{R}^{n\times d_k}$, $K \in \mathbb{R}^{m\times d_k}$, $V \in \mathbb{R}^{m\times d_v}$. In batched code: $(B,n,d_k)$, $(B,m,d_k)$, $(B,m,d_v)$. For multi-head, reshape/proj to $(B,h,n,d_h)$ etc. with $d_h=d/h$.
  • Scores with scaling: Compute $S = QK^\top / \sqrt{d_k}$ via GEMM. Prefer float32 accumulation even with float16/bfloat16 inputs to reduce rounding error.
  • Masking: Add a large negative bias before softmax. Padding mask (don’t attend to pad tokens): add $-\infty$ where keys are padding. Causal mask (no future): set $S_{ij}=-\infty$ for $j>i$. Ensure mask broadcasts to $(B,h,n,m)$.
  • Stable softmax (row-wise): Use log-sum-exp: S -= S.max(axis=...), then A = exp(S) / exp(S).sum(axis=1, keepdims=True). Softmax axis is across keys (rows sum to 1 for each query).
  • Output: $O = AV \in \mathbb{R}^{n\times d_v}$ (batched: $(B,h,n,d_h)$ → concat heads → $(B,n,d)$ via output projection $W^O$).
  • Complexity: Self-attention with $m=n$, $d_k\approx d_v\approx d$ costs $O(n^2 d)$ time and $O(n^2)$ memory for $A$. Use blockwise/sparse/linear attention for long sequences.
  • Mixed precision: Use fp16/bf16 for matmuls, but keep softmax reduction and normalization in fp32. Clip logits if necessary to avoid NaNs.
  • Sanity checks: verify A.sum(axis=1) ≈ 1, O.shape == (n, d_v), no NaN/inf after softmax. With masks, ensure masked positions have zero attention.
  • Reference implementation sketch: S = (Q @ K.T) / sqrt(d_k) → S += mask → S -= S.max(axis=1, keepdims=True) → A = exp(S) / exp(S).sum(axis=1, keepdims=True) → O = A @ V.
What This Example Demonstrates
  • Scaled dot-product attention formula: $\text{Attention}(Q, K, V) = \text{softmax}(QK^\top / \sqrt{d_k}) V$ computes attention in three matrix products. Query-key dot products reveal similarities; softmax normalizes; weighted aggregation produces context.

  • Shape tracking: $Q \in \mathbb{R}^{n \times d_k}$ (queries), $K \in \mathbb{R}^{m \times d_k}$ (keys), $V \in \mathbb{R}^{m \times d_v}$ (values). $QK^\top \in \mathbb{R}^{n \times m}$ (similarities). Output $O \in \mathbb{R}^{n \times d_v}$ (context). Query-key and value dimensions can differ.

  • Scaling factor $\sqrt{d_k}$ prevents saturation: Without scaling, dot products grow with dimension, causing softmax to be nearly one-hot (low entropy). Scaling stabilizes softmax entropy.

  • Softmax as probability normalization: Each row of softmax output sums to 1 (verified by A.sum(axis=1) = [1, 1]). Softmax is differentiable, enabling backpropagation.

  • Attention weights are learned similarity-based routing: Unlike fixed pooling/convolution, attention dynamically learns what each query “attends to.” Weights reflect relevance.

  • Matrix multiplication enables parallelization: All $n \times m$ attention scores computed in one $QK^\top$ product. All outputs computed in one $AV$ product. No sequential loop required; highly parallelizable on GPUs.

  • Output is weighted sum of values: Each row of output $O$ is a context vector: weighted combination of all values, with weights determined by attention. $O_i = \sum_j A_{ij} V_j$.

  • Bilinear form interpretation: Attention is a learned similarity function: score $(q, k) \to q^\top k$. Applied to all query-key pairs, it forms a bilinear ranking.

  • Query, key, value independence: In self-attention, $Q = XW_Q, K = XW_K, V = XW_V$ are learned projections of the same input. This enables the model to extract different aspects (what to query, what to attend to, what information to extract) from the same input.

  • Multi-head attention: Splitting queries/keys/values into multiple “heads” enables attending to different representation subspaces (low-level vs high-level features, different linguistic phenomena). Conceptually: apply attention separately, concatenate. ## Numerical Implementation Details

  • Dot product similarity: $s_{ij} = q_i^\top k_j$ (single scalar per pair). Over matrices: $S = QK^\top \in \mathbb{R}^{n \times m}$. Time: $O(n \times m \times d_k)$ (typically $O(n^2 d)$ for self-attention where $m = n$). GPUs compute this efficiently via GEMM.

  • Scaling by $1/\sqrt{d_k}$: For queries/keys with dimension $d_k$, average dot product magnitude is $\approx d_k$. Dividing by $\sqrt{d_k}$ brings scores to $\text{scale} \approx \sqrt{d_k}$ (regime where softmax has good gradient properties). Standard choice in transformers.

  • Softmax computation: $A_{ij} = \frac{\exp(S_{ij} / \sqrt{d_k})}{\sum_k \exp(S_{ik} / \sqrt{d_k})}$ (axis=1, row-wise). Each row is a probability distribution. Numerically stable via log-sum-exp trick: compute $\max_j S_{ij}$ first, subtract before exp to prevent overflow.

  • Weighted aggregation: $O_i = \sum_j A_{ij} V_j$. Over matrices: $O = AV \in \mathbb{R}^{n \times d_v}$. Time: $O(n \times m \times d_v)$ (typically $O(n^2 d)$ for self-attention). Again, efficient GEMM.

  • Total attention cost: Three GEMM operations: $QK^\top$ ($O(n m d_k)$), $AV$ ($O(n m d_v)$), plus softmax ($O(n m)$). For self-attention with $m = n$ and $d_k = d_v = d$: $O(n^2 d)$ (quadratic in sequence length $n$).

  • Complexity limitation: $O(n^2)$ cost is the fundamental bottleneck. For sequences of length $n = 10^6$, $10^{12}$ attention computations required. Workarounds: local attention (fixed window), sparse attention (only attend to nearby positions), approximations (linear attention via random features).

  • Backward pass (for training): Backpropagation through softmax and matrix products computes gradients w.r.t. $Q, K, V$. Via chain rule: $\frac{\partial L}{\partial Q} = \frac{\partial L}{\partial O} \frac{\partial O}{\partial A} \frac{\partial A}{\partial S} \frac{\partial S}{\partial Q}$. All operations differentiable; PyTorch/JAX/TensorFlow compute automatically.

  • Multi-head computational pattern: Instead of one attention head with dimension $d$, use $h$ heads with dimension $d / h$ each. Compute attention separately, concatenate: $\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$. Cost: same (distributed across heads). Benefit: different heads learn different attention patterns (syntactic, semantic, etc.). Practice shows empirically that multiple heads capture linguistic structure better than single head.

Notes

Practical guidance for implementing and debugging scaled dot-product attention. Focus on shape discipline, numerical stability (scaling and log-sum-exp), masking semantics, batching/head dimensions, memory, and quick verification checks.

Part 1: Scaled Dot-Product Attention and Query-Key-Value Mechanism

Attention solves an information routing problem: Given query $q \in \mathbb{R}^{d_k}$ (what are we looking for?), keys $k_1, \ldots, k_m \in \mathbb{R}^{d_k}$ (what can we attend to?), and values $v_1, \ldots, v_m \in \mathbb{R}^{d_v}$ (what information to extract?), compute a context vector as weighted sum of values: \[ c = \sum_{j=1}^{m} a_j v_j, \] where weights $a_j$ are learned based on similarity between query and key $j$: \[ a_j = \text{softmax}(s_j / \sqrt{d_k}), \quad s_j = q^\top k_j. \]

The scaling factor $\sqrt{d_k}$ is crucial. Without it, $s_j$ grows with dimension (average magnitude $\approx d_k$), causing softmax to be nearly one-hot (one key captures all attention). Dividing by $\sqrt{d_k}$ keeps softmax entropy reasonable.

Over matrices: with $Q = [q_1, \ldots, q_n]^\top \in \mathbb{R}^{n \times d_k}$, $K = [k_1, \ldots, k_m]^\top \in \mathbb{R}^{m \times d_k}$, $V = [v_1, \ldots, v_m]^\top \in \mathbb{R}^{m \times d_v}$: \[ C = \text{softmax}(QK^\top / \sqrt{d_k}) V \in \mathbb{R}^{n \times d_v}. \]

Part 2: Softmax Normalization and Probability Interpretation

Softmax turns scores into probability distributions: \[ A_{ij} = \frac{\exp(S_{ij})}{\sum_k \exp(S_{ik})}, \quad \text{where } S = QK^\top / \sqrt{d_k}. \]

Properties: - Non-negative: $A_{ij} \in [0, 1]$ - Row-stochastic: $\sum_j A_{ij} = 1$ (each row sums to 1) - Differentiable: enables backpropagation - Entropy-controlled: softmax entropy is tuned by temperature scaling (dividing by $\sqrt{d_k}$)

Probability interpretation: $A_{ij} = P(\text{attend to key } j | \text{query } i)$. Attention weights form a categorical distribution over keys for each query.

Part 3: Matrix Product Efficiency and GPU Implementation

Attention’s success is fundamentally due to matrix multiplication efficiency on modern hardware:

  1. $QK^\top$ is a GEMM (general matrix-matrix multiply): $n \times m \times d_k$ operations. GPUs are engineered for GEMM; this is the “peak” operation.
  2. Softmax is element-wise: $O(nm)$ operation, negligible compared to GEMM.
  3. $AV$ is another GEMM: $n \times m \times d_v$ operations.
  4. Total: three GEMMs. Full parallelization; no sequential loop required.

Contrast with RNNs: RNNs process sequentially (position by position), inherently serial. Even with GPU acceleration, they can’t parallelize across sequence positions. Transformers compute all query-key similarities at once, then all value aggregations at once—fully parallelizable.

Scaling to large $n$: For $n = 1024$ (typical): $QK^\top$ is $1024 \times 1024 \times 64 \approx 67\text{M}$ flops (< 1ms on modern GPU). For $n = 100\text{k}$ (long context): $\approx 640\text{B}$ flops ($\approx 1$s on GPU). This is why “long context” transformers are challenging—$O(n^2)$ cost.

Why This Matters for ML

  • Transformer scalability: Attention’s parallelizability enabled training billion-parameter models (BERT, GPT). RNNs could not scale this way.
  • Interpretability: Attention weights are human-readable probabilities. You can visualize which parts of the input the model attends to. This transparency is invaluable for debugging and understanding model behavior.
  • Transfer learning: Transformers pretrained on large corpora (BERT, GPT-2/3) transfer well to downstream tasks. The learned attention patterns generalize across domains.
  • Multimodal learning: Attention naturally handles variable-length inputs (tokens, pixels, audio frames). Cross-attention between modalities (text → image) is straightforward.
  • Efficiency via approximation: When full attention is too expensive (long sequences, resource-constrained devices), sparse/approximate attention variants (Linformer, Performer, Local Attention) approximate $O(n^2)$ with $O(n \log n)$ or $O(n)$.
  • Position bias and relative attention: Attention alone doesn’t encode position. Adding positional encodings (absolute or relative) enables the model to understand sequence structure.

ML Examples and Patterns

  1. Self-attention (transformer encoder): $Q = K = V = X$ (same input). Each position attends to all positions (or causal mask: each position attends to itself + past). Used in BERT, RoBERTa, T5 encoder.
  2. Cross-attention (encoder-decoder): $Q$ from decoder, $K, V$ from encoder. Decoder learns to retrieve relevant encoder information. Used in machine translation (encoder-decoder transformers), image captioning, visual question answering.
  3. Causal attention (transformer decoder): Self-attention with causal mask: position $i$ attends only to positions $\leq i$. Prevents future information leak during generation. Used in GPT, autoregressive language models.
  4. Multi-head attention: Compute attention in $h$ subspaces (heads) with reduced dimension $d/h$. Concatenate outputs. Benefits: captures different linguistic phenomena (syntax, semantics, etc.). All modern transformers use multi-head.
  5. Sparse attention: Instead of full $O(n^2)$ attention, compute only subset of pairs (local window, learnable patterns, etc.). Reduces complexity to $O(n \log n)$ or $O(n)$. Essential for long-context transformers.

Connection to Linear Algebra Theory

  • Bilinear form: Attention score $(q, k) \mapsto q^\top k$ is a bilinear form. Over matrices, $QK^\top$ is the bilinear form applied to all pairs.
  • Softmax as log-sum-exp: $\log(\text{softmax}(x)) = x - \log \sum_i \exp(x_i)$ (log-partition function). This is the softmax temperature scaling; dividing scores by $\sqrt{d_k}$ is temperature scaling.
  • Eigenvalue analysis of attention: Spectral properties of attention weight matrix $A$ determine convergence of attention patterns across layers. Recent work (e.g., studying attention bottleneck, gradient flow) uses eigenvalue analysis.
  • Rank and expressiveness: Rank of attention matrix $A$ determines how many “effective” values are used. Full-rank $A$ means all values contribute; rank-1 $A$ means single output for all queries.
  • Singular value structure: SVD of $QK^\top$ reveals principal attention directions (top singular vectors are most-attended “keys”). Low-rank approximation via truncated SVD reduces attention cost.

Numerical and Implementation Notes

  • Log-sum-exp trick: Compute softmax as $a_i = \exp(s_i - \max(s)) / \sum_j \exp(s_j - \max(s))$ to prevent overflow (exp of large numbers overflows). Most frameworks (PyTorch, TensorFlow) do this automatically.
  • Causal masking: For autoregressive attention, before softmax, set future positions to $-\infty$ (or very large negative): scores[:, i+1:] = -inf. After softmax, $\exp(-\infty) = 0$.
  • Attention dropout: After softmax, zero out random attention weights (set to 0, renormalize). Regularization technique; prevents overfitting to particular attention patterns.
  • Head dimension choice: Most transformers use $d / h$ per head (e.g., 768 dims / 12 heads = 64 dims per head). Smaller heads are more efficient; larger heads capture complex patterns. Choice often empirical.
  • Memory complexity: For self-attention with $n$ tokens and dimension $d$: storing full $n \times n$ attention matrix requires $O(n^2)$ memory. For $n = 10^6$, this is terabytes! Solution: compute attention in blocks (local attention) or use approximations.

Numerical and Shape Notes

  • Shapes: $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$. Output $O \in \mathbb{R}^{n \times d_v}$.
  • Shape compatibility: Query and key dimensions must match ($d_k$). Value dimension ($d_v$) can be anything (determines output dim). Number of keys/values must match ($m$).
  • Scaling factor: Divide by $\sqrt{d_k}$ (key dimension), not $d_k$ or other. This is standard and has theoretical justification (controls softmax entropy).
  • Broadcasting in attention: All-to-all attention (full $n \times m$ matrix). Each query attends to all keys. For self-attention ($m = n$), this is $n^2$ attention weights.
  • Batch and head dimensions: In practice, transformers use batch attention: $(batch, n, m)$ or $(batch, heads, n, m)$ for multi-head. Shape discipline is critical to avoid errors.

Pedagogical Significance

Attention is the gateway to understanding modern deep learning. Once you grasp scaled dot-product attention—similarities via dot products, normalization via softmax, aggregation via weighted sum—you understand the core of transformers, which have revolutionized NLP, vision, and beyond. This example teaches layered thinking: top-level (what is attention doing?), middle-level (how does softmax work?), implementation-level (matrix products on GPUs). It shows that simplicity is powerful: three matrix products and one softmax create a mechanism more effective than RNNs and CNNs. Finally, it teaches shape discipline: tracking dimensions through computation is non-negotiable. Get shapes wrong, and the code crashes (or worse, runs but gives nonsense). Master this, and you can reason about any neural network architecture.

History and Applications

2014–2015: The Attention Mechanism Emerges

Machine translation with RNNs dominated 2012–2014, but faced the bottleneck problem: source sentences were compressed into fixed-size context vectors, and long sentences caused information loss. Sutskever et al. (2014, “Sequence to Sequence Learning with Neural Networks”) demonstrated that neural machine translation worked reasonably well, but blew up on long sentences. Bahdanau et al. (2015, “Neural Machine Translation by Jointly Learning to Align and Translate”) introduced the breakthrough idea: instead of compressing the entire source into one vector, the decoder should learn to attend to different parts of the source at each decoding step. Their attention mechanism computed attention weights via a small learned network (not yet scaled dot-product; that came later). The result was transformative: suddenly, long sentences were tractable. The paper had immediate impact—within months, attention was adopted in image captioning (Xu et al. 2015), machine translation, speech recognition, and beyond.

2016–2017: Scaling Attention; Birth of Transformers

Despite attention’s success, attention-augmented RNNs were still sequentially processed (position by position). Vaswani et al. (2016) experimented with attention-only models, gradually removing recurrence. By June 2017, “Attention Is All You Need” dropped at NeurIPS 2017 and changed everything. Key contributions:

  • Scaled dot-product attention: Replaced the learned attention mechanism with a simpler, more efficient formula: $\text{Attention}(Q, K, V) = \text{softmax}(QK^\top / \sqrt{d_k}) V$. This is not just efficient; it’s fully parallelizable.
  • Multi-head attention: Split queries/keys/values into $h$ subspaces, compute attention separately, concatenate. Empirically, this captured different linguistic phenomena in different heads.
  • Stacked layers: The “Transformer” is many layers of multi-head attention + feed-forward networks. No recurrence, no convolution—pure matrix algebra.
  • Positional encodings: Since attention has no inherent notion of sequence position, they added sinusoidal positional encodings (later, learned embeddings) to tell the model where in the sequence each token is.

The parallelization was revolutionary. While RNNs process sequences position-by-position (inherently serial, gradient flow issues), transformers compute all query-key similarities at once, all outputs at once. GPUs are engineered for matrix multiplication (GEMM); transformers are one big GEMM sandwich. This enabled rapid scaling: BERT (2018, 110M–340M parameters), GPT-2 (2019, 1.5B parameters), GPT-3 (2020, 175B parameters). By comparison, the largest RNNs topped out around 1B parameters by 2017.

2018–2020: Scaling and Pretraining

BERT (Devlin et al. 2018): Bidirectional encoder-only transformers, pretrained on masked language modeling. 340M parameters, demonstrated that pretraining on unlabeled data then fine-tuning on downstream tasks was the paradigm. Within months, BERT variants appeared (RoBERTa, ALBERT, DistilBERT, etc.). NLP was transformed: BERT achieved state-of-the-art on nearly every benchmark.

GPT (Radford et al. 2018): Autoregressive (left-to-right) language model using decoder-style transformer with causal masking. Demonstrated that language model pretraining at scale (117M parameters) learned useful representations. Subsequent scaling showed that GPT-2 (1.5B params) generated coherent long-form text without explicit task supervision—just language modeling.

Encoder-decoder transformers: Machine translation (Vaswani et al. 2017), summarization, question answering—any sequence-to-sequence task now used transformers. Cross-attention (decoder attends to encoder) enabled the decoder to dynamically fetch relevant encoder information.

T5 (Raffel et al. 2019): Encoder-decoder transformers unified all NLP tasks into a single “text-to-text” framework. 11B parameters. Demonstrated that one model could be pretrained and then fine-tuned (or prompted) on translation, summarization, QA, etc.

2020–2021: Vision and Multimodal

Vision Transformers (ViT; Dosovitski et al. 2020): Applied transformers to images by treating images as sequences of patches. Instead of fixed convolution receptive fields, vision transformers could learn to attend to any patch. ViT-Large (300M params) matched or exceeded state-of-the-art CNNs (ResNet) on ImageNet with fewer parameters. This opened the floodgates: transformers are now the default for vision.

CLIP (Radford et al. 2021): Contrastive learning with image and text encoders (separate transformers), aligned via dot-product similarity. CLIP learned rich image representations from 400M image-text pairs. Most importantly: CLIP was zero-shot—you could query new categories (e.g., “a photo of a tabby cat”) without fine-tuning. CLIP’s success sparked multimodal learning: combining text, vision, and eventually audio/video in a single model.

GPT-3 (Brown et al. 2020): 175B-parameter autoregressive language model. With 175B parameters, GPT-3 exhibited in-context learning: you could prompt it with a few examples of a task (e.g., “Q: 2+3 A: 5”) and it would perform the task without gradient updates. This “few-shot” capability suggested that scale alone could unlock general reasoning.

2021–2023: Specialization and Long Context

Efficient attention variants: Full $O(n^2)$ attention is prohibitive for long sequences. Sparse, local, and linear attention variants emerged: - BigBird (Zaheer et al. 2020): Sparse attention pattern (local + global) reduces complexity to $O(n \log n)$, enabling processing of 4096-token documents (RoBERTa max: 512). - Linformer (Wang et al. 2020): Approximates $O(n^2)$ attention as $O(n)$ via low-rank approximation. - Performer (Choromanski et al. 2020): Random features approximate softmax attention in $O(n)$ time.

Domain-specific transformers: ELECTRA (discriminative pretraining for text), DeiT (data-efficient vision), DALL-E (image generation with transformers), Perceiver (unifying modality with transformers).

Instruction-tuning and RLHF: InstructGPT (Ouaknine et al. 2022) showed that fine-tuning language models with human preferences (via reinforcement learning from human feedback, RLHF) produces models that follow instructions better than pure language model pretraining. This was the precursor to ChatGPT.

2023–2025: Scaling to Multimodal and Long Context

GPT-4 (OpenAI, 2023): Multimodal (text + images), dramatically improved reasoning, longer context window (up to 128K tokens). Suggested that scale + engineering continues to unlock new capabilities.

Gemini (Google DeepMind, 2023): Unified multimodal model (text, images, video, audio). Trained from scratch on multimodal data, showing that a single transformer architecture scales across modalities.

LLaMA and open-source revolution (Meta, 2023): 7B, 13B, 70B parameter models. Open-source enabled rapid experimentation. Variants like Alpaca (instruction-tuned LLaMA) demonstrated that you don’t need 175B parameters to get strong performance—7B fine-tuned well beats 100B pretrained poorly.

Context window expansion: Newer models (Claude, GPT-4 Turbo) support 100K+ token context windows via: - Sparse attention patterns (local + recent + sparse) - Position interpolation (relative position embeddings that extrapolate) - Newer architectures (Mamba, RetNet) that avoid $O(n^2)$ entirely

Mixture-of-Experts (MoE): Instead of dense transformers, MoE splits feed-forward networks into experts (e.g., 8 or 16 parallel FFNs). Each token routes to a subset of experts (e.g., top-2). Enables 300B-500B parameter models while keeping compute fixed (sparse activation).

Real-world Applications of Scaled Dot-Product Attention

  1. Machine Translation: Every modern machine translation system (Google Translate, DeepL, Alexa Translation) uses attention-based transformers. Attention weights provide interpretability: visualizing which source words the decoder attended to when generating each target word.

  2. Large Language Models: GPT-3, GPT-4, Claude, LLaMA, Gemini—all use scaled dot-product attention as the core. Billions of users interact with attention daily without knowing it: ChatGPT, Copilot, Bard, Claude, Llama.

  3. Image Recognition and Classification: Vision Transformers (ViT) are now competitive with or superior to CNNs for ImageNet, COCO, and other benchmarks. Deployed in mobile (efficient attention variants) and cloud services.

  4. Image Generation: DALL-E, Stable Diffusion, Midjourney all use transformer-based diffusion models. Attention enables the model to generate coherent multi-object images with spatial relationships.

  5. Video Understanding: Divided attention (spatial + temporal) extends transformers to video. Applications: video classification, action recognition, video captioning.

  6. Speech Recognition and Synthesis: Attention replaces RNNs in speech (Conformer), improves robustness, enables streaming (partial attention).

  7. Code Generation: GitHub Copilot, CodePilot, Replit Ghostwriter use transformer language models. Attention’s parallelization enables fast code generation (lower latency).

  8. Multimodal Interfaces: CLIP powers image search, DALL-E·3 uses CLIP for text-image alignment, GPT-4V combines text and vision, multimodal models enable ChatGPT to process images.

  9. Recommendation Systems: Attention is used to weight user history (context vectors from past items) when predicting next item. YouTube, Netflix, Spotify all use attention-based recommendation (internally, often called “transformer” models).

  10. Graph Neural Networks: Graph Transformers apply attention to graph nodes. Attention weight between nodes reflects learned edge importance. Applications: molecular property prediction, knowledge graphs, social networks.

  11. Document Understanding and Retrieval: Transformers chunk long documents and use sparse attention to process them. Retrieval systems (Elasticsearch, vector DBs) index document embeddings from transformers.

  12. Time Series Forecasting: Temporal attention learns to weight past timesteps differently depending on context. Temporal transformers forecast electricity demand, weather, stock prices, etc.

Challenges and Open Questions

  1. Quadratic complexity: $O(n^2)$ cost limits sequence length. Workarounds (sparse attention, linear approximations) trade off expressiveness for efficiency. No clear winner yet.

  2. Interpretability: While attention weights are visualizable, they don’t always align with human interpretability (attention heads often encode non-obvious patterns). “Why did the model attend to that word?” remains hard to answer rigorously.

  3. Computational cost: Training GPT-3 (175B params) cost millions of dollars in compute. Inference latency (generating one token requires processing entire context) is a bottleneck for real-time applications. MoE and efficient attention help, but fundamentally expensive.

  4. Data efficiency: Transformers require massive datasets to reach strong performance. Small-data regimes (few examples) still favor inductive biases (CNNs, domain-specific architectures). Transfer learning (pretraining) mitigates this, but still data-hungry.

  5. Context length: While progress on long-context (128K tokens), most models are still limited. Human-scale context (entire books) is not yet practical.

  6. Alignment and safety: As models scale, ensuring they’re aligned with human values (via RLHF, constitutional AI, etc.) becomes harder. Attention mechanisms are not inherently safer; techniques like attention constraints are being explored.

Why Attention Matters for Linear Algebra Education

Transformers are proof that matrix algebra at scale solves real problems. The core building blocks—dot products, softmax, matrix multiplication, rank-constrained approximations—are all fundamental linear algebra. Yet combined cleverly (scaled dot-product attention, multi-head, layer norm, positional encoding), they create models that understand language, generate images, and compete with humans on reasoning tasks. This shows students that linear algebra is not abstract; it’s the foundation of modern AI.

Key Insights for Practitioners

  1. Scaling matters: Attention isn’t a new algorithm; it’s a scaled version (by $\sqrt{d_k}$) that works. Often, the difference between a paper title and a real system is careful engineering: normalization, initialization, scaling.

  2. GPU efficiency: Transformers succeeded partly because of hardware alignment. Matrix multiplication is what GPUs do best. Future architectures (State Space Models like Mamba, Retentive Networks like RetNet) achieve better complexity but must compete with highly optimized GEMM kernels.

  3. Pretraining + Fine-tuning: The dominant paradigm. Pretrain on massive unlabeled data, fine-tune on task. This unlocks transfer learning and few-shot learning (GPT-3 style). In-context learning (prompting) is becoming the norm.

  4. Multimodal is the future: Vision + Language + Audio unified in one architecture via attention. Cross-attention between modalities is straightforward. This trend will continue.

  5. Open-source democratization: LLaMA, Mistral, Llama 2 open-source models enable researchers and practitioners to build without trillion-dollar budgets. Fine-tuning is cheap; pretraining is expensive.

Connection to Broader Mathematics

  • Bilinear forms: Attention is fundamentally a learned bilinear form on queries and keys.
  • Probability theory: Softmax is the exponential family distribution for categorical data.
  • Information theory: Attention entropy (controlled by temperature scaling) relates to information flow in the network.
  • Spectral analysis: Eigenvalues of attention weight matrices determine convergence across layers; recent work analyzes attention rank and its effect on model capacity.
  • Approximation theory: Low-rank approximations (truncated SVD, random features) explain why sparse/linear attention variants work.

The success of transformers validates that linear algebra is not just foundational; it’s sufficient. With careful design, scaling, and engineering, matrix algebra solves billion-parameter problems in language, vision, and beyond. Understanding attention deeply—from the simplicity of scaled dot-product to the complexity of inference optimization—gives practitioners intuition for the next generation of models.

Connection to Broader Examples
  • Least squares (Ex 92): Attention solves an implicit least-squares problem: find weights $A$ that minimize reconstruction error $\|V - AV\|^2$ subject to row-stochastic constraint (probabilistic weights). Scaled dot-product attention is a learned similarity-based method for choosing weights.
  • Conditioning (Ex 94): Attention weights depend on scores $QK^\top / \sqrt{d_k}$. If keys are nearly collinear (ill-conditioned), different queries produce very similar attention patterns (redundant). Well-conditioned keys (diverse) lead to varied, informative attention.
  • SVD and low-rank (Ex 90): Attention can be approximated via SVD or random features (Performer, Linformer). Approximate attention as $\text{softmax}(QK^\top) \approx \Phi(Q) \Phi(K)^\top$ (rank-$r$ approximation). Reduces complexity from $O(n^2 d)$ to $O(n r d)$.
  • PCA (Ex 91): Principal components of keys ($K$) determine what the model can attend to. If keys have low effective rank (few principal directions), attention patterns are limited. Diverse key representations enable fine-grained attention.
  • Eigenvalues and power iteration (Ex 88): Attention dynamics (how attention patterns evolve across layers) can be studied via eigenvalue analysis of attention matrices. Spectral properties of $A$ reveal convergence and stability.
  • Cholesky and SPD systems (Ex 93): In some attention variants (softmax over covariances), the covariance matrix of keys $K^\top K$ is SPD. Cholesky factorization can be used for efficient sampling or approximation.
  • Sparse matrices (Ex 95): Sparse attention (only compute/store $O(n \log n)$ or $O(n)$ attention weights instead of full $O(n^2)$) exploits sparsity to reduce memory/compute. Graph structure (local neighborhoods) determines sparsity pattern.
  • Orthogonality and projections (Ex 86): Attention can be viewed as a learned orthogonal projection (in limit of very sharp softmax) or soft projection (for distributed attention weights). Attention computes projection of queries onto subspace spanned by keys.
  • Rank and nullspace (Ex 87): Numerical rank of attention matrix $A$ determines how many effective values are used. Full-rank $A$ means all values contribute; low-rank $A$ means only few values dominate.

Comments