Skip to content

Instantly share code, notes, and snippets.

@jeffacce
Last active July 25, 2024 18:01
Show Gist options
  • Save jeffacce/fc7b5948ef6e5c9f269f69abb350f9e4 to your computer and use it in GitHub Desktop.
Save jeffacce/fc7b5948ef6e5c9f269f69abb350f9e4 to your computer and use it in GitHub Desktop.
Diffusion in 100 lines
# Classifier-free guidance diffusion on a toy spiral dataset
# Trains and infers a diffusion model on CPU, and generates a diffusion video
import io
import torch
import torchvision
import numpy as np
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
# === Hyperparams ===
T = 300
num_epochs = 200
bs = 512
lr = 1e-3
conditional_pdrop = 0.1
num_samples = 3000
num_classes = 3
gamma = 0.5 # classifier-free guidance weight
# === DDPM schedule ===
beta = torch.linspace(1e-4, 0.02, T+1) * 1000 / T
alpha = 1 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
def mpl2np(fig):
with io.BytesIO() as buff:
fig.savefig(buff, format='raw')
buff.seek(0)
return np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(*fig.canvas.get_width_height()[::-1], -1)
def get_spiral_dataset(num_samples, num_classes):
t = torch.linspace(0, 1, num_samples).repeat(num_classes)
a = 0.8 * t + 0.2
y = torch.repeat_interleave(torch.arange(num_classes), num_samples)
theta = (2 * t + y) * 2 * torch.pi / num_classes + 0.2 * torch.randn(num_classes * num_samples)
return torch.utils.data.TensorDataset(torch.stack([a * theta.sin(), a * theta.cos()]).T, y.long())
def forward_process(x_0, t):
eps = torch.randn_like(x_0)
x_t = torch.sqrt(alpha_bar[t]) * x_0 + torch.sqrt(1 - alpha_bar[t]) * eps
return eps, x_t
def langevin_cfg_once(model, x_t, t, y):
z = torch.randn_like(x_t) if t > 1 else torch.zeros_like(x_t)
ts = torch.Tensor([t / T] * bs).unsqueeze(-1)
eps_hat_cond = model(torch.cat([x_t, ts, y], dim=-1))
eps_hat_uncond = model(torch.cat([x_t, ts, torch.zeros(bs, num_classes)], dim=-1))
eps_hat = (1 + gamma) * eps_hat_cond - gamma * eps_hat_uncond
return 1 / torch.sqrt(alpha[t]) * (x - (1 - alpha[t]) / (torch.sqrt(1 - alpha_bar[t])) * eps_hat) + torch.sqrt(beta[t]) * z
ds = get_spiral_dataset(num_samples, num_classes)
dl = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True)
model = torchvision.ops.MLP(in_channels=2 + 1 + num_classes, hidden_channels=[512, 512, 512, 2], activation_layer=nn.LeakyReLU)
optim = torch.optim.Adam(model.parameters(), lr=lr)
# === Training loop ===
it = tqdm(range(num_epochs))
for i in it:
losses = []
for x_0, y in iter(dl):
y = torch.nn.functional.one_hot(y, num_classes=num_classes)
if torch.rand(1) < conditional_pdrop:
y.zero_() # random dropout for classifier-free guidance
t = torch.randint(1, T, size=(x_0.shape[0], 1))
eps, x_t = forward_process(x_0, t)
eps_hat = model(torch.cat([x_t, t / T, y], dim=-1))
optim.zero_grad()
loss = torch.nn.functional.mse_loss(eps_hat, eps)
loss.backward()
losses.append(loss.item())
optim.step()
it.set_postfix_str(f"{np.mean(losses):.3f}")
# === Inference ===
bs = 1024
y = torch.Tensor([1, 0, 0]).repeat((bs, 1))
with torch.no_grad():
x = torch.randn(bs, 2)
xs = [x.clone()]
for t in range(T-1, 0, -1):
x = langevin_cfg_once(model, x, t, y)
xs.append(x.clone())
arr = []
for x in tqdm(xs):
fig = plt.figure()
fig.gca().set_aspect('equal')
plt.scatter(*x.T, alpha=0.5, s=4)
plt.xlim([-3, 3])
plt.ylim([-3, 3])
arr.append(mpl2np(fig))
plt.close(fig)
arr = np.stack(arr)[..., :3] # drop alpha channel
torchvision.io.write_video('diffusion.mp4', arr, fps=30)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment