Skip to content

Instantly share code, notes, and snippets.

@riga
Created October 30, 2023 09:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save riga/8c8bb2b00070e1f5d3accdfdde5f2ca6 to your computer and use it in GitHub Desktop.
Save riga/8c8bb2b00070e1f5d3accdfdde5f2ca6 to your computer and use it in GitHub Desktop.
Test partial gradient stopping in PyTorch
# coding: utf-8
"""
Setup via
> pip install torch torchvision
"""
from __future__ import annotations
from contextlib import contextmanager
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
batch_size = 512
train_loader = torch.utils.data.DataLoader(
dataset=dsets.MNIST(
root="./data",
train=True,
transform=transforms.ToTensor(),
download=True,
),
batch_size=batch_size,
shuffle=True,
)
valid_loader = torch.utils.data.DataLoader(
dataset=dsets.MNIST(
root="./data",
train=False,
transform=transforms.ToTensor(),
),
batch_size=batch_size,
shuffle=False,
)
@contextmanager
def empty_context():
yield
class NN(nn.Module):
def __init__(self, *, n_in: int | None, latent_space: list[int], n_out: int | None):
super().__init__()
self.n_layers = len(latent_space)
# linear layers
for i, n_units in enumerate(latent_space):
linear = nn.Linear(n_in if i == 0 and n_in is not None else n_units, n_units)
setattr(self, f"linear_{i}", linear)
# activations
for i in range(self.n_layers):
setattr(self, f"activation_{i}", nn.Tanh())
# output layer
self.last_layer = None
if n_out is not None:
self.last_layer = nn.Linear(n_units, 10)
def forward(self, x):
out = x
for i in range(self.n_layers):
linear = getattr(self, f"linear_{i}")
act = getattr(self, f"activation_{i}")
out = act(linear(out))
# optional last layer
if self.last_layer is not None:
out = self.last_layer(out)
return out
class CombinedNN(NN):
def __init__(self, *, pre_latent_space: list[int], latent_space: list[int]):
super().__init__(n_in=pre_latent_space[-1], latent_space=latent_space, n_out=10)
# preprocessing NN
self.pre_nn = NN(n_in=28 * 28, latent_space=pre_latent_space, n_out=None)
def forward(self, x, stop_pre_gradients: bool = False):
# evaluate the pre NN, with or without gradients
context = torch.no_grad if stop_pre_gradients else empty_context
with context():
out = self.pre_nn(x)
# normal forward pass of this nn
return super().forward(out)
# define model, loss function and optimizer
model = CombinedNN(pre_latent_space=[32], latent_space=[10])
model_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# training loop
step = 0
for epoch in range(100):
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
# forward pass, enable the pre-NN training only after a certain point!
outputs = model(
images.view(-1, 28 * 28).requires_grad_(),
stop_pre_gradients=step < 1000,
)
# loss, back-prop and update step
loss = model_loss(outputs, labels)
loss.backward()
optimizer.step()
# validation
if step % 200 == 0:
correct = 0
total = 0
for images, labels in valid_loader:
predicted = torch.max(model(images.view(-1, 28 * 28)).data, 1)[1]
correct += (predicted == labels).sum()
total += labels.size(0)
accuracy = 100 * correct / total
print(f"step {step}, training loss: {loss.item()}, valid accuracy: {accuracy:.2f}%")
step += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment