Numerical and Shape Notes
- Shapes first: Declare shapes (e.g., $X \in \mathbb{R}^{n imes d}$, $w \in \mathbb{R}^{d}$, $b \in \mathbb{R}^{n}$). Vectors are column by convention; keep row/column usage consistent.
- Axis discipline: Be explicit with
axis in reductions and normalizations. For attention-like ops, softmax over keys (row-wise) so rows sum to â1.
- Broadcasting: Check that broadcasts are intended (e.g.,
(n,1) with (n,d)). Prefer reshape/expand-dims to make semantics clear.
- Stability eps: Add $arepsilon$ for divisions/logs and $arepsilon I$ (jitter) for SPD solves; use log-sum-exp for softmax.
- Masking preserves shape: Masks should broadcast to the score/activation tensor; verify masked outputs keep the same shape and zero out excluded entries.
- Dtype choices: Use float64 for clarity in scripts; with mixed precision, keep reductions/factorizations in float32/float64 to avoid under/overflow.
- Sanity checks: Print shapes and residuals (e.g.,
||Ax-b||, reconstruction error, row-sum â 1). Assert finiteness and expected monotonicity where applicable.
Numerical and Implementation Notes
- Dtype & precision: Prefer float64 for clarity; if using mixed precision, keep reductions (norms, softmax sums, factorizations) in float32/float64. Avoid explicit inverses; use
solve, lstsq, Cholesky/QR/SVD.
- Shapes & broadcasting: Annotate shapes (e.g., $X \in \mathbb{R}^{n imes d}$); vectors are column by default. Verify axes for reductions (
axis) and ensure broadcasts are intended.
- Stability: Use log-sum-exp for softmax; add small diagonal $arepsilon I$ (jitter) for SPD solves; prefer QR/SVD for ill-conditioned least squares.
- Conditioning: Inspect
np.linalg.cond(A) when solutions look unstable; regularize (ridge) or rescale features to improve conditioning.
- Reproducibility: Set NumPy seed for random data; print shapes and residuals (e.g.,
||Ax-b||, reconstruction errors) and assert finiteness.
- Complexity & memory: Matmul ~ $O(n^3)$ for factorizations, $O(n^2)$ for triangular solves/products. Prefer vectorization over Python loops; avoid materializing large intermediates.
- Masking & indexing: Use boolean masks that broadcast to target shapes; for attention-like ops, add $-\infty$ before softmax or zero-out after, then verify rows sum to ~1.
- Sanity checks: Compare against references (e.g.,
lstsq vs. solve), check orthogonality (U.T @ U â I), PSD (x.T @ A @ x > 0), and residual norms within tolerance (~1e-12 for float64).
The code demonstrates two equivalent views of the same mathematical operation, reinforcing that matrix operations are syntactic sugar for explicit linear combinations.
For linear regression, X @ w computes predictions via optimized BLAS routines in $O(nd)$ time. The manual computation 2*X[:,0] - 1*X[:,1] performs the same operation by explicitly scaling and adding column vectors. NumPyâs broadcasting handles the scalar multiplication element-wise, producing [2., 6., 10.] from 2*X[:,0] and [2., 4., 6.] from 1*X[:,1], then subtracting to get [0., 2., 4.]. This dual representation is pedagogicalâin production, always use the matrix form (X @ w) for efficiency and numerical stability. The equivalence confirms that predictions are geometrically constrained to the column span of $X$.
For attention, the code implements scaled dot-product attention step-by-step. First, Q @ K.T computes query-key similarities as a $1 \times 3$ vector of dot products. Scaling by $1/\sqrt{2}$ (where $d_k = 2$) prevents the magnitudes from growing large, which would push softmax into saturation regions with near-zero gradients. The softmax function converts raw scores to normalized weights that sum to 1.0âthis is what makes attention a convex combination rather than just any linear combination. Finally, a @ V computes the weighted sum of value vectors. The verification a[0]*V[0] + a[1]*V[1] + a[2]*V[2] explicitly shows the output is a point in the convex hull of $\{V_0, V_1, V_2\}$.
Memory consideration: For batch attention with $n$ queries and $m$ keys, the score matrix is $n \times m$. For self-attention ($n = m = $ sequence length), this is $O(n^2)$ memory, which is why transformers struggle with long sequences. Efficient attention variants (linear attention, sparse attention) exploit low-rank structure or sparsity patterns to reduce this cost.
Comments