Skip to content

Instantly share code, notes, and snippets.

@ThiagoLira
Created August 2, 2022 19:51
Show Gist options
  • Save ThiagoLira/20aea5c82bb56506d1150881bd1a1680 to your computer and use it in GitHub Desktop.
Save ThiagoLira/20aea5c82bb56506d1150881bd1a1680 to your computer and use it in GitHub Desktop.
def q_sample(x_start, t, list_bar_alphas, device):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
alpha_bar_t = list_bar_alphas[t]
mean = alpha_bar_t*x_start
cov = torch.eye(x_start.shape[0]).to(device)
cov = cov*(1-alpha_bar_t)
return torch.distributions.MultivariateNormal(loc=mean,covariance_matrix=cov).sample().to(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment