Skip to content

Instantly share code, notes, and snippets.

@akashgit
Created November 7, 2017 18:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save akashgit/2a11806b2ae79de7e5faddfd97b76399 to your computer and use it in GitHub Desktop.
Save akashgit/2a11806b2ae79de7e5faddfd97b76399 to your computer and use it in GitHub Desktop.
def highdim_syn_data(batch_size, num_components, num_features,**kwargs):
shape=(num_features)
shape_cat=(batch_size,num_components)
cat = ds.Categorical(tf.zeros(num_components, dtype=float32))
mus = [-1*tf.ones(shape, dtype=float32),-.5*tf.ones(shape, dtype=float32),
0*tf.ones(shape, dtype=float32),.5*tf.ones(shape, dtype=float32),
-2*tf.ones(shape, dtype=float32),-2.5*tf.ones(shape, dtype=float32),
10*tf.ones(shape, dtype=float32),.25*tf.ones(shape, dtype=float32),
-13*tf.ones(shape, dtype=float32),-5.5*tf.ones(shape, dtype=float32)]
s=tf.concat([1*tf.ones((num_features-500), dtype=float32),
np.zeros((500), dtype=float32)],0)
sigmas = [s for i in range(num_components)]
components = list((ds.MultivariateNormalDiag(mu, sigma, **kwargs)
for (mu, sigma) in zip(mus, sigmas)))
return ds.Mixture(cat, components).sample(batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment