Skip to content

Instantly share code, notes, and snippets.

@stites
Created September 10, 2019 21:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stites/f24456e81980da4127de306f815a730b to your computer and use it in GitHub Desktop.
Save stites/f24456e81980da4127de306f815a730b to your computer and use it in GitHub Desktop.
def run_gibbs(nsweeps=1):
xs = get_dataset()
xs, ys = xs[:1,:], xs[2,:]
# xs is Shape([1000,2])
# ys is Shape([1000,1]) of cluster labels
prior = [Normal(-1, 0.5), Normal(1, 0.5)]
for sweep in range(0, nsweeps):
postieror = gibbs(prior, xs)
prior = posterior # do something with that posterior?
def gibbs(prior, data):
chain = [] # used for updating later?
assert len(prior) == 2, "just assume 2 features in the data with 2 variables that correspond."
# ...do a more general case later
for xy in data:
old_x, old_y = xy
x_prior = prior[0]
new_x = x_prior.sample()
new_cond_old = None # find conditional of x|new_x?
chain.append(("x", old_x, new_x, new_cond_old)) # for an update later?
# repeat in y
posterior = prior
return posterior
def get_datapoint(sample_shape=[2]):
"""return datapoint sampled from one of two 2-dimensional gaussians"""
cluster = Bernoulli(torch.tensor([0.5])).sample()
mu = 5 if cluster.item() == 1 else 2
sample = Normal(mu, 1.0).sample(sample_shape=sample_shape)
return torch.cat((sample, cluster))
def get_dataset(n=1000):
"""get dataset of two clusters of 2-dimensional gaussians"""
mkdata = lambda x: get_datapoint()
return torch.stack(list(map(mkdata, range(0, n))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment