Created
August 2, 2022 19:52
-
-
Save ThiagoLira/98c5ab5fab07844fd402912322f6a9eb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def denoise_with_mu(denoise_model, x_t, t, list_alpha, list_alpha_bar, DATA_SIZE, device): | |
""" | |
Denoising function considering the denoising models tries to model the posterior mean | |
""" | |
alpha_t = list_alpha[t] | |
beta_t = 1 - alpha_t | |
alpha_bar_t = list_alpha_bar[t] | |
mu_theta = denoise_model(x_t,t) | |
x_t_before = torch.distributions.MultivariateNormal(loc=mu_theta,covariance_matrix=torch.diag(beta_t.repeat(DATA_SIZE))).sample().to(device) | |
return x_t_before |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment