Skip to content

Instantly share code, notes, and snippets.

@akashgit
Created November 7, 2017 18:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save akashgit/134f9fadaca4c6df72dade6e49ddaf89 to your computer and use it in GitHub Desktop.
Save akashgit/134f9fadaca4c6df72dade6e49ddaf89 to your computer and use it in GitHub Desktop.
def grid(batch_size, num_components, num_features,**kwargs):
shape=(batch_size,num_features)
shape_cat=(batch_size,num_components)
cat = ds.Categorical(logits=np.log(0.04*np.ones(shape_cat, dtype=float32)))
mus = np.array([np.array([i, j])*np.ones(shape, dtype=float32) for i, j in itertools.product(range(-4, 5, 2),
range(-4, 5, 2))],dtype=float32)
s = 0.05*np.ones(shape, dtype=float32)
sigmas = [s for i in range(num_components)]
components = list((ds.MultivariateNormalDiag(mu, sigma, **kwargs)
for (mu, sigma) in zip(mus, sigmas)))
return ds.Mixture(cat, components)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment