Skip to content

Instantly share code, notes, and snippets.

@mdvsh
Created November 21, 2022 18:08
Show Gist options
  • Save mdvsh/dcc15167590155af8a4cc9122aa93524 to your computer and use it in GitHub Desktop.
Save mdvsh/dcc15167590155af8a4cc9122aa93524 to your computer and use it in GitHub Desktop.
example mog use
class MoG:
def __init__(self, means, sigma, weights=None, td="cpu"):
if weights is None:
weights = torch.ones(means.shape[0], device=td) / means.shape[0]
self.means = means.detach()
mix_d = D.Categorical(weights)
comp_d = D.Independent(D.Normal(self.means, sigma * torch.ones(means.shape, device=td)), 2) # 2 needed to interpret M as batch
self.mixture = D.MixtureSameFamily(mix_d, comp_d)
def sample(self, n=None):
# remember to reshape the sampled control
return self.mixture.sample((n,)) if n is not None else self.mixture.sample()
def log_prob(self, x):
# remember to flatten x
return self.mixture.log_prob(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment