Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active May 23, 2022 14:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/06236b0ca59626197f9e8abc418504fb to your computer and use it in GitHub Desktop.
Save krsnewwave/06236b0ca59626197f9e8abc418504fb to your computer and use it in GitHub Desktop.
vae recommender implementation on pytorch lightning
class MVAERecommender(TopNRecommender):
# TopNRecommender contains methods to predict top k
def __init__(self, model_conf : Dict, novelty_per_item, num_users, num_items, remove_observed = False, ):
# ... configuration is skipped
# # # # Model Structure # # # #
# this is to handle encoding dimensions as lists
self.encoder = nn.ModuleList()
# this enumeration produces dims in pairs, start with 1
for i, (d_in, d_out) in enumerate(zip(self.enc_dims[:-1], self.enc_dims[1:]), start=1):
# double d out at last for the mean and variance parameters
if i == len(self.enc_dims) - 1:
d_out *= 2
self.encoder.append(nn.Linear(d_in, d_out))
# if NOT at the middle bottleneck point, simply add nonlinearit
if i != len(self.enc_dims) - 1:
self.encoder.append(nn.Tanh())
self.decoder = nn.ModuleList()
# this enumeration produces dims in pairs, start with 1
for i, (d_in, d_out) in enumerate(zip(self.dec_dims[:-1], self.dec_dims[1:]), start=1):
self.decoder.append(nn.Linear(d_in, d_out))
# if we're not at the last layer, then add nonlinearities
if i != len(self.dec_dims) - 1:
self.decoder.append(nn.Tanh())
def forward(self, x):
# corrupt the input after normalization using Euclidean norm
h = F.dropout(F.normalize(x), p=self.dropout, training=self.training)
# forward to the encoders
for layer in self.encoder:
h = layer(h)
# h is we get our q(z|x) parameters
# mean and standard dev
mu_q = h[:, :self.enc_dims[-1]]
logvar_q = h[:, self.enc_dims[-1]:]
std_q = torch.exp(0.5 * logvar_q)
## Sample from q our z
# fill a tensor with unit mean and variance
epsilon = torch.zeros_like(std_q).normal_(mean=0, std=0.01)
# sample from mean and variance
sampled_z = mu_q + self.training * epsilon * std_q
# decode for reconstruction error
output = sampled_z
for layer in self.decoder:
output = layer(output)
# kl divergence for a normal distribution
kl_loss = ((0.5 * (-logvar_q + torch.exp(logvar_q) + torch.pow(mu_q, 2) - 1)).sum(1)).mean()
return output, kl_loss
def training_step(self, batch, batch_idx):
"""One training step
Args:
batch (torch.Tensor): batch matrix
batch_idx (list): mini-batch index
Returns:
torch.Tensor: loss
"""
# prep the annealing, this is a linearly increasing function with a cap
if self.total_anneal_steps > 0:
self.anneal = min(self.anneal_cap, 1. * self.update_count / self.total_anneal_steps)
else:
self.anneal = self.anneal_cap
# forward prop
pred_matrix, kl_loss = self(batch)
# loss
loss = self.__compute_loss(batch, pred_matrix, kl_loss)
self.update_count += 1
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss
def __compute_loss(self, batch_matrix, pred_matrix, kl_loss):
# first term is reconstructon loss
# softmax the predicted matrix
# mask the losses via multiplication
# sum and average
ce_loss = -(F.log_softmax(pred_matrix, 1) * batch_matrix).sum(1).mean()
loss = ce_loss + kl_loss * self.anneal
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment