Last active
July 25, 2024 18:01
-
-
Save jeffacce/fc7b5948ef6e5c9f269f69abb350f9e4 to your computer and use it in GitHub Desktop.
Diffusion in 100 lines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Classifier-free guidance diffusion on a toy spiral dataset | |
# Trains and infers a diffusion model on CPU, and generates a diffusion video | |
import io | |
import torch | |
import torchvision | |
import numpy as np | |
from torch import nn | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
# === Hyperparams === | |
T = 300 | |
num_epochs = 200 | |
bs = 512 | |
lr = 1e-3 | |
conditional_pdrop = 0.1 | |
num_samples = 3000 | |
num_classes = 3 | |
gamma = 0.5 # classifier-free guidance weight | |
# === DDPM schedule === | |
beta = torch.linspace(1e-4, 0.02, T+1) * 1000 / T | |
alpha = 1 - beta | |
alpha_bar = torch.cumprod(alpha, dim=0) | |
def mpl2np(fig): | |
with io.BytesIO() as buff: | |
fig.savefig(buff, format='raw') | |
buff.seek(0) | |
return np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(*fig.canvas.get_width_height()[::-1], -1) | |
def get_spiral_dataset(num_samples, num_classes): | |
t = torch.linspace(0, 1, num_samples).repeat(num_classes) | |
a = 0.8 * t + 0.2 | |
y = torch.repeat_interleave(torch.arange(num_classes), num_samples) | |
theta = (2 * t + y) * 2 * torch.pi / num_classes + 0.2 * torch.randn(num_classes * num_samples) | |
return torch.utils.data.TensorDataset(torch.stack([a * theta.sin(), a * theta.cos()]).T, y.long()) | |
def forward_process(x_0, t): | |
eps = torch.randn_like(x_0) | |
x_t = torch.sqrt(alpha_bar[t]) * x_0 + torch.sqrt(1 - alpha_bar[t]) * eps | |
return eps, x_t | |
def langevin_cfg_once(model, x_t, t, y): | |
z = torch.randn_like(x_t) if t > 1 else torch.zeros_like(x_t) | |
ts = torch.Tensor([t / T] * bs).unsqueeze(-1) | |
eps_hat_cond = model(torch.cat([x_t, ts, y], dim=-1)) | |
eps_hat_uncond = model(torch.cat([x_t, ts, torch.zeros(bs, num_classes)], dim=-1)) | |
eps_hat = (1 + gamma) * eps_hat_cond - gamma * eps_hat_uncond | |
return 1 / torch.sqrt(alpha[t]) * (x - (1 - alpha[t]) / (torch.sqrt(1 - alpha_bar[t])) * eps_hat) + torch.sqrt(beta[t]) * z | |
ds = get_spiral_dataset(num_samples, num_classes) | |
dl = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True) | |
model = torchvision.ops.MLP(in_channels=2 + 1 + num_classes, hidden_channels=[512, 512, 512, 2], activation_layer=nn.LeakyReLU) | |
optim = torch.optim.Adam(model.parameters(), lr=lr) | |
# === Training loop === | |
it = tqdm(range(num_epochs)) | |
for i in it: | |
losses = [] | |
for x_0, y in iter(dl): | |
y = torch.nn.functional.one_hot(y, num_classes=num_classes) | |
if torch.rand(1) < conditional_pdrop: | |
y.zero_() # random dropout for classifier-free guidance | |
t = torch.randint(1, T, size=(x_0.shape[0], 1)) | |
eps, x_t = forward_process(x_0, t) | |
eps_hat = model(torch.cat([x_t, t / T, y], dim=-1)) | |
optim.zero_grad() | |
loss = torch.nn.functional.mse_loss(eps_hat, eps) | |
loss.backward() | |
losses.append(loss.item()) | |
optim.step() | |
it.set_postfix_str(f"{np.mean(losses):.3f}") | |
# === Inference === | |
bs = 1024 | |
y = torch.Tensor([1, 0, 0]).repeat((bs, 1)) | |
with torch.no_grad(): | |
x = torch.randn(bs, 2) | |
xs = [x.clone()] | |
for t in range(T-1, 0, -1): | |
x = langevin_cfg_once(model, x, t, y) | |
xs.append(x.clone()) | |
arr = [] | |
for x in tqdm(xs): | |
fig = plt.figure() | |
fig.gca().set_aspect('equal') | |
plt.scatter(*x.T, alpha=0.5, s=4) | |
plt.xlim([-3, 3]) | |
plt.ylim([-3, 3]) | |
arr.append(mpl2np(fig)) | |
plt.close(fig) | |
arr = np.stack(arr)[..., :3] # drop alpha channel | |
torchvision.io.write_video('diffusion.mp4', arr, fps=30) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment