Skip to content

Instantly share code, notes, and snippets.

@ahmadsalim
Last active November 1, 2018 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahmadsalim/5e98a01b5c2d77142e248ebf0fa36291 to your computer and use it in GitHub Desktop.
Save ahmadsalim/5e98a01b5c2d77142e248ebf0fa36291 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, JitTraceEnum_ELBO
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
def mg_gen(is_guide=False):
def mg(data):
clust, corr, vals, lengths = data
clust = clust.float()
vals = vals.t().float()
corr = corr.float()
n_clusts = clust.size(1)
n_corr = 2
n_class = 5
pseudocounts = clust.new_ones(n_clusts, n_corr, n_class).float() * 0.5
if is_guide:
pseudocounts = pyro.param('pc_q', pseudocounts)
obs_dir = pyro.sample('obs_dir', dist.Dirichlet(pseudocounts).independent(2))
feature_plate = pyro.plate('feature', vals.size(0), dim=-2)
with pyro.plate('data', vals.size(1), dim=-1) as idx:
c0corr = vals.new_tensor(4.0).float()
c1corr = vals.new_tensor(1.0).float()
if is_guide:
c0corr = pyro.param('c0corr', c0corr, constraint=constraints.greater_than(1e-3))
c1corr = pyro.param('c1corr', c1corr, constraint=constraints.greater_than(1e-3))
corrpr = pyro.sample('corrpr', dist.Beta(concentration0=c0corr, concentration1=c1corr))
corrobs = dict(obs=corr[idx]) if not is_guide else dict()
corrch = pyro.sample('corrch', dist.Bernoulli(corrpr),
infer=dict(enumerate='parallel', is_auxiliary=True), **corrobs).long()
clustch = pyro.sample('clustch', dist.Categorical(clust.index_select(0, idx)),
infer=dict(enumerate='parallel'))
with feature_plate as fidx:
mask = fidx.unsqueeze(-1).int() < lengths[idx].repeat(fidx.size(0), 1).int()
if not is_guide:
with poutine.mask(mask=mask):
pyro.sample('obs', dist.Categorical(obs_dir[clustch, corrch]), obs=vals)
return mg
model = mg_gen(False)
guide = mg_gen(True)
def _train(svi, train_loader):
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch x returned
# by the data loader
for _, data in enumerate(train_loader):
# do ELBO gradient and accumulate loss
data = tuple(d.squeeze(0) for d in data)
epoch_loss += svi.step(data)
return epoch_loss / train_loader.dataset.tensors[0].size(0)
def infer(dataloader, learning_rate=1e-3, n_epochs=50):
optimizer_args = {'lr': learning_rate}
optimizer = Adam(optimizer_args)
loss = TraceEnum_ELBO(max_iarange_nesting=2)
svi = SVI(model, guide, optimizer, loss)
train_elbo = []
pbar = tqdm(range(n_epochs))
for epoch in pbar:
total_epoch_loss_train = _train(svi, dataloader)
train_elbo.append(-total_epoch_loss_train)
pbar.set_description("[epoch {}] avg train loss: {}".format(epoch, total_epoch_loss_train))
param_names = ('c0corr', 'c1corr', 'pc_q')
params = {pn: pyro.param(pn) for pn in param_names}
return params
def main(_args):
pyro.clear_param_store()
pyro.enable_validation(False)
diricthlet_alphas = np.ones((20,))
diricthlet_alphas[:3] = 10
corr_prop = 0.2
n_data = 10000
n_features = 52
claz = []
for j in range(len(diricthlet_alphas)):
# Correct Class
cclaz = np.random.dirichlet((0.1, 0.1, 0.5, 0.9, 0.9))
# Incorrect Class
iclaz = np.random.dirichlet((0.7, 0.7, 0.5, 0.3, 0.3))
claz.append(np.stack((iclaz, cclaz)))
claz = np.stack(claz)
clust = []
corr = []
vals = []
lengths = []
for i in range(n_data):
np.random.shuffle(diricthlet_alphas)
clust.append(np.random.dirichlet(diricthlet_alphas))
corr.append(np.random.binomial(1, corr_prop))
cch = np.flatnonzero(np.random.multinomial(1, clust[-1]))[0]
vvals = []
b = False
for k in range(n_features):
if k > 1 and np.random.binomial(1, 2 / n_features):
lengths.append(k)
vvals.extend([0 for _ in range(n_features - k)])
b = True
break
vvals.append(np.flatnonzero(np.random.multinomial(1, claz[cch, corr[-1]]))[0])
if not b:
lengths.append(n_features)
vals.append(np.stack(vvals))
clust = np.stack(clust)
corr = np.stack(corr)
vals = np.stack(vals)
lengths = np.stack(lengths)
dataset = TensorDataset(torch.tensor(clust), torch.tensor(corr), torch.tensor(vals), torch.tensor(lengths))
dataloader = DataLoader(dataset, batch_size=250)
params = infer(dataloader)
for p, v in params.items():
print(p)
print(v)
if __name__ == '__main__':
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment