Part 1: Query-Key Scoring - $Q \in \mathbb{R}^{2 imes 2}$: 2 queries, each 2-dimensional. - $K \in \mathbb{R}^{3 imes 2}$: 3 keys, matching query dimension for dot products. - $ ext{scores} = QK^ op / ^{2 imes 3}$: entry $(i,j)$ is $\langle q_i, k_j angle / \sqrt{2}$. - Transpose alignment: $QK^ op$ computes all pairwise query-key similarities; $QK$ would fail (shape mismatch). - Scaling: Without $1/\sqrt{d_k}$, scores grow with dimension, pushing softmax into saturation (near-zero gradients).
Part 2: Attention Weights via Softmax - $A = ext{softmax}( ext{scores}, ext{axis}=1) \in \mathbb{R}^{2 imes 3}$. - Row-wise normalization: $A_{ij} = \exp( ext{scores}_{ij}) / \sum_k \exp( ext{scores}_{ik})$. - Each row of $A$ is a probability distribution: $\sum_j A_{ij} = 1$, $A_{ij} \ge 0$. - High scores â high attention weights; softmax is a differentiable approximation to argmax. - Numerical stability: Softmax implementations subtract $\max( ext{scores}_i)$ before exponentiating to prevent overflow.
Part 3: Output as Weighted Value Sum - $O = AV \in \mathbb{R}^{2 imes 2}$: each row is a weighted sum of value vectors. - $o_i = \sum_j A_{ij} v_j$: convex combination of values (weights sum to 1). - Output dimension: $O \in \mathbb{R}^{n_q imes d_v}$ matches value dimension $d_v$, independent of key dimension $d_k$. - Geometric interpretation: $o_i$ lies in the convex hull of value vectors; extreme attention (one weight near 1) retrieves a single value.
Why This Pattern Powers Transformers - Content-based addressing: Query-key similarity determines relevance; no hand-crafted attention patterns (unlike RNNâs sequential bias). - Parallelizable: All $n_q$ queries attend simultaneously; no sequential bottleneck. - Differentiable: Softmax and matrix products are smooth; gradients flow through entire attention operation. - Hardware-efficient: GPUs/TPUs optimize dense matrix multiplication (GEMM kernels); $QK^ op$ and $AV$ map directly to BLAS calls. - Composable: Stack many attention layers; each refines representations hierarchically.
Why This Matters for ML - Transformers dominate modern ML: BERT, GPT, ViT, CLIP, Whisper, Flamingo all use scaled dot-product attention. - Long-range dependencies: Attention captures dependencies across arbitrary distances in a single layer (vs. RNNâs sequential propagation). - Interpretability: Attention weights $A$ reveal which keys each query attends to; useful for debugging and visualization. - Scalability challenges: $O(n^2 d)$ complexity for sequence length $n$; sparse/efficient attention variants address this.
Connection to ML Applications - Self-attention: $Q, K, V$ all project from the same input sequence; each token attends to all others. - Cross-attention: Queries from one sequence, keys/values from another (e.g., decoder attends to encoder in seq2seq). - Multi-head attention: Run $h$ attention heads in parallel with different weight matrices; concatenate outputs for richer representations. - Masked attention: Set $ ext{scores}_{ij} = -$ for $j > i$ (causal mask) to prevent future tokens from influencing past predictions (autoregressive models). - Sparse attention: Restrict attention to local windows, strided patterns, or learned sparsity to reduce $O(n^2)$ to $O(n \log n)$ or $O(n)$.
Connection to Linear Algebra Theory - Matrix products as operator composition: $QK^ op$ followed by softmax and $AV$ is function composition with learned parameters. - Projection interpretation: Each output is a weighted projection onto the span of value vectors; attention weights are soft coefficients. - Inner products define similarity: Dot product $q \cdot k$ measures alignment; cosine similarity normalizes by norms for scale-invariance. - Softmax as differentiable max: Hard argmax is non-differentiable; softmax provides a smooth relaxation with usable gradients. - Low-rank structure: Attention matrices often have low effective rank; SVD-based compression or low-rank parameterizations (e.g., Linformer) exploit this.
Numerical and Implementation Notes - Avoid materializing large attention matrices: For long sequences, compute $QK^ op$ in blocks or use kernel fusion to save memory. - Softmax numerical stability: Subtract row-wise max before exponentiating: $\exp(x_i - \max(x))$ prevents overflow. - Gradient clipping: Attention gradients can explode in early training; clip by norm to stabilize. - Mixed precision: Use float16 for forward pass, float32 for softmax and gradients to balance speed and stability. - Flash Attention: Tiled computation of attention reduces memory from $O(n^2)$ to $O(n)$ by recomputing on-the-fly instead of materializing full $A$.
Numerical and Shape Notes - Shape checks: Before $QK^ op$, verify $Q.shape[1] == K.shape[1]$ (query/key dimension must match). - Softmax axis: axis=1 normalizes across keys for each query; verify with A.sum(axis=1) â all ones. - Output shape: $O.shape = (n_q, d_v)$, inheriting rows from queries and columns from values. - Batch dimensions: In practice, prepend batch dimension: $Q \in \mathbb{R}^{B imes n_q imes d_k}$; all operations generalize via broadcasting.
ML Context: From Attention to Transformers - Self-attention layer: Input $X \in \mathbb{R}^{n imes d}$ â project to $Q = XW_Q$, $K = XW_K$, $V = XW_V$ â attention â output $O \in \mathbb{R}^{n imes d_v}$. - Multi-head attention: Run $h$ heads in parallel, concatenate: $ ext{MultiHead}(X) = ext{Concat}( ext{head}_1, , ext{head}_h) W_O$. - Positional encoding: Attention is permutation-equivariant; add positional embeddings to $X$ to encode token order. - Residual connections: $X' = X + ext{Attention}(X)$ stabilizes training and enables deep stacks (100+ layers). - Layer normalization: Normalize hidden states before/after attention to control activation magnitudes and improve conditioning. - Feedforward sublayer: After attention, apply position-wise MLP: $ ext{FFN}(x) = W_2 (W_1 x)$ for additional expressiveness. - Causal masking: Autoregressive models mask future positions to maintain left-to-right generation order. - Cross-attention: Encoder-decoder models use cross-attention (decoder queries, encoder keys/values) for source-target alignment.
Comments