Skip to content

Instantly share code, notes, and snippets.

@tomonari-masada
Created November 27, 2017 11:48
Show Gist options
  • Save tomonari-masada/48c10d01c98eddf27762dbf3a7a0b633 to your computer and use it in GitHub Desktop.
Save tomonari-masada/48c10d01c98eddf27762dbf3a7a0b633 to your computer and use it in GitHub Desktop.
Adversarial variational Bayes for univariate Gaussian mixture models
import sys
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import os
if not os.path.exists('out/'):
os.makedirs('out/')
torch.manual_seed(123)
np.random.seed(123)
mb_size = 100
eps_size = 100
z_dim = 1
embed_dim = 1
eps_dim = 2
X_dim = 1
h_dim = 8
cnt = 0
q_lr = 0.0001
qz_lr = 0.0001
t_lr = 0.01
lr = 0.001
def log(x):
return torch.log(x + 1e-10)
weight_init_dict = {'kn':torch.nn.init.kaiming_normal,
'ku':torch.nn.init.kaiming_uniform,
'xn':torch.nn.init.xavier_normal,
'xu':torch.nn.init.xavier_uniform}
weight_init = dict()
if len(sys.argv) > 3:
weight_init['q'] = sys.argv[1]
weight_init['qz'] = sys.argv[2]
weight_init['t'] = sys.argv[3]
elif len(sys.argv) == 3:
weight_init['q'] = sys.argv[1]
weight_init['qz'] = sys.argv[1]
weight_init['t'] = sys.argv[2]
elif len(sys.argv) == 2:
weight_init['q'] = sys.argv[1]
weight_init['qz'] = sys.argv[1]
weight_init['t'] = sys.argv[1]
else:
weight_init['q'] = 'ku'
weight_init['qz'] = 'ku'
weight_init['t'] = 'ku'
# number of clusters
K = 3
# Encoder: q(mu|eps)
Q = torch.nn.Sequential(
torch.nn.Linear(embed_dim + eps_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, z_dim)
)
for l in Q:
if type(l) == torch.nn.ReLU:
weight_init_dict[weight_init['q']](prev_l.weight)
prev_l = l
# Encoder: q(z|X)
Qz = torch.nn.Sequential(
torch.nn.Linear(X_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, K),
torch.nn.Softmax()
)
for l in Qz:
if type(l) == torch.nn.ReLU:
weight_init_dict[weight_init['qz']](prev_l.weight)
prev_l = l
# Discriminator: T(z)
T = torch.nn.Sequential(
torch.nn.Linear(z_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, h_dim),
torch.nn.ReLU(),
torch.nn.Linear(h_dim, 1)
)
for l in T:
if type(l) == torch.nn.ReLU:
weight_init_dict[weight_init['t']](prev_l.weight)
prev_l = l
def reset_grad():
Q_optimizer.zero_grad()
Qz_optimizer.zero_grad()
T_optimizer.zero_grad()
optimizer.zero_grad()
true_scale = 20.0
true_mu = np.random.normal(loc=0.0, scale=true_scale, size=(K))
print(true_mu)
true_cluster_prob = np.ones(K) + np.random.random(K) * 3
true_cluster_prob /= true_cluster_prob.sum()
print(true_cluster_prob)
def sample_X(size):
z = np.random.choice(K, size, p=true_cluster_prob)
X = np.random.normal(loc=true_mu[z], scale=1.0).astype(np.float32)
X = Variable(torch.from_numpy(X))
return X
cluster_prob = Variable(torch.randn(K), requires_grad=True)
cluster_embedding = Variable(torch.randn(K, embed_dim), requires_grad=True)
Q_optimizer = optim.Adam(list(Q.parameters()), lr=q_lr)
Qz_optimizer = optim.Adam(list(Qz.parameters()), lr=qz_lr)
T_optimizer = optim.SGD(T.parameters(), lr=t_lr, momentum=0.9)
optimizer = optim.Adam([cluster_embedding, cluster_prob], lr=lr)
for it in range(200000):
# Discriminator
eps = Variable(torch.randn(eps_size, K, eps_dim))
mu = Variable(torch.randn(eps_size, K, z_dim) * true_scale)
mu_sample = Q(torch.cat([cluster_embedding.unsqueeze(0).repeat(eps_size, 1, 1), eps], 2))
T_q = F.sigmoid(T(mu_sample))
T_prior = F.sigmoid(T(mu))
T_loss = - torch.mean(log(T_q) + log(1.0 - T_prior))
T_loss.backward()
T_optimizer.step()
reset_grad()
# Encoder
X = sample_X((mb_size, 1))
eps = Variable(torch.randn(K, eps_dim))
mu_sample = Q(torch.cat([cluster_embedding, eps], 1))
resp = Qz(X)
T_sample = T(mu_sample)
disc = torch.mean(- T_sample)
temp = torch.pow(X.expand(mb_size, K) - mu_sample.squeeze(), 2)
temp = log(F.softmax(cluster_prob)) - log(resp) - 0.5 * temp
loglike = torch.mean(torch.bmm(resp.view(mb_size, 1, K),
temp.view(mb_size, K, 1)))
elbo = - (disc + loglike)
elbo.backward()
Q_optimizer.step()
Qz_optimizer.step()
optimizer.step()
reset_grad()
# Print and plot every now and then
if it % 1000 == 0:
print('{} {:.4} {:.4} '
.format(it, -elbo.data[0], -T_loss.data[0]), end=' ')
print('{}'.format(' '.join([str(x) for x in F.softmax(cluster_prob).data.squeeze().numpy()])))
sys.stdout.flush()
eps = Variable(torch.randn(1000, K, eps_dim))
mu_sample = Q(torch.cat([cluster_embedding.unsqueeze(0).repeat(1000, 1, 1), eps], 2))
mu_sample = mu_sample.squeeze().data.numpy()
plt.figure(figsize=(12, 4))
for k in range(K):
n, bins, patches = plt.hist(mu_sample[:, k], 50)
plt.xlim(-25, 25)
plt.ylim(0, 80)
plt.savefig('out/mu_{}{}{}_{}.png'.format(weight_init['q'],
weight_init['qz'],
weight_init['t'],
str(cnt).zfill(4)),
bbox_inches='tight')
plt.clf()
cnt += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment