Skip to content

Instantly share code, notes, and snippets.

@izmailovpavel
Created June 6, 2023 04:40
Show Gist options
  • Save izmailovpavel/e7f94f71af2b9949b5f02b7a49b0805c to your computer and use it in GitHub Desktop.
Save izmailovpavel/e7f94f71af2b9949b5f02b7a49b0805c to your computer and use it in GitHub Desktop.
mode_connectivity_example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=0\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torchvision\n",
"import torch.nn.functional as F\n",
"import tqdm\n",
"\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class LeNet(torch.nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(LeNet, self).__init__()\n",
" self.conv1 = torch.nn.Conv2d(1, 32, 5, padding=2)\n",
" self.conv2 = torch.nn.Conv2d(32, 64, 5)\n",
" self.fc1 = torch.nn.Linear(64*5*5, 100)\n",
" self.fc2 = torch.nn.Linear(100, 10)\n",
" self.num_flat_features = 64*5*5\n",
" \n",
" def forward(self, x):\n",
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
" x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n",
" x = x.view(-1, self.num_flat_features)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"transform = torchvision.transforms.ToTensor()\n",
"\n",
"trainset = torchvision.datasets.MNIST(root=\"/datasets/\", train=True, transform=transform)\n",
"testset = torchvision.datasets.MNIST(root=\"/datasets/\", train=False, transform=transform)\n",
"\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = LeNet().cuda()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 10\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, loader):\n",
" criterion = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
" loss = 0.\n",
" with torch.no_grad():\n",
" correct, total = 0, 0\n",
" for (x, y) in loader:\n",
" x, y = x.cuda(), y.cuda()\n",
" logits = model(x)\n",
" preds = torch.argmax(logits, axis=1)\n",
" correct += (preds == y).sum().item()\n",
" loss += criterion(logits, y).item()\n",
" total += len(x)\n",
" return correct / total, loss / total"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.1135, 11805508.464)"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate(model, testloader)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 246.29it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 347.23it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9844\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:09<00:00, 202.46it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 343.35it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9867\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 256.00it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 346.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9897\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:12<00:00, 149.80it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 342.27it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9917\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 256.36it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 344.25it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9928\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 255.57it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 347.69it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9927\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:08<00:00, 217.34it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 343.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9931\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 256.49it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 347.81it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9936\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 256.48it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 347.29it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9933\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:07<00:00, 256.88it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 347.28it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9933\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"for epoch in range(n_epochs):\n",
" total_loss = 0.\n",
" pbar = tqdm.tqdm(trainloader)\n",
" for (x, y) in pbar:\n",
" x, y = x.cuda(), y.cuda()\n",
" optimizer.zero_grad()\n",
" logits = model(x)\n",
" loss = criterion(logits, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" scheduler.step()\n",
" acc, _ = evaluate(model, testloader)\n",
" print(f\"Test accuracy: {acc}\")\n",
" pbar.set_description(f\"Epoch {epoch}: Loss {loss:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"# torch.save(model.state_dict(), \"mnist_model1.pt\")\n",
"# torch.save(model.state_dict(), \"mnist_model2.pt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mode Connectivity"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model = LeNet().cuda()\n",
"model.load_state_dict(torch.load(\"mnist_model1.pt\"))\n",
"weights_1 = torch.nn.utils.parameters_to_vector(model.parameters())\n",
"\n",
"model.load_state_dict(torch.load(\"mnist_model2.pt\"))\n",
"weights_2 = torch.nn.utils.parameters_to_vector(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def get_bezier_curve(weights_1, weights_2):\n",
" def curve(theta, t):\n",
" return (1 - t)**2 * weights_1 + 2 * t * (1 - t) * theta + weights_2 * t**2\n",
" theta = torch.from_numpy((weights_1 + weights_2).cpu().detach().numpy() / 2)\n",
" return curve, theta\n",
"\n",
"curve, theta = get_bezier_curve(weights_1, weights_2)\n",
"theta = theta.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 313/313 [00:01<00:00, 299.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9933\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 313/313 [00:00<00:00, 344.51it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9933\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 313/313 [00:00<00:00, 346.92it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9553\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"w = curve(theta, 0)\n",
"torch.nn.utils.vector_to_parameters(w, model.parameters())\n",
"print(evaluate(model, testloader))\n",
"\n",
"w = curve(theta, 1)\n",
"torch.nn.utils.vector_to_parameters(w, model.parameters())\n",
"print(evaluate(model, testloader))\n",
"\n",
"w = curve(theta, 0.5)\n",
"torch.nn.utils.vector_to_parameters(w, model.parameters())\n",
"print(evaluate(model, testloader))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 5\n",
"optimizer = torch.optim.Adam([theta], lr=1.e-4)\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([213206])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grads.shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:12<00:00, 151.85it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 343.60it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9891\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:08<00:00, 232.31it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 349.25it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9899\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:14<00:00, 133.84it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 344.55it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9879\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:13<00:00, 135.04it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 342.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.989\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1875/1875 [00:13<00:00, 138.75it/s]\n",
"100%|██████████| 313/313 [00:00<00:00, 319.25it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9878\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"for epoch in range(n_epochs):\n",
" total_loss = 0.\n",
" pbar = tqdm.tqdm(trainloader)\n",
" for (x, y) in pbar:\n",
" x, y = x.cuda(), y.cuda()\n",
" optimizer.zero_grad()\n",
" \n",
" t = torch.rand(1).cuda()\n",
" curve_t_fn = lambda theta: curve(theta, t)\n",
" w = curve_t_fn(theta)\n",
" \n",
" torch.nn.utils.vector_to_parameters(w, model.parameters())\n",
" logits = model(x)\n",
" loss = criterion(logits, y)\n",
" loss.backward()\n",
" \n",
" # chain rule\n",
" grads = torch.nn.utils.parameters_to_vector([p.grad for p in model.parameters()])\n",
" theta.grad = torch.autograd.functional.vjp(curve_t_fn, theta, grads)[1]\n",
" \n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" scheduler.step()\n",
" acc, _ = evaluate(model, testloader)\n",
" print(f\"Test accuracy: {acc}\")\n",
" pbar.set_description(f\"Epoch {epoch}: Loss {loss:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"def loss_surface(weights_2, weights_1, theta, N, model, loader):\n",
" u = weights_2 - weights_1\n",
" v = theta - weights_1\n",
" v = (v - u * (u @ v) / (u @ u))\n",
" us = torch.linspace(-.5, 1.5, N)\n",
" vs = torch.linspace(-.2, 1.2, N)\n",
" grid_us, grid_vs = torch.meshgrid(us, vs, indexing='ij')\n",
" accs = torch.zeros_like(grid_us)\n",
" losses = torch.zeros_like(grid_us)\n",
" with torch.no_grad():\n",
" for i in tqdm.tqdm(range(N)):\n",
" for j in range(N):\n",
" x, y = grid_us[i][j], grid_vs[i][j]\n",
" w = x * u + y * v + weights_1\n",
" torch.nn.utils.vector_to_parameters(w, model.parameters())\n",
" acc, loss = evaluate(model, loader)\n",
" accs[i, j] = acc\n",
" losses[i, j] = loss\n",
"\n",
" return grid_us, grid_vs, accs, losses"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"from itertools import islice\n",
"import cmocean"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 30/30 [00:23<00:00, 1.27it/s]\n"
]
}
],
"source": [
"grid_us, grid_vs, accs, losses = loss_surface(\n",
" weights_2, weights_1, theta, N=30, model=model, loader=list(islice(testloader, 40)))"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f5d243948b0>]"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"levels = [0.01, 0.25, 0.5, 0.75, 0.85, 0.90, 0.95, 0.98, 0.99, 1.,]\n",
"\n",
"xs = np.linspace(1, 250, len(levels))\n",
"cmap_colors = [cmocean.cm.solar(int(x)) for x in xs]\n",
"\n",
"plt.contour(grid_us, grid_vs, accs, zorder=1, levels=levels,\n",
" colors=[cmap_colors[i] for i in range(len(levels))])\n",
"plt.contourf(grid_us, grid_vs, accs, zorder=0, alpha=0.8, levels=levels,\n",
" colors=[cmap_colors[i] for i in range(len(levels))])\n",
"plt.colorbar()\n",
"\n",
"w1 = np.array([0., 0.])\n",
"w2 = np.array([1., 0.])\n",
"th = np.array([0.5, 1.])\n",
"curve_xy = np.stack([w1 * t**2 + w2 * (1 - t)**2 + 2 * th * t * (1 - t) for t in np.linspace(0, 1, 20)])\n",
"\n",
"plt.plot([w1[0], w2[0], th[0]], [w1[1], w2[1], th[1]], 'o', ms=5, mec='black', mfc= 'black', mfcalt='black')\n",
"plt.plot([w1[0]], [w1[1] + 0.12], marker='$w_1$', mec='black', mfc='black', mfcalt='black', ms=20)\n",
"plt.plot([w2[0]], [w2[1] + 0.12], marker='$w_2$', mec='black', mfc='black', mfcalt='black', ms=20)\n",
"plt.plot([th[0]], [th[1] + 0.12], marker=r'$\\theta$', mec='black', mfc='black', mfcalt='black', ms=12)\n",
"\n",
"plt.plot(curve_xy[:, 0], curve_xy[:, 1], \"--k\")"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(20, 2)"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"curve_xy.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py38",
"language": "python",
"name": "py38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment