Created
January 12, 2025 19:07
-
-
Save MathieuTuli/b0859a8a62439999a0a33d55cb297189 to your computer and use it in GitHub Desktop.
Standalone Conditional Flow Matching (CFM) for Image Generation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import partial | |
| from typing import Tuple | |
| from pathlib import Path | |
| from torchvision import datasets, transforms | |
| import matplotlib.pyplot as plt | |
| import torch.nn as nn | |
| import torchvision | |
| import torch | |
| import math | |
| def timestep_embedding(timesteps, dim, max_period=10000): | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) | |
| * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) | |
| / half | |
| ) | |
| args = timesteps[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat( | |
| [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def avg_pool_nd(dims, *args, **kwargs): | |
| if dims == 1: | |
| return nn.AvgPool1d(*args, **kwargs) | |
| elif dims == 2: | |
| return nn.AvgPool2d(*args, **kwargs) | |
| elif dims == 3: | |
| return nn.AvgPool3d(*args, **kwargs) | |
| raise ValueError(f"unsupported dimensions: {dims}") | |
| def conv_nd(dims, *args, **kwargs): | |
| if dims == 1: | |
| return nn.Conv1d(*args, **kwargs) | |
| elif dims == 2: | |
| return nn.Conv2d(*args, **kwargs) | |
| elif dims == 3: | |
| return nn.Conv3d(*args, **kwargs) | |
| raise ValueError(f"unsupported dimensions: {dims}") | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class Upsample(nn.Module): | |
| def __init__(self, channels, use_conv, dims=2, out_channels=None): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.dims = dims | |
| if use_conv: | |
| self.conv = conv_nd(dims, self.channels, | |
| self.out_channels, 3, padding=1) | |
| def forward(self, x): | |
| assert x.shape[1] == self.channels | |
| if self.dims == 3: | |
| x = nn.functional.interpolate( | |
| x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") | |
| else: | |
| x = nn.functional.interpolate(x, scale_factor=2, mode="nearest") | |
| if self.use_conv: | |
| x = self.conv(x) | |
| return x | |
| class Downsample(nn.Module): | |
| def __init__(self, channels, use_conv, dims=2, out_channels=None): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.dims = dims | |
| stride = 2 if dims != 3 else (1, 2, 2) | |
| if use_conv: | |
| self.op = conv_nd(dims, self.channels, | |
| self.out_channels, 3, stride=stride, padding=1) | |
| else: | |
| assert self.channels == self.out_channels | |
| self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) | |
| def forward(self, x): | |
| assert x.shape[1] == self.channels | |
| return self.op(x) | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| emb_channels, | |
| dropout, | |
| out_channels=None, | |
| use_conv=False, | |
| use_scale_shift_norm=False, | |
| dims=2, | |
| up=False, | |
| down=False, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.emb_channels = emb_channels | |
| self.dropout = dropout | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.use_scale_shift_norm = use_scale_shift_norm | |
| self.in_layers = nn.Sequential( | |
| nn.GroupNorm(32, channels), | |
| nn.SiLU(), | |
| conv_nd(dims, channels, self.out_channels, 3, padding=1), | |
| ) | |
| self.updown = up or down | |
| if up: | |
| self.h_upd = Upsample(channels, False, dims) | |
| self.x_upd = Upsample(channels, False, dims) | |
| elif down: | |
| self.h_upd = Downsample(channels, False, dims) | |
| self.x_upd = Downsample(channels, False, dims) | |
| else: | |
| self.h_upd = self.x_upd = nn.Identity() | |
| self.emb_layers = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear( | |
| emb_channels, | |
| 2 * self.out_channels if use_scale_shift_norm else self.out_channels, | |
| ), | |
| ) | |
| self.out_layers = nn.Sequential( | |
| nn.GroupNorm(32, self.out_channels), | |
| nn.SiLU(), | |
| nn.Dropout(p=dropout), | |
| zero_module(conv_nd(dims, self.out_channels, | |
| self.out_channels, 3, padding=1)), | |
| ) | |
| if self.out_channels == channels: | |
| self.skip_connection = nn.Identity() | |
| elif use_conv: | |
| self.skip_connection = conv_nd( | |
| dims, channels, self.out_channels, 3, padding=1) | |
| else: | |
| self.skip_connection = conv_nd( | |
| dims, channels, self.out_channels, 1) | |
| def forward(self, x, emb): | |
| if self.updown: | |
| in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] | |
| h = in_rest(x) | |
| h = self.h_upd(h) | |
| x = self.x_upd(x) | |
| h = in_conv(h) | |
| else: | |
| h = self.in_layers(x) | |
| emb_out = self.emb_layers(emb).type(h.dtype) | |
| while len(emb_out.shape) < len(h.shape): | |
| emb_out = emb_out[..., None] | |
| if self.use_scale_shift_norm: | |
| out_norm, out_rest = self.out_layers[0], self.out_layers[1:] | |
| scale, shift = torch.chunk(emb_out, 2, dim=1) | |
| h = out_norm(h) * (1 + scale) + shift | |
| h = out_rest(h) | |
| else: | |
| h = h + emb_out | |
| h = self.out_layers(h) | |
| return self.skip_connection(x) + h | |
| class QKVAttentionLegacy(nn.Module): | |
| def __init__(self, n_heads): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| def forward(self, qkv): | |
| bs, width, length = qkv.shape | |
| assert width % (3 * self.n_heads) == 0 | |
| ch = width // (3 * self.n_heads) | |
| q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, | |
| length).split(ch, dim=1) | |
| scale = 1 / math.sqrt(math.sqrt(ch)) | |
| weight = torch.einsum( | |
| "bct,bcs->bts", q * scale, k * scale | |
| ) # More stable with f16 than dividing afterwards | |
| weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
| a = torch.einsum("bts,bcs->bct", weight, v) | |
| return a.reshape(bs, -1, length) | |
| class AttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| num_heads=1, | |
| num_head_channels=-1, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| if num_head_channels == -1: | |
| self.num_heads = num_heads | |
| else: | |
| assert ( | |
| channels % num_head_channels == 0 | |
| ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" | |
| self.num_heads = channels // num_head_channels | |
| assert num_heads > 0 | |
| self.norm = nn.GroupNorm(32, channels) | |
| self.qkv = conv_nd(1, channels, channels * 3, 1) | |
| self.attention = QKVAttentionLegacy(self.num_heads) | |
| self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) | |
| def forward(self, x): | |
| b, c, *spatial = x.shape | |
| x = x.reshape(b, c, -1) | |
| qkv = self.qkv(self.norm(x)) | |
| h = self.attention(qkv) | |
| h = self.proj_out(h) | |
| return (x + h).reshape(b, c, *spatial) | |
| class TimestepEmbedSequential(nn.Sequential): | |
| """A sequential module that passes timestep embeddings to the children that support it as an | |
| extra input.""" | |
| def forward(self, x, emb): | |
| for layer in self: | |
| if isinstance(layer, ResBlock): | |
| x = layer(x, emb) | |
| else: | |
| x = layer(x) | |
| return x | |
| class UNet(nn.Module): | |
| def __init__(self, | |
| dim: int = 28, | |
| in_channels: int = 1, | |
| out_channels: int = 1, | |
| model_channels: int = 32, | |
| num_classes: int = 10, | |
| channel_mult=(1, 2, 2), | |
| num_res_blocks: int = 1, | |
| attention_resolutions: str = "16", | |
| dims: int = 2, | |
| num_heads: int = 1, | |
| num_head_channels: int = -1, | |
| num_heads_upsample: int = -1, | |
| use_scale_shift_norm: bool = False, | |
| dropout: float = 0, | |
| resblock_updown: bool = False, | |
| ): | |
| super().__init__() | |
| self.model_channels = model_channels | |
| time_embed_dim = model_channels * 4 | |
| attention_ds = [] | |
| for res in attention_resolutions.split(","): | |
| attention_ds.append(dim // int(res)) | |
| attention_resolutions = attention_ds | |
| conv_resample = True | |
| if num_heads_upsample == -1: | |
| num_heads_upsample = num_heads | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(model_channels, time_embed_dim), | |
| nn.SiLU(), | |
| nn.Linear(time_embed_dim, time_embed_dim), | |
| ) | |
| self.num_classes = num_classes | |
| if self.num_classes is not None: | |
| self.label_emb = nn.Embedding(num_classes, time_embed_dim) | |
| ch = input_ch = int(channel_mult[0] * model_channels) | |
| self.input_blocks = nn.ModuleList( | |
| [TimestepEmbedSequential( | |
| conv_nd(dims, in_channels, ch, 3, padding=1))] | |
| ) | |
| self._feature_size = ch | |
| input_block_chans = [ch] | |
| ds = 1 | |
| for level, mult in enumerate(channel_mult): | |
| for _ in range(num_res_blocks): | |
| layers = [ | |
| ResBlock( | |
| ch, | |
| time_embed_dim, | |
| dropout, | |
| out_channels=int(mult * model_channels), | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| ) | |
| ] | |
| ch = int(mult * model_channels) | |
| if ds in attention_resolutions: | |
| layers.append( | |
| AttentionBlock( | |
| ch, | |
| num_heads=num_heads, | |
| num_head_channels=num_head_channels, | |
| ) | |
| ) | |
| self.input_blocks.append(TimestepEmbedSequential(*layers)) | |
| self._feature_size += ch | |
| input_block_chans.append(ch) | |
| if level != len(channel_mult) - 1: | |
| out_ch = ch | |
| self.input_blocks.append( | |
| TimestepEmbedSequential( | |
| ResBlock( | |
| ch, | |
| time_embed_dim, | |
| dropout, | |
| out_channels=out_ch, | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| down=True, | |
| ) | |
| if resblock_updown | |
| else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) | |
| ) | |
| ) | |
| ch = out_ch | |
| input_block_chans.append(ch) | |
| ds *= 2 | |
| self._feature_size += ch | |
| self.middle_block = TimestepEmbedSequential( | |
| ResBlock( | |
| ch, | |
| time_embed_dim, | |
| dropout, | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| ), | |
| AttentionBlock( | |
| ch, | |
| num_heads=num_heads, | |
| num_head_channels=num_head_channels, | |
| ), | |
| ResBlock( | |
| ch, | |
| time_embed_dim, | |
| dropout, | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| ), | |
| ) | |
| self._feature_size += ch | |
| self.output_blocks = nn.ModuleList([]) | |
| for level, mult in list(enumerate(channel_mult))[::-1]: | |
| for i in range(num_res_blocks + 1): | |
| ich = input_block_chans.pop() | |
| layers = [ | |
| ResBlock( | |
| ch + ich, | |
| time_embed_dim, | |
| dropout, | |
| out_channels=int(model_channels * mult), | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| ) | |
| ] | |
| ch = int(model_channels * mult) | |
| if ds in attention_resolutions: | |
| layers.append( | |
| AttentionBlock( | |
| ch, | |
| num_heads=num_heads_upsample, | |
| num_head_channels=num_head_channels, | |
| ) | |
| ) | |
| if level and i == num_res_blocks: | |
| out_ch = ch | |
| layers.append( | |
| ResBlock( | |
| ch, | |
| time_embed_dim, | |
| dropout, | |
| out_channels=out_ch, | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| up=True, | |
| ) | |
| if resblock_updown | |
| else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) | |
| ) | |
| ds //= 2 | |
| self.output_blocks.append(TimestepEmbedSequential(*layers)) | |
| self._feature_size += ch | |
| self.out = nn.Sequential( | |
| nn.GroupNorm(32, ch), | |
| nn.SiLU(), | |
| zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), | |
| ) | |
| def forward(self, t, x, y=None): | |
| timesteps = t | |
| assert (y is not None) == ( | |
| self.num_classes is not None | |
| ), "must specify y if and only if the model is class-conditional" | |
| while timesteps.dim() > 1: | |
| print(timesteps.shape) | |
| timesteps = timesteps[:, 0] | |
| if timesteps.dim() == 0: | |
| timesteps = timesteps.repeat(x.shape[0]) | |
| hs = [] | |
| emb = self.time_embed(timestep_embedding( | |
| timesteps, self.model_channels)) | |
| if self.num_classes is not None: | |
| assert y.shape == (x.shape[0],), f"{y.shape}, {x.shape}" | |
| emb = emb + self.label_emb(y) | |
| h = x.type(torch.float32) | |
| for module in self.input_blocks: | |
| h = module(h, emb) | |
| hs.append(h) | |
| h = self.middle_block(h, emb) | |
| for module in self.output_blocks: | |
| h = torch.cat([h, hs.pop()], dim=1) | |
| h = module(h, emb) | |
| h = h.type(x.dtype) | |
| return self.out(h) | |
| class OptimalTransportPath: | |
| def __init__(self, sig_min: float = 1e-5) -> None: | |
| self.sig_min = sig_min | |
| def sample(self, x1: torch.Tensor, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| t = t.view(-1, 1, 1, 1) | |
| xt = x1 * t + (1 - (1 - self.sig_min) * t) * x0 | |
| vt = x1 - (1 - self.sig_min) * x0 | |
| return xt, vt | |
| if __name__ == "__main__": | |
| device = torch.device("cuda") | |
| model = UNet().to(device) | |
| path = OptimalTransportPath() | |
| optim = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| B = 128 | |
| train_loader = torch.utils.data.DataLoader( | |
| datasets.MNIST('../mnist_data', download=True, train=True, | |
| transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), | |
| batch_size=B, shuffle=True) | |
| data_iter = iter(train_loader) | |
| print("Running Flow Matching with OT") | |
| for i in range(10000): | |
| optim.zero_grad() | |
| try: | |
| x1, c = next(data_iter) | |
| except StopIteration: | |
| data_iter = iter(train_loader) | |
| x1, c = next(data_iter) | |
| x1, c = x1.to(device), c.to(device) | |
| x0 = torch.randn_like(x1, device=device) | |
| t = torch.rand(x1.shape[0], device=device) | |
| xt, vt = path.sample(x1, x0, t) | |
| loss = torch.pow(model(t, xt, c) - vt, 2).mean() | |
| loss.backward() | |
| optim.step() | |
| if (i + 1) % 100 == 0: | |
| print(f"| iter {i+1:6d} | loss {loss.item():8.3f}") | |
| # sample using midpoint solver | |
| xt = torch.rand(10, 1, 28, 28, dtype=torch.float32, device=device) | |
| T = torch.linspace(0, 1, 11, device=device) # sample times | |
| c = torch.arange(10, device=device) | |
| odefunc = partial(model.forward, y=c) | |
| sol = list() | |
| for i in range(10): | |
| t_start = T[i].expand(xt.shape[0]) | |
| t_end = T[i + 1].expand(xt.shape[0]) | |
| xt = xt + (t_end - t_start)[..., None, None, None] * odefunc( | |
| t=t_start + (t_end - t_start) / | |
| 2, x=xt + odefunc(x=xt, t=t_start) * ((t_end - t_start) / 2)[..., None, None, None]) | |
| sol.append(xt) | |
| Path("outputs").mkdir(exist_ok=True) | |
| num_timesteps = len(sol) | |
| fig, axes = plt.subplots(10, num_timesteps, figsize=(2*num_timesteps, 20)) | |
| plt.subplots_adjust(hspace=0.1, wspace=0.1) | |
| for t in range(num_timesteps): | |
| for i in range(10): | |
| ax = axes[i, t] | |
| ax.imshow(sol[t][i].squeeze().detach().cpu(), cmap='gray') | |
| ax.axis('off') | |
| if t == 0: | |
| ax.set_ylabel(f'Sample {i}') | |
| if i == 0: | |
| ax.set_title(f'Step {t}') | |
| plt.savefig('outputs/diffusion_process.png', bbox_inches='tight', dpi=150) | |
| plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment