vae recommender implementation on pytorch lightning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MVAERecommender(TopNRecommender): | |
# TopNRecommender contains methods to predict top k | |
def __init__(self, model_conf : Dict, novelty_per_item, num_users, num_items, remove_observed = False, ): | |
# ... configuration is skipped | |
# # # # Model Structure # # # # | |
# this is to handle encoding dimensions as lists | |
self.encoder = nn.ModuleList() | |
# this enumeration produces dims in pairs, start with 1 | |
for i, (d_in, d_out) in enumerate(zip(self.enc_dims[:-1], self.enc_dims[1:]), start=1): | |
# double d out at last for the mean and variance parameters | |
if i == len(self.enc_dims) - 1: | |
d_out *= 2 | |
self.encoder.append(nn.Linear(d_in, d_out)) | |
# if NOT at the middle bottleneck point, simply add nonlinearit | |
if i != len(self.enc_dims) - 1: | |
self.encoder.append(nn.Tanh()) | |
self.decoder = nn.ModuleList() | |
# this enumeration produces dims in pairs, start with 1 | |
for i, (d_in, d_out) in enumerate(zip(self.dec_dims[:-1], self.dec_dims[1:]), start=1): | |
self.decoder.append(nn.Linear(d_in, d_out)) | |
# if we're not at the last layer, then add nonlinearities | |
if i != len(self.dec_dims) - 1: | |
self.decoder.append(nn.Tanh()) | |
def forward(self, x): | |
# corrupt the input after normalization using Euclidean norm | |
h = F.dropout(F.normalize(x), p=self.dropout, training=self.training) | |
# forward to the encoders | |
for layer in self.encoder: | |
h = layer(h) | |
# h is we get our q(z|x) parameters | |
# mean and standard dev | |
mu_q = h[:, :self.enc_dims[-1]] | |
logvar_q = h[:, self.enc_dims[-1]:] | |
std_q = torch.exp(0.5 * logvar_q) | |
## Sample from q our z | |
# fill a tensor with unit mean and variance | |
epsilon = torch.zeros_like(std_q).normal_(mean=0, std=0.01) | |
# sample from mean and variance | |
sampled_z = mu_q + self.training * epsilon * std_q | |
# decode for reconstruction error | |
output = sampled_z | |
for layer in self.decoder: | |
output = layer(output) | |
# kl divergence for a normal distribution | |
kl_loss = ((0.5 * (-logvar_q + torch.exp(logvar_q) + torch.pow(mu_q, 2) - 1)).sum(1)).mean() | |
return output, kl_loss | |
def training_step(self, batch, batch_idx): | |
"""One training step | |
Args: | |
batch (torch.Tensor): batch matrix | |
batch_idx (list): mini-batch index | |
Returns: | |
torch.Tensor: loss | |
""" | |
# prep the annealing, this is a linearly increasing function with a cap | |
if self.total_anneal_steps > 0: | |
self.anneal = min(self.anneal_cap, 1. * self.update_count / self.total_anneal_steps) | |
else: | |
self.anneal = self.anneal_cap | |
# forward prop | |
pred_matrix, kl_loss = self(batch) | |
# loss | |
loss = self.__compute_loss(batch, pred_matrix, kl_loss) | |
self.update_count += 1 | |
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def __compute_loss(self, batch_matrix, pred_matrix, kl_loss): | |
# first term is reconstructon loss | |
# softmax the predicted matrix | |
# mask the losses via multiplication | |
# sum and average | |
ce_loss = -(F.log_softmax(pred_matrix, 1) * batch_matrix).sum(1).mean() | |
loss = ce_loss + kl_loss * self.anneal | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment