Skip to content

Instantly share code, notes, and snippets.

@salkj
Created April 8, 2021 00:21
Show Gist options
  • Save salkj/f0db13d68ab39d4a3da2374643e99c3b to your computer and use it in GitHub Desktop.
Save salkj/f0db13d68ab39d4a3da2374643e99c3b to your computer and use it in GitHub Desktop.
from . import contextual_watch_sequence_dataset
from . import word2gm
from . import mixturesamefamily as M
import pandas as pd
import pickle
from itertools import combinations
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as D
class GMM_Scale_Clipper(object):
def __init__(self, lower_scale_cap, upper_scale_cap):
self.lower_scale_cap = lower_scale_cap
self.upper_scale_cap = upper_scale_cap
def __call__(self, module, to_update):
with torch.no_grad():
for i in range(module.num_mixture_components):
for token in to_update:
scale_tensor = getattr(module, f'gmm_scales_comp_{str(i)}')
scale_tensor.weight.data[token,:] = torch.max(self.lower_scale_cap * torch.ones_like(scale_tensor.weight.data[token, :]),
torch.min(self.upper_scale_cap * torch.ones_like(scale_tensor.weight.data[token, :]), scale_tensor.weight.data[token,:]))
def _batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
def _set_default_tensor_type(device):
if 'cuda' in device:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
return
def train(cc_input_data, kg_metadata, softcoded_recs, params):
MARGIN = params['margin']
BATCH_SIZE = params['batch_size'] # small as possible?
NUM_MIXTURE_COMPONENTS = params['num_mixture_components']
COMPONENT_DIM = params['dim']
EPOCHS = params['epochs']# more?
CONTEXT_WINDOW = params['ws'] # good?
ALPHA = params['alpha']
INIT_SCALE = params['init_scale']
LOW_SCALE_CAP = params['low_scale_cap']
UPPER_SCALE_CAP = params['upper_scale_cap']
MEAN_NORM_CAP = params['mean_norm_cap']
COVAR_MODE = 'diagonal'
KG = kg_metadata
SOFTCODED_RECS = softcoded_recs
CLASS_LABELS = params['class_labels']
SCALE_GRAD_BY_FREQ = params['scale_grad_by_freq']
LR = params['lr']
PATH = cc_input_data
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
print(f'Using {device}')
_set_default_tensor_type(device)
dataset = contextual_watch_sequence_dataset.contextual_watch_sequence_dataset(PATH, KG, SOFTCODED_RECS, context_window=CONTEXT_WINDOW, alpha=ALPHA, class_label=CLASS_LABELS)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE)
print(f'Dataset size: {len(dataset)}')
print(f'Number of batches: {len(dataset)/BATCH_SIZE}')
df = dataset.title_id
model = word2gm.word2gm(NUM_MIXTURE_COMPONENTS, COMPONENT_DIM, dataset.content2idx, dataset.metadata2metadataidx, BATCH_SIZE, INIT_SCALE, COVAR_MODE, MEAN_NORM_CAP, SCALE_GRAD_BY_FREQ)
model = model.to(device)
scale_clipper = GMM_Scale_Clipper(LOW_SCALE_CAP, UPPER_SCALE_CAP)
optimizer = optim.Adadelta(model.parameters())
#optimizer = optim.Adam(model.parameters(), lr=LR)
#optimizer = optim.SGD(model.parameters(), lr=LR)
tokens = list(dataset.metadata2metadataidx.values()) + list(dataset.content2idx.values())
for epoch in range(EPOCHS):
# Training
print(f'Epoch: {epoch}')
for i, (batch, true_context, fake_context) in enumerate(train_loader):
start = time.time()
optimizer.zero_grad()
log_true_energy = model(batch, true_context)
log_fake_energy = model(batch, fake_context)
# pairwise hinge
loss = F.relu(MARGIN + log_fake_energy - log_true_energy)
# pairwise logistic
# loss = F.softplus(log_fake_energy - log_true_energy)
# pointwise hinge
# loss = F.relu(MARGIN - log_true_energy) + F.relu(MARGIN + log_fake_energy)
# pointwise logistic
# loss = F.softplus(-1. * log_true_energy) + F.softplus(log_fake_energy)
loss = torch.mean(loss)
loss.backward(retain_graph=False)
optimizer.step()
to_update = []
for t in batch:
to_update.append(t.item())
for t in true_context:
to_update.append(t.item())
for t in fake_context:
to_update.append(t.item())
to_update = list(set(to_update))
scale_clipper(model, to_update)
end = time.time()
with torch.no_grad():
if i % 50 == 0:
print(50*'--')
print(f'Elapsed time: {end - start}')
print(epoch, i, loss)
break
model = model.to('cpu')
return model, dataset
def save_model(directory, dataset, model):
torch.save(model, directory + 'model.pth')
with open(directory + 'dataset.pkl', 'wb') as f:
pickle.dump(dataset, f)
def get_token_mean(token, model):
return torch.stack([getattr(model, f'gmm_means_comp_{str(i)}')(torch.tensor([token], dtype=torch.long)) for i in range(model.num_mixture_components)], dim=1).squeeze(0)
def get_token_scale(token, model):
return torch.stack([getattr(model, f'gmm_scales_comp_{str(i)}')(torch.tensor([token], dtype=torch.long)) for i in range(model.num_mixture_components)], dim=1).squeeze(0)
def get_token_mix(token, model):
return getattr(model, f'gmm_mix')(torch.tensor([token], dtype=torch.long)).squeeze(0)
def get_token_gmm(token, model):
token_means = get_token_mean(token, model)
token_scales = get_token_scale(token, model)
token_mix = get_token_mix(token, model)
gmm = M.MixtureSameFamily(D.Categorical(logits=token_mix), D.Independent(D.Normal(token_means, token_scales), 1))
return gmm
def get_mean_gmm(content_mog_stack):
mean_gmm = None
history_length = len(content_mog_stack)
if history_length == 1:
return content_mog_stack[0]
with torch.no_grad():
init_round = True
while len(content_mog_stack) != 0:
mog_a = content_mog_stack.pop(0)
mog_b = content_mog_stack.pop(0)
new_gmm_mix = []
new_gmm_means = []
new_gmm_vars = []
for c_n in range(len(mog_a.mixture_distribution.probs)):
for c_m in range(len(mog_b.mixture_distribution.probs)):
new_comp_weight = mog_a.mixture_distribution.probs[c_n] * mog_b.mixture_distribution.probs[c_m]
if init_round:
# print(f'init_round: {init_round}')
new_comp_mean = (1./history_length * mog_a.components_distribution.base_dist.loc[c_n,:]) + (1./history_length * mog_b.components_distribution.base_dist.loc[c_m,:])
new_comp_var = (1./(history_length**2) * mog_a.components_distribution.base_dist.scale[c_n,:].pow(2)) + (1./(history_length**2) * mog_b.components_distribution.base_dist.scale[c_m,:].pow(2))
else:
# print(f'init_round: {init_round}')
new_comp_mean = mog_a.components_distribution.base_dist.loc[c_n,:] + (1./history_length * mog_b.components_distribution.base_dist.loc[c_m,:])
new_comp_var = mog_a.components_distribution.base_dist.scale[c_n,:].pow(2) + (1./(history_length**2) * mog_b.components_distribution.base_dist.scale[c_m,:].pow(2))
new_gmm_mix.append(new_comp_weight)
new_gmm_means.append(new_comp_mean)
new_gmm_vars.append(new_comp_var)
new_gmm_mix = torch.stack(new_gmm_mix)
new_gmm_means = torch.stack(new_gmm_means)
new_gmm_vars = torch.stack(new_gmm_vars)
new_gmm = M.MixtureSameFamily(D.Categorical(probs=new_gmm_mix), D.Independent(D.Normal(new_gmm_means, torch.sqrt(new_gmm_vars)), 1))
init_round = False
if len(content_mog_stack) == 0:
mean_gmm = new_gmm
break
else:
content_mog_stack.insert(0, new_gmm)
return mean_gmm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from . import mixturesamefamily as M
import math
import time
# An implementation of word2gm
class word2gm(nn.Module):
def __init__(self, num_mixture_components,
dist_dimensions,
content2idx,
metadata2metadataidx,
batch_size,
init_scale=1.,
covar_mode='diagonal',
mean_max_norm=None,
scale_grad_by_freq=False):
super(word2gm, self).__init__()
self.content2idx = content2idx
self.metadata2metadataidx = metadata2metadataidx
self.batch_size = batch_size
self.num_mixture_components = num_mixture_components
self.dist_dimensions = dist_dimensions
self.init_scale = init_scale
self.covar_mode = covar_mode
self.mean_max_norm = mean_max_norm
self.scale_grad_by_freq = scale_grad_by_freq
for i in range(self.num_mixture_components):
setattr(self, f'gmm_means_comp_{str(i)}', nn.Embedding(len(content2idx) + len(self.metadata2metadataidx), self.dist_dimensions, max_norm=mean_max_norm, scale_grad_by_freq=scale_grad_by_freq))
nn.init.uniform_(getattr(self, f'gmm_means_comp_{str(i)}').weight, -1. * math.sqrt(3. / self.dist_dimensions), math.sqrt(3. / self.dist_dimensions))
setattr(self, f'gmm_scales_comp_{str(i)}', nn.Embedding(len(content2idx) + len(self.metadata2metadataidx), self.dist_dimensions, scale_grad_by_freq=scale_grad_by_freq))
nn.init.constant_(getattr(self, f'gmm_scales_comp_{str(i)}').weight, init_scale)
setattr(self, f'gmm_mix', nn.Embedding(len(content2idx) + len(self.metadata2metadataidx), self.num_mixture_components, scale_grad_by_freq=scale_grad_by_freq))
nn.init.constant_(getattr(self, f'gmm_mix').weight, 1.)
def forward(self, word, true_context):
a = self.log_expected_likelihood_kernel(word, true_context)
return a
def _log_expected_likelihood_kernel(self, batch_token_a, batch_token_b):
mean_batch_a = torch.stack([getattr(self, f'gmm_means_comp_{str(i)}')(batch_token_a) for i in range(self.num_mixture_components)], dim=1)
mean_batch_b = torch.stack([getattr(self, f'gmm_means_comp_{str(i)}')(batch_token_b) for i in range(self.num_mixture_components)], dim=1)
scale_batch_a = torch.stack([getattr(self, f'gmm_scales_comp_{str(i)}')(batch_token_a) for i in range(self.num_mixture_components)], dim=1)
scale_batch_b = torch.stack([getattr(self, f'gmm_scales_comp_{str(i)}')(batch_token_b) for i in range(self.num_mixture_components)], dim=1)
mix_batch_a = F.softmax(self.gmm_mix(batch_token_a), dim=1).unsqueeze(-1)
mix_batch_b = F.softmax(self.gmm_mix(batch_token_b), dim=1).unsqueeze(-1)
diag_var_batch_a = scale_batch_a.pow(2)
diag_var_batch_b = scale_batch_b.pow(2)
return self.batched_partial_log_energy_diagonal_components(mix_batch_a, mix_batch_b, mean_batch_a, mean_batch_b, diag_var_batch_a, diag_var_batch_b, len(batch_token_a))
def log_expected_likelihood_kernel(self, batch_token_a, batch_token_b):
log_energy = self._log_expected_likelihood_kernel(batch_token_a, batch_token_b)
return log_energy
def batched_partial_log_energy_diagonal_components(self, mix_batch_a, mix_batch_b, mean_batch_a, mean_batch_b, diag_var_batch_a, diag_var_batch_b, batch_size):
eps = 0.
mix_prod = mix_batch_a.unsqueeze(2) * mix_batch_b.unsqueeze(1)
mix_prod = mix_prod.reshape((batch_size, -1))
mean_diff = mean_batch_a.unsqueeze(2) - mean_batch_b.unsqueeze(1)
mean_diff = mean_diff.reshape((batch_size, -1, self.dist_dimensions))
diag_sum = diag_var_batch_a.unsqueeze(2) + diag_var_batch_b.unsqueeze(1)
diag_sum = diag_sum.reshape((batch_size, -1, self.dist_dimensions)) + eps
inv_diag_sum = 1. / diag_sum
ple = -0.5 * torch.sum(torch.log(diag_sum), axis=-1) - 0.5 * torch.sum(mean_diff * inv_diag_sum * mean_diff, axis=-1)
max_ple = torch.max(ple, axis=-1)[0].view((-1,1))
log_energy = max_ple.view((-1,)) + torch.log(torch.sum(mix_prod * (torch.exp(ple - max_ple)), axis=-1))
return log_energy
def partial_log_energy_diagonal_components(self, mean_word2, mean_word1, diag_var_word2, diag_var_word1):
eps = 0.
mean_diff = mean_word2 - mean_word1
inv_diag_sum = 1. / (diag_var_word2 + diag_var_word1 + eps)
ple = -0.5 * torch.sum(torch.log(diag_var_word2 + diag_var_word1 + eps)) - 0.5 * torch.sum(mean_diff * inv_diag_sum * mean_diff)
return ple
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment