Skip to content

Instantly share code, notes, and snippets.

@ThiagoLira
Created August 2, 2022 19:52
Show Gist options
  • Save ThiagoLira/98c5ab5fab07844fd402912322f6a9eb to your computer and use it in GitHub Desktop.
Save ThiagoLira/98c5ab5fab07844fd402912322f6a9eb to your computer and use it in GitHub Desktop.
def denoise_with_mu(denoise_model, x_t, t, list_alpha, list_alpha_bar, DATA_SIZE, device):
"""
Denoising function considering the denoising models tries to model the posterior mean
"""
alpha_t = list_alpha[t]
beta_t = 1 - alpha_t
alpha_bar_t = list_alpha_bar[t]
mu_theta = denoise_model(x_t,t)
x_t_before = torch.distributions.MultivariateNormal(loc=mu_theta,covariance_matrix=torch.diag(beta_t.repeat(DATA_SIZE))).sample().to(device)
return x_t_before
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment