Created
December 21, 2025 09:44
-
-
Save sarthakbagaria/7ed67aeed62816a98babaa1a92713883 to your computer and use it in GitHub Desktop.
A script to analyze the geometry of model training using singular learning theory.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # python 3.12 | |
| # pip install torch devinterp matplotlib scikit-learn | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from devinterp.slt.callback import SamplerCallback | |
| from devinterp.slt.llc import LLCEstimator | |
| from devinterp.slt.sampler import sample | |
| from devinterp.optim import SGLD | |
| from sklearn.decomposition import PCA | |
| from matplotlib.animation import FuncAnimation, PillowWriter | |
| import matplotlib.pyplot as plt | |
| from matplotlib.widgets import Slider | |
| import numpy as np | |
| import os | |
| import warnings | |
| from tqdm import tqdm | |
| warnings.filterwarnings("ignore") | |
| # ========================================== | |
| # CONFIGURATION & TOGGLE | |
| # ========================================== | |
| USE_SYMMETRY = True # True adds the embedding vector, False concatenates them | |
| MULTIPLY = True # True computes (a*b) % P, False computes (a+b) % P | |
| P = 67 | |
| DEVICE = "cpu" | |
| TRAIN_PCT = 0.45 | |
| BATCH_SIZE = 512 | |
| ESTIMATION_INTERVAL = 100 | |
| EPOCHS = 10000 | |
| # ========================================== | |
| # MODELS | |
| # ========================================== | |
| class MLPConcat(nn.Module): | |
| def __init__(self, p, dim=128): | |
| super().__init__() | |
| self.embed = nn.Embedding(p, dim) | |
| self.linear = nn.Sequential( | |
| nn.Linear(dim * 2, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, p) | |
| ) | |
| def forward(self, x): | |
| e = self.embed(x).view(x.shape[0], -1) | |
| return self.linear(e) | |
| class MLPAdd(nn.Module): | |
| def __init__(self, p, dim=128): | |
| super().__init__() | |
| self.p = p | |
| self.embed = nn.Embedding(p, dim) | |
| self.linear = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, p) | |
| ) | |
| def forward(self, x): | |
| e = self.embed(x[:, 0]) + self.embed(x[:, 1]) | |
| return self.linear(e) | |
| # Initialize model | |
| ModelClass = MLPAdd if USE_SYMMETRY else MLPConcat | |
| model_name = ("Multiply " if MULTIPLY else "") + ("MLP-Add (Symmetric)" if USE_SYMMETRY else "MLP-Concat (Standard)") | |
| model = ModelClass(P).to(DEVICE) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.4) | |
| criterion = nn.CrossEntropyLoss() | |
| # ========================================== | |
| # DATA | |
| # ========================================== | |
| pairs = torch.cartesian_prod(torch.arange(P), torch.arange(P)) | |
| labels = (pairs[:, 0] * pairs[:, 1]) % P if MULTIPLY else (pairs[:, 0] + pairs[:, 1]) % P | |
| dataset = TensorDataset(pairs, labels) | |
| train_size = int(TRAIN_PCT * len(dataset)) | |
| train_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, len(dataset)-train_size]) | |
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) | |
| test_x, test_y = dataset.tensors[0][test_ds.indices].to(DEVICE), dataset.tensors[1][test_ds.indices].to(DEVICE) | |
| # ========================================== | |
| # CUSTOM ESTIMATORS | |
| # ========================================== | |
| class WBICEstimator(SamplerCallback): | |
| def __init__(self): | |
| super().__init__() | |
| self.losses = [] | |
| def __call__(self, loss, **_kwargs): self.losses.append(loss.item()) | |
| def get_results(self): return {"wbic": np.mean(self.losses)} | |
| class FunctionalVarianceEstimator(SamplerCallback): | |
| def __init__(self, x_fixed): | |
| super().__init__() | |
| self.x_fixed = x_fixed | |
| self.outputs = [] | |
| def __call__(self, model, **_kwargs): | |
| model.eval() | |
| with torch.no_grad(): | |
| self.outputs.append(model(self.x_fixed).cpu().numpy()) | |
| def get_results(self): | |
| stack = np.stack(self.outputs) | |
| return {"func_var": np.var(stack, axis=0).mean()} | |
| def get_embedding_pca(model): | |
| weights = model.embed.weight.detach().cpu().numpy() | |
| pca = PCA(n_components=min(6, weights.shape[1])) | |
| return pca.fit_transform(weights) | |
| # ========================================== | |
| # TRAINING LOOP | |
| # ========================================== | |
| history = {"train_acc": [], "test_acc": [], "llc": [], "wbic": [], "func_var": [], "embeddings": [], "epochs": []} | |
| print(f"Training {model_name}...") | |
| for epoch in tqdm(range(EPOCHS)): | |
| model.train() | |
| correct, total_loss = 0, 0 | |
| for x, y in train_loader: | |
| x, y = x.to(DEVICE), y.to(DEVICE) | |
| out = model(x); loss = criterion(out, y) | |
| loss.backward(); optimizer.step(); optimizer.zero_grad() | |
| correct += (out.argmax(1) == y).sum().item() | |
| total_loss += loss.item() | |
| if epoch % ESTIMATION_INTERVAL == 0: | |
| model.eval() | |
| train_acc = correct / train_size | |
| test_acc = (model(test_x).argmax(1) == test_y).sum().item() / len(test_ds) | |
| nbeta = train_size / np.log(train_size) | |
| llc_est = LLCEstimator(num_chains=1, num_draws=100, nbeta=nbeta, device=DEVICE, init_loss=total_loss/len(train_loader)) | |
| wbic_est = WBICEstimator() | |
| fvar_est = FunctionalVarianceEstimator(test_x[:100]) | |
| sample(model, train_loader, num_chains=1, num_draws=100, burnin=50, | |
| evaluate=lambda m, d: criterion(m(d[0].to(DEVICE)), d[1].to(DEVICE)), | |
| callbacks=[llc_est, wbic_est, fvar_est], | |
| optimizer_kwargs=dict(lr=1e-4, nbeta=nbeta), method=SGLD) | |
| llc = llc_est.get_results()["llc/mean"] | |
| wbic = wbic_est.get_results()["wbic"] | |
| fvar = fvar_est.get_results()["func_var"] | |
| history["train_acc"].append(train_acc) | |
| history["test_acc"].append(test_acc) | |
| history["llc"].append(llc) | |
| history["wbic"].append(wbic) | |
| history["func_var"].append(fvar) | |
| history["embeddings"].append(get_embedding_pca(model)) | |
| history["epochs"].append(epoch) | |
| print(f"Epoch {epoch}: Train Acc {train_acc:.4f}, Test Acc {test_acc:.4f}, LLC: {llc:.4f}, WBIC: {wbic:.4f}, FVar: {fvar:.4f}") | |
| # ========================================== | |
| # INTERACTIVE VISUALIZATION | |
| # ========================================== | |
| fig = plt.figure(figsize=(16, 12)) | |
| plt.suptitle(f"Analysis: {model_name}", fontsize=16) | |
| # Embedding axes (Top Row) | |
| ax_pca1 = fig.add_axes([0.05, 0.65, 0.25, 0.25]) | |
| ax_pca2 = fig.add_axes([0.37, 0.65, 0.25, 0.25]) | |
| ax_pca3 = fig.add_axes([0.70, 0.65, 0.25, 0.25]) | |
| # Metrics axes (Bottom Half) | |
| ax_acc = fig.add_axes([0.1, 0.42, 0.65, 0.15]) | |
| ax_slt = fig.add_axes([0.1, 0.24, 0.65, 0.15]) | |
| ax_fvar = fig.add_axes([0.1, 0.08, 0.65, 0.12]) | |
| # Initialization of Scatters | |
| initial_pca = history["embeddings"][0] | |
| color_map = plt.cm.get_cmap('hsv', P) | |
| colors = range(P) | |
| scat1 = ax_pca1.scatter(initial_pca[:, 0], initial_pca[:, 1], c=colors, cmap='Greys', edgecolors='k', s=40) | |
| ax_pca1.set_title("PCA 1 & 2") | |
| scat2 = ax_pca2.scatter(initial_pca[:, 2], initial_pca[:, 3], c=colors, cmap='Greys', edgecolors='k', s=40) | |
| ax_pca2.set_title("PCA 3 & 4") | |
| scat3 = ax_pca3.scatter(initial_pca[:, 4], initial_pca[:, 5], c=colors, cmap='Greys', edgecolors='k', s=40) | |
| ax_pca3.set_title("PCA 5 & 6") | |
| # Static Metric Plots (Showing full history) | |
| ax_acc.plot(history["epochs"], history["train_acc"], label="Train Acc", alpha=0.3, color='blue') | |
| ax_acc.plot(history["epochs"], history["test_acc"], label="Test Acc", color='green', linewidth=2) | |
| ax_acc.set_ylabel("Accuracy") | |
| ax_acc.legend(loc='lower right') | |
| ax_slt.plot(history["epochs"], history["llc"], color='red', label="LLC (λ)") | |
| ax_slt.set_ylabel("LLC") | |
| ax_wbic = ax_slt.twinx() | |
| ax_wbic.plot(history["epochs"], history["wbic"], color='purple', linestyle='--', label="WBIC") | |
| ax_wbic.set_ylabel("WBIC") | |
| ax_slt.legend(loc='upper left'); ax_wbic.legend(loc='upper right') | |
| ax_fvar.plot(history["epochs"], history["func_var"], color='orange', label="Func Var") | |
| ax_fvar.set_ylabel("Var[f]") | |
| ax_fvar.set_xlabel("Epochs") | |
| # Vertical indicators for current epoch on metrics | |
| vline1 = ax_acc.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5) | |
| vline2 = ax_slt.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5) | |
| vline3 = ax_fvar.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5) | |
| # Value displays (placed to the right of charts) | |
| txt_acc = fig.text(0.85, 0.495, '', verticalalignment='center') | |
| txt_slt = fig.text(0.85, 0.315, '', verticalalignment='center') | |
| txt_fvar = fig.text(0.85, 0.14, '', verticalalignment='center') | |
| # Slider Setup | |
| ax_slider = fig.add_axes([0.2, 0.01, 0.6, 0.02]) | |
| slider = Slider(ax_slider, f"Epoch/{ESTIMATION_INTERVAL}", 0, len(history["epochs"]) - 1, valinit=len(history["epochs"])-1, valstep=1) | |
| def update(_val): | |
| idx = int(slider.val) | |
| epoch = history["epochs"][idx] | |
| pca_data = history["embeddings"][idx] | |
| # Update Scatters | |
| scat1.set_offsets(pca_data[:, 0:2]) | |
| ax_pca1.set_xlim(pca_data[:, 0].min() - 0.1, pca_data[:, 0].max() + 0.1) | |
| ax_pca1.set_ylim(pca_data[:, 1].min() - 0.1, pca_data[:, 1].max() + 0.1) | |
| scat2.set_offsets(pca_data[:, 2:4]) | |
| ax_pca2.set_xlim(pca_data[:, 2].min() - 0.1, pca_data[:, 2].max() + 0.1) | |
| ax_pca2.set_ylim(pca_data[:, 3].min() - 0.1, pca_data[:, 3].max() + 0.1) | |
| scat3.set_offsets(pca_data[:, 4:6]) | |
| ax_pca3.set_xlim(pca_data[:, 4].min() - 0.1, pca_data[:, 4].max() + 0.1) | |
| ax_pca3.set_ylim(pca_data[:, 5].min() - 0.1, pca_data[:, 5].max() + 0.1) | |
| # Update time indicators | |
| vline1.set_xdata([epoch]) | |
| vline2.set_xdata([epoch]) | |
| vline3.set_xdata([epoch]) | |
| # Update text displays | |
| txt_acc.set_text(f"Train Acc: {history['train_acc'][idx]:.4f}\nTest Acc: {history['test_acc'][idx]:.4f}") | |
| txt_slt.set_text(f"LLC: {history['llc'][idx]:.4f}\nWBIC: {history['wbic'][idx]:.4f}") | |
| txt_fvar.set_text(f"Func Var: {history['func_var'][idx]:.4f}") | |
| fig.canvas.draw_idle() | |
| slider.on_changed(update) | |
| update(len(history["epochs"]) - 1) | |
| # Save Animation | |
| print("Saving animation...") | |
| os.makedirs("generated", exist_ok=True) | |
| # Calculate FPS for 10 seconds duration | |
| num_frames = len(history["epochs"]) | |
| duration_seconds = 10 | |
| calculated_fps = max(1, num_frames // duration_seconds) | |
| def animate(i): | |
| slider.set_val(i) | |
| return [] | |
| ani = FuncAnimation(fig, animate, frames=num_frames, interval=1000/calculated_fps) | |
| writer = PillowWriter(fps=calculated_fps) | |
| ani.save("generated/modulo_animation.gif", writer=writer) | |
| print(f"Animation saved to generated/modulo_animation.gif at {calculated_fps} FPS") | |
| ani.pause() | |
| slider.set_val(len(history["epochs"])-1) | |
| # Show plots | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment