Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active May 1, 2022 15:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/d456d66910f628d4becf1c64ad1f79dd to your computer and use it in GitHub Desktop.
Save krsnewwave/d456d66910f628d4becf1c64ad1f79dd to your computer and use it in GitHub Desktop.
class CDAE(pl.LightningModule):
def __init__(self, model_conf : Dict, novelty_per_item, num_users, num_items, remove_observed = False, ):
super().__init__()
self.hidden_dim = model_conf["hidden_dim"]
# ... other self. initializations
self.user_embedding = nn.Embedding(self.num_users, self.hidden_dim)
self.encoder = nn.Linear(self.num_items, self.hidden_dim)
self.decoder = nn.Linear(self.hidden_dim, self.num_items)
self.criterion = nn.BCELoss(reduction='sum')
# for flattened dictionary logging, add to conf
model_conf["num_users"] = num_users
model_conf["num_items"] = num_items
self.save_hyperparameters(model_conf, ignore=["novelty_per_item", "remove_observed"])
def forward(self, x):
rating_matrix, user_idx = x
# ... some normalize options here
# ...
# (1) corrupt the rating matrix when in training
corrupted_rating_matrix = F.dropout(rating_matrix, self.corruption_ratio, training=self.training)
# (2) build the collaborative denoising autoencoder
# first term - ratings, second term - user embedding
embedded_users = self.user_embedding(user_idx)
encoded_ratings = self.encoder(corrupted_rating_matrix)
enc = torch.add(embedded_users, encoded_ratings)
enc = self.__apply_activation(self.act, enc)
dec = self.decoder(enc)
return self.__apply_activation(self.out_act, dec)
def training_step(self, batch, batch_idx):
# negative sampling options here
# ...
pred_matrix, batch_matrix = self.__get_pred_matrix(batch, batch_idx)
loss = self.criterion(pred_matrix, batch_matrix)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
score_prefix = "val"
return self.__shared_evaluation(batch, batch_idx, score_prefix)
def __get_pred_matrix(self, batch, batch_idx):
batch_matrix, user_ids = batch
pred_matrix = self(batch)
return pred_matrix, batch_matrix
def __shared_evaluation(self, batch, batch_idx, prefix):
with torch.no_grad():
train_matrix, test_matrix, user_ids = batch
pred_matrix = self((train_matrix, user_ids))
# compute loss on targets
loss = self.criterion(pred_matrix, test_matrix)
# (1) convert test matrix to dictionary
targets = np_mat_to_dict(test_matrix.cpu().numpy())
# (2) Get the top-k predictions
pred_matrix = pred_matrix.cpu().numpy()
top_k_recos = self.predict_topk(pred_matrix, self.max_k)
# (3) Precision, Recall, NDCG @ k
scores = self.__prec_recall_ndcg(top_k_recos, targets)
# ... and other metrics
# ...
self.log_dict(score_dict, on_epoch=True, prog_bar=True, logger=True)
return score_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment