Skip to content

Instantly share code, notes, and snippets.

@MathieuTuli
Created January 12, 2025 19:07
Show Gist options
  • Select an option

  • Save MathieuTuli/b0859a8a62439999a0a33d55cb297189 to your computer and use it in GitHub Desktop.

Select an option

Save MathieuTuli/b0859a8a62439999a0a33d55cb297189 to your computer and use it in GitHub Desktop.
Standalone Conditional Flow Matching (CFM) for Image Generation
from functools import partial
from typing import Tuple
from pathlib import Path
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision
import torch
import math
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
/ half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def avg_pool_nd(dims, *args, **kwargs):
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class Upsample(nn.Module):
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels,
self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = nn.functional.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = nn.functional.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, self.channels,
self.out_channels, 3, stride=stride, padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(nn.Module):
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels,
self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1)
def forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class QKVAttentionLegacy(nn.Module):
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,
length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
assert num_heads > 0
self.norm = nn.GroupNorm(32, channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class TimestepEmbedSequential(nn.Sequential):
"""A sequential module that passes timestep embeddings to the children that support it as an
extra input."""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, ResBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class UNet(nn.Module):
def __init__(self,
dim: int = 28,
in_channels: int = 1,
out_channels: int = 1,
model_channels: int = 32,
num_classes: int = 10,
channel_mult=(1, 2, 2),
num_res_blocks: int = 1,
attention_resolutions: str = "16",
dims: int = 2,
num_heads: int = 1,
num_head_channels: int = -1,
num_heads_upsample: int = -1,
use_scale_shift_norm: bool = False,
dropout: float = 0,
resblock_updown: bool = False,
):
super().__init__()
self.model_channels = model_channels
time_embed_dim = model_channels * 4
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(dim // int(res))
attention_resolutions = attention_ds
conv_resample = True
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
self.num_classes = num_classes
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(
conv_nd(dims, in_channels, ch, 3, padding=1))]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(32, ch),
nn.SiLU(),
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
)
def forward(self, t, x, y=None):
timesteps = t
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
while timesteps.dim() > 1:
print(timesteps.shape)
timesteps = timesteps[:, 0]
if timesteps.dim() == 0:
timesteps = timesteps.repeat(x.shape[0])
hs = []
emb = self.time_embed(timestep_embedding(
timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],), f"{y.shape}, {x.shape}"
emb = emb + self.label_emb(y)
h = x.type(torch.float32)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
class OptimalTransportPath:
def __init__(self, sig_min: float = 1e-5) -> None:
self.sig_min = sig_min
def sample(self, x1: torch.Tensor, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
t = t.view(-1, 1, 1, 1)
xt = x1 * t + (1 - (1 - self.sig_min) * t) * x0
vt = x1 - (1 - self.sig_min) * x0
return xt, vt
if __name__ == "__main__":
device = torch.device("cuda")
model = UNet().to(device)
path = OptimalTransportPath()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
B = 128
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../mnist_data', download=True, train=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=B, shuffle=True)
data_iter = iter(train_loader)
print("Running Flow Matching with OT")
for i in range(10000):
optim.zero_grad()
try:
x1, c = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
x1, c = next(data_iter)
x1, c = x1.to(device), c.to(device)
x0 = torch.randn_like(x1, device=device)
t = torch.rand(x1.shape[0], device=device)
xt, vt = path.sample(x1, x0, t)
loss = torch.pow(model(t, xt, c) - vt, 2).mean()
loss.backward()
optim.step()
if (i + 1) % 100 == 0:
print(f"| iter {i+1:6d} | loss {loss.item():8.3f}")
# sample using midpoint solver
xt = torch.rand(10, 1, 28, 28, dtype=torch.float32, device=device)
T = torch.linspace(0, 1, 11, device=device) # sample times
c = torch.arange(10, device=device)
odefunc = partial(model.forward, y=c)
sol = list()
for i in range(10):
t_start = T[i].expand(xt.shape[0])
t_end = T[i + 1].expand(xt.shape[0])
xt = xt + (t_end - t_start)[..., None, None, None] * odefunc(
t=t_start + (t_end - t_start) /
2, x=xt + odefunc(x=xt, t=t_start) * ((t_end - t_start) / 2)[..., None, None, None])
sol.append(xt)
Path("outputs").mkdir(exist_ok=True)
num_timesteps = len(sol)
fig, axes = plt.subplots(10, num_timesteps, figsize=(2*num_timesteps, 20))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for t in range(num_timesteps):
for i in range(10):
ax = axes[i, t]
ax.imshow(sol[t][i].squeeze().detach().cpu(), cmap='gray')
ax.axis('off')
if t == 0:
ax.set_ylabel(f'Sample {i}')
if i == 0:
ax.set_title(f'Step {t}')
plt.savefig('outputs/diffusion_process.png', bbox_inches='tight', dpi=150)
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment