Skip to content

Instantly share code, notes, and snippets.

@sleexyz
Created March 16, 2024 02:24
Show Gist options
  • Save sleexyz/48b9a86bd6688433cf2b7526db4ea942 to your computer and use it in GitHub Desktop.
Save sleexyz/48b9a86bd6688433cf2b7526db4ea942 to your computer and use it in GitHub Desktop.
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
print(f"i: {i}, sigma: {sigmas[i]}, sigma_hat: {sigma_hat}, x.shape: {x.shape}")
if i == 0:
model_graph = draw_graph(
model,
input_data=(
x,
sigma_hat * s_in,
),
expand_nested=True,
save_graph=True,
**extra_args,
)
# writer.add_graph(model, (x, sigma_hat * s_in, extra_args['cond'], extra_args['uncond'], extra_args['cond_scale'], extra_args['denoise_mask']))
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment