Skip to content

Instantly share code, notes, and snippets.

@francois-rozet
Last active April 21, 2024 14:26
Show Gist options
  • Star 42 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Flow Matching in 100 LOC
#!/usr/bin/env python
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.datasets import make_moons
from torch import Tensor
from tqdm import tqdm
from typing import *
from zuko.utils import odeint
def log_normal(x: Tensor) -> Tensor:
return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2
class MLP(nn.Sequential):
def __init__(
self,
in_features: int,
out_features: int,
hidden_features: List[int] = [64, 64],
):
layers = []
for a, b in zip(
(in_features, *hidden_features),
(*hidden_features, out_features),
):
layers.extend([nn.Linear(a, b), nn.ELU()])
super().__init__(*layers[:-1])
class CNF(nn.Module):
def __init__(self, features: int, freqs: int = 3, **kwargs):
super().__init__()
self.net = MLP(2 * freqs + features, features, **kwargs)
self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
def forward(self, t: Tensor, x: Tensor) -> Tensor:
t = self.freqs * t[..., None]
t = torch.cat((t.cos(), t.sin()), dim=-1)
t = t.expand(*x.shape[:-1], -1)
return self.net(torch.cat((t, x), dim=-1))
def encode(self, x: Tensor) -> Tensor:
return odeint(self, x, 0.0, 1.0, phi=self.parameters())
def decode(self, z: Tensor) -> Tensor:
return odeint(self, z, 1.0, 0.0, phi=self.parameters())
def log_prob(self, x: Tensor) -> Tensor:
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)
def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor:
with torch.enable_grad():
x = x.requires_grad_()
dx = self(t, x)
jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
trace = torch.einsum('i...i', jacobian)
return dx, trace * 1e-2
ladj = torch.zeros_like(x[..., 0])
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters())
return log_normal(z) + ladj * 1e2
class FlowMatchingLoss(nn.Module):
def __init__(self, v: nn.Module):
super().__init__()
self.v = v
def forward(self, x: Tensor) -> Tensor:
t = torch.rand_like(x[..., 0, None])
z = torch.randn_like(x)
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z
u = (1 - 1e-4) * z - x
return (self.v(t.squeeze(-1), y) - u).square().mean()
if __name__ == '__main__':
flow = CNF(2, hidden_features=[64] * 3)
# Training
loss = FlowMatchingLoss(flow)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
data, _ = make_moons(16384, noise=0.05)
data = torch.from_numpy(data).float()
for epoch in tqdm(range(16384), ncols=88):
subset = torch.randint(0, len(data), (256,))
x = data[subset]
loss(x).backward()
optimizer.step()
optimizer.zero_grad()
# Sampling
with torch.no_grad():
z = torch.randn(16384, 2)
x = flow.decode(z)
plt.figure(figsize=(4.8, 4.8), dpi=150)
plt.hist2d(*x.T, bins=64)
plt.savefig('moons_fm.pdf')
# Log-likelihood
with torch.no_grad():
log_p = flow.log_prob(data[:4])
print(log_p)
@francois-rozet
Copy link
Author

francois-rozet commented Feb 15, 2023

As side notes,

  1. Computing the log-likelihood of a CNF requires to integrate an ODE. I use the odeint function provided by Zuko to do so. It implements the adaptive checkpoint adjoint (ACA) method which allows for more accurate back-propagation than the standard adjoint method implemented by torchdiffeq.
  2. Adaptive ODE solvers choose their step size according to an estimation of the integration error. For the trace-augmented ODE, odeint over estimates the integration error because the trace has large(r) absolute values, which leads to small step sizes. To mitigate this without significant loss of accuracy, I multiply the trace by a factor $10^{-2}$.

📢 See francois-rozet/papers-101 for other papers in 100 lines of code.

@pengzhangzhi
Copy link

Code explains everything! I am wondering if you have code for reproducing the results on some more challenging datasets, like cifar-10 or beyond. That may help use explore the method further.

@samedii
Copy link

samedii commented Mar 16, 2023

Thanks for the nice code! :) Do you have any thoughts on how similar this is to the velocity objective?

@francois-rozet
Copy link
Author

Code explains everything! I am wondering if you have code for reproducing the results on some more challenging datasets, like cifar-10 or beyond. That may help use explore the method further.

Hello @pengzhangzhi, I have not tried flow matching on images yet. The main change would be the underlying network, which should be adapted for images. A good choice would be the Unet2DModel from Hugging Face's diffusers package

@francois-rozet
Copy link
Author

Thanks for the nice code! :) Do you have any thoughts on how similar this is to the velocity objective?

Hello @samedii, I am not sure to understand what you mean by "velocity objective"?

@samedii
Copy link

samedii commented Mar 18, 2023

@francois-rozet Sorry for the slow reply

It's called v-objective too. It was first used by nvidia for distilling I think.

Katherine used it here https://github.com/crowsonkb/v-diffusion-pytorch
It's used in some of the v2 stable diffusion models https://huggingface.co/stabilityai/stable-diffusion-2

@francois-rozet
Copy link
Author

francois-rozet commented Mar 23, 2023

@samedii Thanks for the references! It is indeed similar in that a difference between $x$ and $\epsilon$ is targeted. In OT flow matching, the target is $\epsilon - x$ while $\alpha_t \epsilon - \sigma_t x$ is the target of the $v$-objective.

@pengzhangzhi
Copy link

Hi! I copy-paste your code and got a bug saying TypeError: grad() got an unexpected keyword argument 'is_grads_batched'.
I check the pytorch torch.autograd.grad function, it does contains the is_grads_batched parameter. I am confused. Not sure what when run.

  File "a.py", line 122, in <module>
    log_p = flow.log_prob(data[:4])
  File "a.py", line 74, in log_prob
    z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters())
  File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 314, in odeint
    return tuple(unpack(AdaptiveCheckpointAdjoint.apply(g, x, t0, t1, *phi)))
  File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 415, in forward
    y, error = dopri45(f, x, t, dt, error=True)
  File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 330, in dopri45
    k1 = dt * f(t, x)
  File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 304, in <lambda>
    g = lambda t, x: pack(f(t, *unpack(x)))
  File "a.py", line 68, in augmented
    jacobian = torch.autograd.grad(dx, x, I, is_grads_batched=True, create_graph=True)[0]
TypeError: grad() got an unexpected keyword argument 'is_grads_batched'

@francois-rozet
Copy link
Author

Hi! I copy-paste your code and got a bug saying TypeError: grad() got an unexpected keyword argument 'is_grads_batched'. I check the pytorch torch.autograd.grad function, it does contains the is_grads_batched parameter.

Hello @pengzhangzhi, the is_grads_batched option is available since PyTorch 1.11. You probably use an older version.

@pengzhangzhi
Copy link

Hi! I copy-paste your code and got a bug saying TypeError: grad() got an unexpected keyword argument 'is_grads_batched'. I check the pytorch torch.autograd.grad function, it does contains the is_grads_batched parameter.

Hello @pengzhangzhi, the is_grads_batched option is available since PyTorch 1.11. You probably use an older version.

The only way to fix this bug is to upgrade the torch version right? Thanks! You are very nice!

@hamrel-cxu
Copy link

@francois-rozet Thank you so much for making this code! Just to make sure, is this an implementation of the work "Flow Matching for Generative Modeling (https://arxiv.org/abs/2210.02747)" by Lipman et al.? If so, may I ask which loss objective in the work is your training objective based on?

@francois-rozet
Copy link
Author

I am glad you like it @hamrel-cxu! It is the optimal transport (OT) flow matching loss. Note that the 0 and 1 extremities of the time are reversed here.

@DebajyotiS
Copy link

Really cool! Any particular reason for inverting the time extremities?

@francois-rozet
Copy link
Author

Hello @DebajyotiS, thanks! In score-based generative modeling, it is standard to set $t=0$ as the data (noiseless) extremity and $t=1$ as the noise extremity. In the Flow Matching paper, the authors do not follow the standard, for unknown reasons, but I do in this implementation.

@fd873630
Copy link

fd873630 commented Jul 18, 2023

@francois-rozet Thank you so much for making this code! I have a question about the code.

self.register_buffer('frequencies', 2 ** torch.arange(frequencies) * torch.pi)

t = self.frequencies * t[..., None]
t = torch.cat((t.cos(), t.sin()), dim=-1)

Through this function, t is changed. Could you please explain the reason behind this?

@francois-rozet
Copy link
Author

Hello @fd873630, $t$ is not "changed" by this function but embedded to a higher dimensional space. It is often called "positional encoding" or "time embedding" and allows the network to adjust its behavior with respect to $t$ with more granularity than by simply giving it $t$ as input.

@shivammehta25
Copy link

shivammehta25 commented Jul 24, 2023

Thank you for the fantastic breakdown of the code! The code helped me a lot in understanding the equations of the paper.
If I have to run this on a GPU, do you have any suggestions on how I can change the Zuko's odeint function to torchdyn? Mainly because I found zuko's odeint to be slow, but maybe because of the nature of the solution, it will take equal time if I replace it with Torchdyn, do you have any ideas around this?

@francois-rozet
Copy link
Author

francois-rozet commented Jul 24, 2023

Hello @shivammehta25, thanks! I never tried to use torchdyn, mainly because of the lack of documentation. I did try with torchdiffeq's odeint_adjoint, but it was always (1.5-2x) slower than Zuko's.

# Encode
z = torchdiffeq.odeint_adjoint(flow, x, torch.tensor((0.0, 1.0)))[-1]

# Decode
x = torchdiffeq.odeint_adjoint(flow, z, torch.tensor((1.0, 0.0)))[-1]

Note that all adaptive ODE integrators rely on CPU synchronization, so this might be a bottleneck when solving on GPU. Also, the smoother the solution, the faster the integrator, so don't be afraid to train your network for a LONG time (continue even if the loss seems to "have converged"), and use learning rate scheduling. I usually use linear scheduling.

Finally, score-based generative modeling (which flow-matching is a special case of) is slow by design. Sampling requires a lot of network evaluations, and there's not much you can do about it.

@shivammehta25
Copy link

shivammehta25 commented Jul 24, 2023

Awesome! Thanks :)

adaptive ODE integrators rely on CPU synchronization

Interesting! Would you mind elaborating on this? Why would this be the case? I am sorry for spamming here, otherwise, I would reach out to you by email, if that is fine, or it might be useful for other people as well.

score-based generative modelling (which flow-matching is a special case of)

I thought it was reverse that flow matching is the umbrella framework and score-based is one of the special case.

@francois-rozet
Copy link
Author

Adaptive ODE solvers modify their integration step size according to an estimation of the integration error. If the error is too large, the step is rejected and the step size is reduced. The "if" can only be evaluated on CPU, and hence requires CPU-GPU synchronization.

I thought it was reverse that flow matching is the umbrella framework and score-based is one of the special case.

You can view this either way. The main difference is that flow-matching approximates an ODE while score-matching approximates an SDE.

@yuyangw
Copy link

yuyangw commented Aug 4, 2023

Thanks for the nice implementation! If I understand correctly, in line 88-89, you implement conditional flow matching loss (CMF) based on Equation 23 in the flow matching paper (https://arxiv.org/pdf/2210.02747.pdf). However, shouldn't it be as following?

y = (1 - (1 - 1e-4) * t) * z + t * x
u = x - (1 - 1e-4) * z

However, if I change the code, the CFM model won't work. Could you please help me with that. Thanks a lot!

@francois-rozet
Copy link
Author

francois-rozet commented Aug 5, 2023

Hello @yuyangw 👋 As mentioned earlier in the thread, the 0 and 1 extremities of the time are reversed in this implementation, which is why the loss is slightly different. If you want to change the loss, you also have to switch the initial and final times ($0 \leftrightarrow 1$) of the odeint calls (in encode, decode and log_prob).

@yuyangw
Copy link

yuyangw commented Aug 6, 2023

Hello @yuyangw 👋 As mentioned earlier in the thread, the 0 and 1 extremities of the time are reversed in this implementation, which is why the loss is slightly different. If you want to change the loss, you also have to switch the initial and final times (0↔1) of the odeint calls (in encode, decode and log_prob).

Hi @francois-rozet, thank you so much for the reply!

@radiradev
Copy link

Hi @francois-rozet 👋, could you give a bit more details about the probability calculation? In particular where does the added 1e2 (1e-2) come from?

@francois-rozet
Copy link
Author

francois-rozet commented Apr 11, 2024

Hello @radiradev, as explained in the first comment,

Adaptive ODE solvers choose their step size according to an estimation of the integration error. For the trace-augmented ODE, odeint over estimates the integration error because the trace has large(r) absolute values, which leads to small step sizes. To mitigate this without significant loss of accuracy, I multiply the trace by a factor $10^{-2}$.

To paraphrase, I don't want the computation of log-absolute-determinant of the Jacobian (ladj) to influence the step size of the solver. But because the trace has high magnitude compared to the derivative (dx), it does influence it (and makes it much slower) in practice. To mitigate this, I multiply trace by a factor $10^{-2}$, and at the end multiply the ladj by the inverse factor $10^2$.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment