Skip to content

Instantly share code, notes, and snippets.

@talesa
Last active March 19, 2019 09:45
Show Gist options
  • Save talesa/be22d5a2260d1d2399c04819c6237d1a to your computer and use it in GitHub Desktop.
Save talesa/be22d5a2260d1d2399c04819c6237d1a to your computer and use it in GitHub Desktop.
Simple script to minimize KL between distributions using PyTorch
import torch
from torch import tensor
import torch.distributions as td
import torch.optim as optim
import matplotlib.pyplot as plt
from torch import isinf, isnan
from tqdm import tqdm as tqdm
import numpy as np
import scipy.stats
def log_mean_exp(x):
xmax, _ = x.max(dim=1)
output = -tensor([x.shape[1]], dtype=torch.float).log_() + \
xmax + \
(x - xmax.unsqueeze(1)).exp_().sum(dim=1).log_()
return output.unsqueeze(1)
mu = tensor([0.]).requires_grad_()
std = tensor([0.1]).requires_grad_()
q = td.Normal(mu, std)
class TruncatedNormal:
def __init__(self, loc, scale, theta):
self.theta = theta
self.loc = loc
self.scale = scale
# def log_prob(self, x):
# tail_value = scipy.stats.norm.cdf(self.theta, loc=self.loc, scale=self.scale)
# return torch.where(x < self.theta,
# -1e30 * torch.ones_like(x),
# td.Normal(self.loc, self.scale).log_prob(x) - np.log(tail_value))
def log_prob(self, x):
alpha = 4.
tail_value = scipy.stats.norm.cdf(self.theta, loc=self.loc, scale=self.scale)
output = td.Normal(self.loc, self.scale).log_prob(x) - np.log(tail_value)
multiplier = torch.where(x < self.theta,
1. + torch.tanh((x - self.theta) * alpha),
torch.ones_like(x))
return output + multiplier.log().clamp(-1e30, 1e30)
def sample(self, shape):
raise NotImplemented()
# class TruncatedNormal:
# def __init__(self, loc, scale, theta):
# self.theta = theta
# self.loc = loc
# self.scale = scale
#
# def log_prob(self, x):
# tail_value = scipy.stats.norm.cdf(self.theta, loc=self.loc, scale=self.scale)
# return torch.where(x < self.theta,
# -1e7 * torch.ones_like(x),
# td.Normal(self.loc, self.scale).log_prob(x) - np.log(tail_value))
#
# def sample(self, shape):
# raise NotImplemented()
# p = td.Normal(tensor([-5., 5.]),
# tensor([1., 1.]))
p = TruncatedNormal(0.5, 1./np.sqrt(2.), 0.7)
optimizer = optim.Adam([mu, std], lr=1e-2)
losses = []
mus = []
stds = []
batch_size = 1000
epochs = 10000
# kl_name = 'pq'
kl_name = 'qp'
for i in tqdm(range(epochs)):
optimizer.zero_grad()
# log_mean_exp() is here because I have this Mixture of Gaussians p with 2 components
if kl_name == 'pq':
x = p.sample((batch_size,))
logqx = q.log_prob(x)
logpx = p.log_prob(x)
kl = logpx - logqx
else:
x = q.rsample((batch_size,))
logqx = q.log_prob(x)
# logpx = log_mean_exp(p.log_prob(x))
logpx = p.log_prob(x)
kl = logqx - logpx
loss = kl.mean()
if (isinf(loss) or isnan(loss)):
print(f"Loss is invalid! {loss.item()} on epoch {i}")
import ipdb
ipdb.set_trace()
losses.append(loss.item())
mus.append(mu.item())
stds.append(std.item())
loss.backward()
optimizer.step()
# plt.plot(losses[:-1])
plt.plot(mus[:-1], label='mu')
plt.plot(stds[:-1], label='std')
plt.legend()
plt.show()
x = np.linspace(-15, 15, 1000)
plt.plot(x, p.log_prob(tensor(x, dtype=torch.float).unsqueeze(1)).exp().mean(dim=1).numpy(), label='p')
plt.plot(x, q.log_prob(tensor(x, dtype=torch.float)).exp().detach().numpy(), label='q')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment