Skip to content

Instantly share code, notes, and snippets.

@sarthakbagaria
Created December 21, 2025 09:44
Show Gist options
  • Select an option

  • Save sarthakbagaria/7ed67aeed62816a98babaa1a92713883 to your computer and use it in GitHub Desktop.

Select an option

Save sarthakbagaria/7ed67aeed62816a98babaa1a92713883 to your computer and use it in GitHub Desktop.
A script to analyze the geometry of model training using singular learning theory.
# python 3.12
# pip install torch devinterp matplotlib scikit-learn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from devinterp.slt.callback import SamplerCallback
from devinterp.slt.llc import LLCEstimator
from devinterp.slt.sampler import sample
from devinterp.optim import SGLD
from sklearn.decomposition import PCA
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import numpy as np
import os
import warnings
from tqdm import tqdm
warnings.filterwarnings("ignore")
# ==========================================
# CONFIGURATION & TOGGLE
# ==========================================
USE_SYMMETRY = True # True adds the embedding vector, False concatenates them
MULTIPLY = True # True computes (a*b) % P, False computes (a+b) % P
P = 67
DEVICE = "cpu"
TRAIN_PCT = 0.45
BATCH_SIZE = 512
ESTIMATION_INTERVAL = 100
EPOCHS = 10000
# ==========================================
# MODELS
# ==========================================
class MLPConcat(nn.Module):
def __init__(self, p, dim=128):
super().__init__()
self.embed = nn.Embedding(p, dim)
self.linear = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Linear(dim, p)
)
def forward(self, x):
e = self.embed(x).view(x.shape[0], -1)
return self.linear(e)
class MLPAdd(nn.Module):
def __init__(self, p, dim=128):
super().__init__()
self.p = p
self.embed = nn.Embedding(p, dim)
self.linear = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, p)
)
def forward(self, x):
e = self.embed(x[:, 0]) + self.embed(x[:, 1])
return self.linear(e)
# Initialize model
ModelClass = MLPAdd if USE_SYMMETRY else MLPConcat
model_name = ("Multiply " if MULTIPLY else "") + ("MLP-Add (Symmetric)" if USE_SYMMETRY else "MLP-Concat (Standard)")
model = ModelClass(P).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.4)
criterion = nn.CrossEntropyLoss()
# ==========================================
# DATA
# ==========================================
pairs = torch.cartesian_prod(torch.arange(P), torch.arange(P))
labels = (pairs[:, 0] * pairs[:, 1]) % P if MULTIPLY else (pairs[:, 0] + pairs[:, 1]) % P
dataset = TensorDataset(pairs, labels)
train_size = int(TRAIN_PCT * len(dataset))
train_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, len(dataset)-train_size])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_x, test_y = dataset.tensors[0][test_ds.indices].to(DEVICE), dataset.tensors[1][test_ds.indices].to(DEVICE)
# ==========================================
# CUSTOM ESTIMATORS
# ==========================================
class WBICEstimator(SamplerCallback):
def __init__(self):
super().__init__()
self.losses = []
def __call__(self, loss, **_kwargs): self.losses.append(loss.item())
def get_results(self): return {"wbic": np.mean(self.losses)}
class FunctionalVarianceEstimator(SamplerCallback):
def __init__(self, x_fixed):
super().__init__()
self.x_fixed = x_fixed
self.outputs = []
def __call__(self, model, **_kwargs):
model.eval()
with torch.no_grad():
self.outputs.append(model(self.x_fixed).cpu().numpy())
def get_results(self):
stack = np.stack(self.outputs)
return {"func_var": np.var(stack, axis=0).mean()}
def get_embedding_pca(model):
weights = model.embed.weight.detach().cpu().numpy()
pca = PCA(n_components=min(6, weights.shape[1]))
return pca.fit_transform(weights)
# ==========================================
# TRAINING LOOP
# ==========================================
history = {"train_acc": [], "test_acc": [], "llc": [], "wbic": [], "func_var": [], "embeddings": [], "epochs": []}
print(f"Training {model_name}...")
for epoch in tqdm(range(EPOCHS)):
model.train()
correct, total_loss = 0, 0
for x, y in train_loader:
x, y = x.to(DEVICE), y.to(DEVICE)
out = model(x); loss = criterion(out, y)
loss.backward(); optimizer.step(); optimizer.zero_grad()
correct += (out.argmax(1) == y).sum().item()
total_loss += loss.item()
if epoch % ESTIMATION_INTERVAL == 0:
model.eval()
train_acc = correct / train_size
test_acc = (model(test_x).argmax(1) == test_y).sum().item() / len(test_ds)
nbeta = train_size / np.log(train_size)
llc_est = LLCEstimator(num_chains=1, num_draws=100, nbeta=nbeta, device=DEVICE, init_loss=total_loss/len(train_loader))
wbic_est = WBICEstimator()
fvar_est = FunctionalVarianceEstimator(test_x[:100])
sample(model, train_loader, num_chains=1, num_draws=100, burnin=50,
evaluate=lambda m, d: criterion(m(d[0].to(DEVICE)), d[1].to(DEVICE)),
callbacks=[llc_est, wbic_est, fvar_est],
optimizer_kwargs=dict(lr=1e-4, nbeta=nbeta), method=SGLD)
llc = llc_est.get_results()["llc/mean"]
wbic = wbic_est.get_results()["wbic"]
fvar = fvar_est.get_results()["func_var"]
history["train_acc"].append(train_acc)
history["test_acc"].append(test_acc)
history["llc"].append(llc)
history["wbic"].append(wbic)
history["func_var"].append(fvar)
history["embeddings"].append(get_embedding_pca(model))
history["epochs"].append(epoch)
print(f"Epoch {epoch}: Train Acc {train_acc:.4f}, Test Acc {test_acc:.4f}, LLC: {llc:.4f}, WBIC: {wbic:.4f}, FVar: {fvar:.4f}")
# ==========================================
# INTERACTIVE VISUALIZATION
# ==========================================
fig = plt.figure(figsize=(16, 12))
plt.suptitle(f"Analysis: {model_name}", fontsize=16)
# Embedding axes (Top Row)
ax_pca1 = fig.add_axes([0.05, 0.65, 0.25, 0.25])
ax_pca2 = fig.add_axes([0.37, 0.65, 0.25, 0.25])
ax_pca3 = fig.add_axes([0.70, 0.65, 0.25, 0.25])
# Metrics axes (Bottom Half)
ax_acc = fig.add_axes([0.1, 0.42, 0.65, 0.15])
ax_slt = fig.add_axes([0.1, 0.24, 0.65, 0.15])
ax_fvar = fig.add_axes([0.1, 0.08, 0.65, 0.12])
# Initialization of Scatters
initial_pca = history["embeddings"][0]
color_map = plt.cm.get_cmap('hsv', P)
colors = range(P)
scat1 = ax_pca1.scatter(initial_pca[:, 0], initial_pca[:, 1], c=colors, cmap='Greys', edgecolors='k', s=40)
ax_pca1.set_title("PCA 1 & 2")
scat2 = ax_pca2.scatter(initial_pca[:, 2], initial_pca[:, 3], c=colors, cmap='Greys', edgecolors='k', s=40)
ax_pca2.set_title("PCA 3 & 4")
scat3 = ax_pca3.scatter(initial_pca[:, 4], initial_pca[:, 5], c=colors, cmap='Greys', edgecolors='k', s=40)
ax_pca3.set_title("PCA 5 & 6")
# Static Metric Plots (Showing full history)
ax_acc.plot(history["epochs"], history["train_acc"], label="Train Acc", alpha=0.3, color='blue')
ax_acc.plot(history["epochs"], history["test_acc"], label="Test Acc", color='green', linewidth=2)
ax_acc.set_ylabel("Accuracy")
ax_acc.legend(loc='lower right')
ax_slt.plot(history["epochs"], history["llc"], color='red', label="LLC (λ)")
ax_slt.set_ylabel("LLC")
ax_wbic = ax_slt.twinx()
ax_wbic.plot(history["epochs"], history["wbic"], color='purple', linestyle='--', label="WBIC")
ax_wbic.set_ylabel("WBIC")
ax_slt.legend(loc='upper left'); ax_wbic.legend(loc='upper right')
ax_fvar.plot(history["epochs"], history["func_var"], color='orange', label="Func Var")
ax_fvar.set_ylabel("Var[f]")
ax_fvar.set_xlabel("Epochs")
# Vertical indicators for current epoch on metrics
vline1 = ax_acc.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5)
vline2 = ax_slt.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5)
vline3 = ax_fvar.axvline(history["epochs"][0], color='black', linestyle=':', alpha=0.5)
# Value displays (placed to the right of charts)
txt_acc = fig.text(0.85, 0.495, '', verticalalignment='center')
txt_slt = fig.text(0.85, 0.315, '', verticalalignment='center')
txt_fvar = fig.text(0.85, 0.14, '', verticalalignment='center')
# Slider Setup
ax_slider = fig.add_axes([0.2, 0.01, 0.6, 0.02])
slider = Slider(ax_slider, f"Epoch/{ESTIMATION_INTERVAL}", 0, len(history["epochs"]) - 1, valinit=len(history["epochs"])-1, valstep=1)
def update(_val):
idx = int(slider.val)
epoch = history["epochs"][idx]
pca_data = history["embeddings"][idx]
# Update Scatters
scat1.set_offsets(pca_data[:, 0:2])
ax_pca1.set_xlim(pca_data[:, 0].min() - 0.1, pca_data[:, 0].max() + 0.1)
ax_pca1.set_ylim(pca_data[:, 1].min() - 0.1, pca_data[:, 1].max() + 0.1)
scat2.set_offsets(pca_data[:, 2:4])
ax_pca2.set_xlim(pca_data[:, 2].min() - 0.1, pca_data[:, 2].max() + 0.1)
ax_pca2.set_ylim(pca_data[:, 3].min() - 0.1, pca_data[:, 3].max() + 0.1)
scat3.set_offsets(pca_data[:, 4:6])
ax_pca3.set_xlim(pca_data[:, 4].min() - 0.1, pca_data[:, 4].max() + 0.1)
ax_pca3.set_ylim(pca_data[:, 5].min() - 0.1, pca_data[:, 5].max() + 0.1)
# Update time indicators
vline1.set_xdata([epoch])
vline2.set_xdata([epoch])
vline3.set_xdata([epoch])
# Update text displays
txt_acc.set_text(f"Train Acc: {history['train_acc'][idx]:.4f}\nTest Acc: {history['test_acc'][idx]:.4f}")
txt_slt.set_text(f"LLC: {history['llc'][idx]:.4f}\nWBIC: {history['wbic'][idx]:.4f}")
txt_fvar.set_text(f"Func Var: {history['func_var'][idx]:.4f}")
fig.canvas.draw_idle()
slider.on_changed(update)
update(len(history["epochs"]) - 1)
# Save Animation
print("Saving animation...")
os.makedirs("generated", exist_ok=True)
# Calculate FPS for 10 seconds duration
num_frames = len(history["epochs"])
duration_seconds = 10
calculated_fps = max(1, num_frames // duration_seconds)
def animate(i):
slider.set_val(i)
return []
ani = FuncAnimation(fig, animate, frames=num_frames, interval=1000/calculated_fps)
writer = PillowWriter(fps=calculated_fps)
ani.save("generated/modulo_animation.gif", writer=writer)
print(f"Animation saved to generated/modulo_animation.gif at {calculated_fps} FPS")
ani.pause()
slider.set_val(len(history["epochs"])-1)
# Show plots
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment