-
-
Save eric-czech/ebd9a80d58c7b5e9c40ba390ff884617 to your computer and use it in GitHub Desktop.
VPNLS TypeScript implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env node | |
| /** | |
| * Minimal VPNLS (Variable-Projection Nonlinear Least Squares) demo in TS. | |
| * | |
| * Loss model: L(N, D) = E + A * N^(-alpha) + B * D^(-beta) | |
| * | |
| * Strategy: grid-search over (alpha, beta) in (0, 1] at 1/gridSize steps | |
| * (gridSize defaults to 100), solve for (E, A, B) via OLS at each grid | |
| * point, keep lowest RSS. | |
| * | |
| * Zero dependencies — OLS uses the normal equations on a 3x3 system. | |
| * | |
| * Usage: npx tsx demo/vpnls.ts | |
| * | |
| * Benchmarks (Apple M4 Pro, Node v24.3.0, 70 data points): | |
| * gridSize=100 (10k grid points): ~0.03s | |
| * gridSize=1000 (1M grid points): ~2.5s | |
| */ | |
| type Mat3 = [Vec3, Vec3, Vec3]; | |
| type Vec3 = [number, number, number]; | |
| interface VpnlsResult { rss: number; E: number; A: number; B: number; alpha: number; beta: number } | |
| // --------------------------------------------------------------------------- | |
| // Tiny linear algebra helpers (3x3 only needed) | |
| // --------------------------------------------------------------------------- | |
| /** Solve a 3x3 linear system Ax = b via Cramer's rule. */ | |
| function solve3x3(A: Mat3, b: Vec3): Vec3 { | |
| const det = (M: Mat3): number => | |
| M[0][0] * (M[1][1] * M[2][2] - M[1][2] * M[2][1]) - | |
| M[0][1] * (M[1][0] * M[2][2] - M[1][2] * M[2][0]) + | |
| M[0][2] * (M[1][0] * M[2][1] - M[1][1] * M[2][0]); | |
| const D = det(A); | |
| const rep = (col: number): Mat3 => | |
| A.map((row, i) => row.map((v, j) => (j === col ? b[i] : v))) as Mat3; | |
| return [det(rep(0)) / D, det(rep(1)) / D, det(rep(2)) / D]; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // OLS inner solve for fixed (alpha, beta) | |
| // --------------------------------------------------------------------------- | |
| /** | |
| * Build design matrix [1, N^{-alpha}, D^{-beta}] and solve for [E, A, B] | |
| * via ordinary least squares (normal equations). | |
| */ | |
| function olsSolve(alpha: number, beta: number, logN: number[], logD: number[], L: number[]) { | |
| const n = L.length; | |
| // Build columns of the design matrix | |
| const cols = [new Float64Array(n), new Float64Array(n), new Float64Array(n)]; | |
| for (let i = 0; i < n; i++) { | |
| cols[0][i] = 1; // ones | |
| cols[1][i] = Math.exp(-alpha * logN[i]); // N^{-alpha} | |
| cols[2][i] = Math.exp(-beta * logD[i]); // D^{-beta} | |
| } | |
| // X'X (3x3 symmetric) and X'y (3x1) | |
| const XtX: Mat3 = [[0,0,0],[0,0,0],[0,0,0]]; | |
| const Xty: Vec3 = [0, 0, 0]; | |
| for (let i = 0; i < n; i++) { | |
| for (let a = 0; a < 3; a++) { | |
| Xty[a] += cols[a][i] * L[i]; | |
| for (let b = a; b < 3; b++) XtX[a][b] += cols[a][i] * cols[b][i]; | |
| } | |
| } | |
| // Fill symmetric lower triangle | |
| XtX[1][0] = XtX[0][1]; XtX[2][0] = XtX[0][2]; XtX[2][1] = XtX[1][2]; | |
| const [E, A, B] = solve3x3(XtX, Xty); | |
| // Compute RSS | |
| let rss = 0; | |
| for (let i = 0; i < n; i++) { const r = L[i] - (E + A * cols[1][i] + B * cols[2][i]); rss += r * r; } | |
| return { rss, E, A, B }; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // VPNLS grid search | |
| // --------------------------------------------------------------------------- | |
| /** Fit the 5-parameter loss surface via VPNLS with exhaustive grid search. */ | |
| function fitVpnls(params: number[], tokens: number[], loss: number[], gridSize: number = 100): VpnlsResult { | |
| const logN = params.map(Math.log); | |
| const logD = tokens.map(Math.log); | |
| let best: VpnlsResult = { rss: Infinity, E: 0, A: 0, B: 0, alpha: 0, beta: 0 }; | |
| // Grid search: alpha, beta in [1/gridSize, 1.00] at 1/gridSize resolution | |
| for (let ai = 1; ai <= gridSize; ai++) { | |
| const alpha = ai / gridSize; | |
| for (let bi = 1; bi <= gridSize; bi++) { | |
| const beta = bi / gridSize; | |
| const r = olsSolve(alpha, beta, logN, logD, loss); | |
| if (r.rss < best.rss) best = { ...r, alpha, beta }; | |
| } | |
| } | |
| return best; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Synthetic data generation | |
| // --------------------------------------------------------------------------- | |
| const TRUE = { A: 400, B: 400, E: 1.69, alpha: 0.31, beta: 0.31 }; | |
| function trueLoss(N: number, D: number): number { | |
| return TRUE.E + TRUE.A / N ** TRUE.alpha + TRUE.B / D ** TRUE.beta; | |
| } | |
| /** Generate IsoFLOP data: for each FLOP budget C ≈ 6ND, sweep N. */ | |
| function generateData() { | |
| const flops = [1e17, 3e17, 1e18, 3e18, 1e19, 3e19, 1e20]; | |
| const params: number[] = [], tokens: number[] = [], losses: number[] = []; | |
| for (const C of flops) { | |
| const G = (TRUE.alpha * TRUE.A / (TRUE.beta * TRUE.B)) ** (1 / (TRUE.alpha + TRUE.beta)); | |
| const Nopt = G * (C / 6) ** (TRUE.beta / (TRUE.alpha + TRUE.beta)); | |
| for (let k = 0; k < 10; k++) { | |
| // Log-uniform spread: 0.2x to 5x of N_opt | |
| const N = Nopt * 0.2 * (25 ** (k / 9)); | |
| const D = C / (6 * N); | |
| params.push(N); tokens.push(D); losses.push(trueLoss(N, D)); | |
| } | |
| } | |
| return { params, tokens, losses }; | |
| } | |
| // --------------------------------------------------------------------------- | |
| // Main | |
| // --------------------------------------------------------------------------- | |
| const { params: P, tokens: T, losses: L } = generateData(); | |
| console.log(`Data points: ${P.length}`); | |
| console.log(`True params: A=${TRUE.A}, B=${TRUE.B}, E=${TRUE.E}, α=${TRUE.alpha}, β=${TRUE.beta}\n`); | |
| const t0 = performance.now(); | |
| const result = fitVpnls(P, T, L, 1000); | |
| const elapsed = ((performance.now() - t0) / 1000).toFixed(3); | |
| console.log("VPNLS result:"); | |
| console.log(` α = ${result.alpha.toFixed(2)}`); | |
| console.log(` β = ${result.beta.toFixed(2)}`); | |
| console.log(` A = ${result.A.toFixed(4)}`); | |
| console.log(` B = ${result.B.toFixed(4)}`); | |
| console.log(` E = ${result.E.toFixed(4)}`); | |
| console.log(` RSS = ${result.rss.toExponential(6)}`); | |
| console.log(` Time: ${elapsed}s`); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment