Skip to content

Instantly share code, notes, and snippets.

@francois-rozet
Last active July 1, 2024 19:03
Show Gist options
  • Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Flow Matching in 100 LOC
#!/usr/bin/env python
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.datasets import make_moons
from torch import Tensor
from tqdm import tqdm
from typing import *
from zuko.utils import odeint
def log_normal(x: Tensor) -> Tensor:
return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2
class MLP(nn.Sequential):
def __init__(
self,
in_features: int,
out_features: int,
hidden_features: List[int] = [64, 64],
):
layers = []
for a, b in zip(
(in_features, *hidden_features),
(*hidden_features, out_features),
):
layers.extend([nn.Linear(a, b), nn.ELU()])
super().__init__(*layers[:-1])
class CNF(nn.Module):
def __init__(self, features: int, freqs: int = 3, **kwargs):
super().__init__()
self.net = MLP(2 * freqs + features, features, **kwargs)
self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
def forward(self, t: Tensor, x: Tensor) -> Tensor:
t = self.freqs * t[..., None]
t = torch.cat((t.cos(), t.sin()), dim=-1)
t = t.expand(*x.shape[:-1], -1)
return self.net(torch.cat((t, x), dim=-1))
def encode(self, x: Tensor) -> Tensor:
return odeint(self, x, 0.0, 1.0, phi=self.parameters())
def decode(self, z: Tensor) -> Tensor:
return odeint(self, z, 1.0, 0.0, phi=self.parameters())
def log_prob(self, x: Tensor) -> Tensor:
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)
def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor:
with torch.enable_grad():
x = x.requires_grad_()
dx = self(t, x)
jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
trace = torch.einsum('i...i', jacobian)
return dx, trace * 1e-2
ladj = torch.zeros_like(x[..., 0])
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters())
return log_normal(z) + ladj * 1e2
class FlowMatchingLoss(nn.Module):
def __init__(self, v: nn.Module):
super().__init__()
self.v = v
def forward(self, x: Tensor) -> Tensor:
t = torch.rand_like(x[..., 0, None])
z = torch.randn_like(x)
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z
u = (1 - 1e-4) * z - x
return (self.v(t.squeeze(-1), y) - u).square().mean()
if __name__ == '__main__':
flow = CNF(2, hidden_features=[64] * 3)
# Training
loss = FlowMatchingLoss(flow)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
data, _ = make_moons(16384, noise=0.05)
data = torch.from_numpy(data).float()
for epoch in tqdm(range(16384), ncols=88):
subset = torch.randint(0, len(data), (256,))
x = data[subset]
loss(x).backward()
optimizer.step()
optimizer.zero_grad()
# Sampling
with torch.no_grad():
z = torch.randn(16384, 2)
x = flow.decode(z)
plt.figure(figsize=(4.8, 4.8), dpi=150)
plt.hist2d(*x.T, bins=64)
plt.savefig('moons_fm.pdf')
# Log-likelihood
with torch.no_grad():
log_p = flow.log_prob(data[:4])
print(log_p)
@francois-rozet
Copy link
Author

Adaptive ODE solvers modify their integration step size according to an estimation of the integration error. If the error is too large, the step is rejected and the step size is reduced. The "if" can only be evaluated on CPU, and hence requires CPU-GPU synchronization.

I thought it was reverse that flow matching is the umbrella framework and score-based is one of the special case.

You can view this either way. The main difference is that flow-matching approximates an ODE while score-matching approximates an SDE.

@yuyangw
Copy link

yuyangw commented Aug 4, 2023

Thanks for the nice implementation! If I understand correctly, in line 88-89, you implement conditional flow matching loss (CMF) based on Equation 23 in the flow matching paper (https://arxiv.org/pdf/2210.02747.pdf). However, shouldn't it be as following?

y = (1 - (1 - 1e-4) * t) * z + t * x
u = x - (1 - 1e-4) * z

However, if I change the code, the CFM model won't work. Could you please help me with that. Thanks a lot!

@francois-rozet
Copy link
Author

francois-rozet commented Aug 5, 2023

Hello @yuyangw 👋 As mentioned earlier in the thread, the 0 and 1 extremities of the time are reversed in this implementation, which is why the loss is slightly different. If you want to change the loss, you also have to switch the initial and final times ($0 \leftrightarrow 1$) of the odeint calls (in encode, decode and log_prob).

@yuyangw
Copy link

yuyangw commented Aug 6, 2023

Hello @yuyangw 👋 As mentioned earlier in the thread, the 0 and 1 extremities of the time are reversed in this implementation, which is why the loss is slightly different. If you want to change the loss, you also have to switch the initial and final times (0↔1) of the odeint calls (in encode, decode and log_prob).

Hi @francois-rozet, thank you so much for the reply!

@radiradev
Copy link

Hi @francois-rozet 👋, could you give a bit more details about the probability calculation? In particular where does the added 1e2 (1e-2) come from?

@francois-rozet
Copy link
Author

francois-rozet commented Apr 11, 2024

Hello @radiradev, as explained in the first comment,

Adaptive ODE solvers choose their step size according to an estimation of the integration error. For the trace-augmented ODE, odeint over estimates the integration error because the trace has large(r) absolute values, which leads to small step sizes. To mitigate this without significant loss of accuracy, I multiply the trace by a factor $10^{-2}$.

To paraphrase, I don't want the computation of log-absolute-determinant of the Jacobian (ladj) to influence the step size of the solver. But because the trace has high magnitude compared to the derivative (dx), it does influence it (and makes it much slower) in practice. To mitigate this, I multiply trace by a factor $10^{-2}$, and at the end multiply the ladj by the inverse factor $10^2$.

@thangld201
Copy link

To paraphrase, I don't want the computation of log-absolute-determinant of the Jacobian (ladj) to influence the step size of the solver. But because the trace has high magnitude compared to the derivative (dx), it does influence it (and makes it much slower) in practice. To mitigate this, I multiply trace by a factor 10−2, and at the end multiply the ladj by the inverse factor 102.

Hi @francois-rozet, should I change this factor when dealing with other data (e.g. image embedding), or keep it the same ?

@francois-rozet
Copy link
Author

francois-rozet commented May 6, 2024

Hello @thangld201, the best would be to try different values for the factor (basically its a tradeoff between log-prob accuracy and efficiency) and pick what suits your needs. Note that this code expects x to be a vector or a batch of vectors. If x has the shape of an image it will likely not work.

@thangld201
Copy link

@francois-rozet Thanks for your answer. So if the factor is lower (e.g. 1e-6), it gets less accurate but faster ?

@francois-rozet
Copy link
Author

francois-rozet commented May 6, 2024

Exactly, but potentially much less accurate, while being marginally faster. That's why you should try a few values (with the same input, to compare the results).

@jenkspt
Copy link

jenkspt commented Jun 8, 2024

For decoding - I don't see anything that necessitates z being from a normal distribution. Does this mean z can be sampled from any probability distribution?

@DebajyotiS
Copy link

@jenkspt I would think so, I am aware of at least one study (in the context of data unfolding in High Energy Physics) that does data to data with this formulation. https://arxiv.org/abs/2311.17175
I have to think a bit deeply if that makes sense, though. (Results look good nonetheless)

@francois-rozet
Copy link
Author

@jenkspt As long as the distribution of $z$ is the same during training and sampling, I think it should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment