Skip to content

Instantly share code, notes, and snippets.

Created September 12, 2019 13:07
Show Gist options
  • Save WenchaoDing/0f6539688715c568960075f77caa9ad3 to your computer and use it in GitHub Desktop.
Save WenchaoDing/0f6539688715c568960075f77caa9ad3 to your computer and use it in GitHub Desktop.
simp model
# Filename :
# Author : Wenchao Ding
# Date : 2018-06-25
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical
ONEOVERSQRT2PI = 1.0 / math.sqrt(2*math.pi)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nllloss = nn.NLLLoss().to(device)
class SIMP(nn.Module):
def __init__(self, num_mdns, in_features, out_features, num_gaussians):
super(SIMP, self).__init__()
self.hidden_size = 400
self.simp_embed = nn.Sequential(
nn.Linear(in_features, 400),
nn.Linear(400, 400),
nn.Linear(400, 400),
nn.Linear(400, self.hidden_size),
self.area = nn.Sequential(
nn.Linear(self.hidden_size, num_mdns),
self.mdns = nn.ModuleList([MDN(self.hidden_size, out_features, num_gaussians).to(device) for i in range(num_mdns)])
def forward(self, minibatch):
embeds = self.simp_embed(minibatch)
area_score = self.area(embeds)
mdn_inferences = []
for idx in range(len(self.mdns)):
pi, sigma, mu = self.mdns[idx](embeds)
mdn_inferences += [pi, sigma, mu]
return area_score, mdn_inferences
def simp_loss(area_score, mdn_inferences, target):
num_mdns = area_score.size()[1]
loss = torch.zeros(num_mdns).to(device)
for idx in range(num_mdns):
loss[idx] = mdn_loss_with_mask(mdn_inferences[idx*3], mdn_inferences[idx*3+1],
mdn_inferences[idx*3+2], target, select_label=idx+1)
mdn_loss = torch.mean(loss)
area_target = torch.tensor(target[:,0]-1).long().to(device)
area_loss = nllloss(area_score, area_target)
return area_loss, mdn_loss
class MDN(nn.Module):
(pi, sigma, mu) (BxG, BxGxO, BxGxO): B is the batch size, G is the
number of Gaussians, and O is the number of dimensions for each
Gaussian. Pi is a multinomial distribution of the Gaussians. Sigma
is the standard deviation of each Gaussian. Mu is the mean of each
def __init__(self, in_features, out_features, num_gaussians):
super(MDN, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_gaussians = num_gaussians
self.pi = nn.Sequential(
nn.Linear(in_features, num_gaussians),
self.sigma = nn.Linear(in_features, out_features*num_gaussians) = nn.Linear(in_features, out_features*num_gaussians)
def forward(self, minibatch):
# pi (batch_size x num_gaussians)
pi = torch.exp(self.pi(minibatch))
# self.sigma (in_features, out_features*num_gaussians)
# sigma: (batch_size x out_features*num_gaussians)
sigma = torch.exp(self.sigma(minibatch))
sigma = sigma.view(-1, self.num_gaussians, self.out_features)
mu =
if torch.sum(torch.isnan(mu))>0:
print('input', minibatch)
raise ValueError('weight overflow')
mu = mu.view(-1, self.num_gaussians, self.out_features)
return pi, sigma, mu
def gaussian_probability(sigma, mu, data):
"""Returns the probability of `data` given MoG parameters `sigma` and `mu`.
sigma (BxGxO): The standard deviation of the Gaussians. B is the batch
size, G is the number of Gaussians, and O is the number of
dimensions per Gaussian.
mu (BxGxO): The means of the Gaussians. B is the batch size, G is the
number of Gaussians, and O is the number of dimensions per Gaussian.
data (BxI): A batch of data. B is the batch size and I is the number of
input dimensions.
probabilities (BxG): The probability of each point in the probability
of the distribution in the corresponding sigma/mu index.
data = data.unsqueeze(1).expand_as(sigma)
ret = ONEOVERSQRT2PI * torch.exp(-0.5 * ((data - mu) / sigma)**2) / sigma
return, 2)
def mdn_loss(pi, sigma, mu, target):
"""Calculates the error, given the MoG parameters and the target
The loss is the negative log likelihood of the data given the MoG
prob = pi * gaussian_probability(sigma, mu, target)
nll = -torch.log(torch.sum(prob, dim=1))
return torch.mean(nll)
def mdn_loss_with_mask(pi, sigma, mu, target, select_label):
n = pi.size()[0]
mask_indices = []
for i in range(n):
if target[i][0] == select_label:
if len(mask_indices) == 0:
return torch.zeros(1).to(device)
indice_tensor = torch.from_numpy(np.array(mask_indices)).long().to(device)
# pi(BxG) sigma(BxGxO) mu(BxG)
pi_select = torch.index_select(pi, 0, indice_tensor)
sigma_select = torch.index_select(sigma, 0, indice_tensor)
mu_select = torch.index_select(mu, 0, indice_tensor)
target_select = torch.index_select(target, 0, indice_tensor)
prob = gaussian_probability(sigma_select, mu_select, target_select[:,1:])
prob = torch.sum(pi_select*prob, dim=1)
nll = -torch.log(prob+1e-10)
return torch.sum(nll)/n
def sample(pi, sigma, mu):
"""Draw samples from a MoG.
# categorical = Categorical(pi)
# pis = list(categorical.sample().data)
# sample =, sigma.size(2)).normal_()
# for i, mode in enumerate(pis):
# sample[i] = sample[i].mul(sigma[i,mode]).add(mu[i,mode])
# return sample
N, K = pi.shape
_, K, O = mu.shape
out = torch.zeros(N, O)
for i in range(N):
# pi must sum to 1, thus we can sample from a uniform
# distribution, then transform that to select the component
u = np.random.uniform() # sample from [0, 1)
# split [0, 1] into k segments: [0, pi[0]), [pi[0], pi[1]), ..., [pi[K-1], pi[K])
# then determine the segment `u` that falls into and sample from that component
prob_sum = 0
for k in range(K):
prob_sum +=[i, k]
if u < prob_sum:
# sample from the kth component
for o in range(O):
sample = np.random.normal([i, k, o],[i, k, o])
out[i, o] = sample
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment