Skip to content

Instantly share code, notes, and snippets.

@proger
Created May 26, 2024 16:27
Show Gist options
  • Save proger/879bb49fd52868bf653a8ce2e97f8114 to your computer and use it in GitHub Desktop.
Save proger/879bb49fd52868bf653a8ce2e97f8114 to your computer and use it in GitHub Desktop.
#%%
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.bottom'] = False
torch.manual_seed(3407)
dataset = []
labels = []
plt.figure(figsize=(12, 3))
for label, mean in enumerate([-0.75, -0.25, 0.25, 0.75]):
mu = mean + torch.randn(1000) * 0.05
plt.hist(mu, bins=100, alpha=0.5)
dataset.append(mu)
labels.append(label)
dataset = torch.cat(dataset, dim=0).unsqueeze(-1)
labels = torch.tensor(labels)
class Siren(nn.Module):
def __init__(self, channels=1, dim=512, bandwidth=20):
super().__init__()
self.channels = channels
self.input = nn.Linear(channels + 1, dim, bias=False)
self.hidden = nn.Linear(dim, dim, bias=False)
self.output = nn.Linear(dim, channels, bias=False)
self.bandwidth = bandwidth
with torch.no_grad():
self.input.weight.uniform_(-1 / 2, 1 / 2)
l = (6/dim)**0.5 / bandwidth
self.hidden.weight.uniform_(-l, l)
self.output.weight.uniform_(-l, l)
def forward(self, mu, t):
x = self.input(torch.cat([mu, t], dim=-1))
x = (self.bandwidth * x).sin()
x = self.hidden(x)
x = (self.bandwidth * x).sin()
x = self.output(x)
#x.register_hook(lambda grad: print('output grad', grad.norm()))
return x
class Sampler(nn.Module):
def __init__(self, channels=1, sigma=0.01, unroll_steps=10):
super().__init__()
self.state_dim = channels
self.sigma = nn.Parameter(torch.tensor(sigma), requires_grad=False)
self.h_min = -1
self.h_max = 1
self.unroll_steps = unroll_steps
self.score = Siren(channels=channels, dim=256, bandwidth=20)
self.step_ids = nn.Parameter(torch.arange(self.unroll_steps+1), requires_grad=False)
self.times = nn.Parameter(((self.step_ids / self.unroll_steps).repeat(1,1).T).clip(1e-6, 1), requires_grad=False) # T,1
def step(self, state, t):
update = self.score(state, t)
gamma = 1 - self.sigma**(2*t)
f = 1/gamma
i = -((1 - gamma)/gamma + 1e-6).sqrt()
h = f * state + i * update
return torch.where(t < 1e-6, torch.zeros_like(h), h.clip(self.h_min, self.h_max))
def forward(self, x_NC):
x_NTC = x_NC.unsqueeze(1) # N,T,C
#t = self.times.T.unsqueeze(-1).repeat(x_NTC.shape[0], 1, 1) # pretend timesteps are fixed
t = torch.rand(x_NC.shape[0], self.unroll_steps, 1, device=x_NC.device).clip(1e-6, 1)
gamma_NT1 = 1 - self.sigma**(2*t)
std_NT1 = (gamma_NT1 * (1 - gamma_NT1) + 1e-6).sqrt()
mu_NTC = gamma_NT1 * x_NTC + torch.randn_like(x_NTC) * std_NT1
x1_NTC = self.step(mu_NTC, t)
scale_NT1 = math.log(self.sigma) / self.sigma**(2*t)
diff_NTC = x1_NTC - x_NTC
norm_NT = (diff_NTC.square().sum(-1) + 1e-6).sqrt()
mse_NT = -norm_NT * scale_NT1.squeeze(-1)
return mse_NT
def normal_update(self, prior_precision, state, likelihood_precision, obs):
return (prior_precision * state + likelihood_precision * obs) / (prior_precision + likelihood_precision)
def generate(self, batch_size, ax=None):
device = next(self.parameters()).device
state = torch.zeros(batch_size, self.unroll_steps+1, self.state_dim, device=device)
t = self.times
ids = self.step_ids
T = self.unroll_steps
likelihood_precision = self.sigma ** (-2 * (ids + 1) / T) * (1 - self.sigma.pow(2/T))
prior_precision = torch.cat([torch.ones(1, device=device), likelihood_precision.cumsum(0)], dim=0)
out = self.step(state[:, 0], t[None, 0].repeat(batch_size, 1))
for step in range(1, T+1):
y = out + torch.randn_like(out) / likelihood_precision[step-1].sqrt()
state[:, step] = self.normal_update(prior_precision[step-1], state[:, step-1], likelihood_precision[step-1], y)
out = self.step(state[:, step], t[None, step].repeat(batch_size, 1))
if ax is not None:
for i in range(batch_size):
ax.plot(state[i, :, 0].detach().numpy(), alpha=0.3)
return out
torch.manual_seed(6)
torch.set_anomaly_enabled(True)
flow = Sampler(channels=dataset.size(1), sigma=0.01).to('cuda')
dataset = dataset.to('cuda')
opt = torch.optim.Adam(flow.parameters(), lr=1e-3)
train_steps = 200
trace_loss = torch.zeros(train_steps)
trace_gnorm = torch.zeros(train_steps)
for i in range(train_steps):
opt.zero_grad()
minibatch = torch.arange(len(dataset)) # full batch training
x = dataset[minibatch]
losses = flow(x)
loss = losses.mean()
assert not torch.isnan(loss), f'loss is nan at step {i}'
loss.backward()
trace_loss[i] = loss.item()
trace_gnorm[i] = torch.nn.utils.clip_grad_norm_(flow.parameters(), 1.0)
opt.step()
fig, (axl, axc, axr, axf) = plt.subplots(1, 4, figsize=(20, 3))
axl.plot(trace_loss)
axl.set_title('loss')
axc.plot(trace_gnorm)
axc.set_title('gradient norms')
axl.set_xlim(0, train_steps)
axc.set_xlim(0, train_steps)
flow = flow.to('cpu')
with torch.no_grad():
gen = flow.generate(batch_size=1000, ax=axf)
axr.set_title('sampled data')
axf.set_title('sample flows')
axf.set_ylim(-1, 1)
for i in range(dataset.size(1)):
axr.hist(gen[:, i].numpy(), bins=100, alpha=0.5);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment