This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CDAE(pl.LightningModule): | |
def __init__(self, model_conf : Dict, novelty_per_item, num_users, num_items, remove_observed = False, ): | |
super().__init__() | |
self.hidden_dim = model_conf["hidden_dim"] | |
# ... other self. initializations | |
self.user_embedding = nn.Embedding(self.num_users, self.hidden_dim) | |
self.encoder = nn.Linear(self.num_items, self.hidden_dim) | |
self.decoder = nn.Linear(self.hidden_dim, self.num_items) | |
self.criterion = nn.BCELoss(reduction='sum') | |
# for flattened dictionary logging, add to conf | |
model_conf["num_users"] = num_users | |
model_conf["num_items"] = num_items | |
self.save_hyperparameters(model_conf, ignore=["novelty_per_item", "remove_observed"]) | |
def forward(self, x): | |
rating_matrix, user_idx = x | |
# ... some normalize options here | |
# ... | |
# (1) corrupt the rating matrix when in training | |
corrupted_rating_matrix = F.dropout(rating_matrix, self.corruption_ratio, training=self.training) | |
# (2) build the collaborative denoising autoencoder | |
# first term - ratings, second term - user embedding | |
embedded_users = self.user_embedding(user_idx) | |
encoded_ratings = self.encoder(corrupted_rating_matrix) | |
enc = torch.add(embedded_users, encoded_ratings) | |
enc = self.__apply_activation(self.act, enc) | |
dec = self.decoder(enc) | |
return self.__apply_activation(self.out_act, dec) | |
def training_step(self, batch, batch_idx): | |
# negative sampling options here | |
# ... | |
pred_matrix, batch_matrix = self.__get_pred_matrix(batch, batch_idx) | |
loss = self.criterion(pred_matrix, batch_matrix) | |
# logs metrics for each training_step, | |
# and the average across the epoch, to the progress bar and logger | |
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
score_prefix = "val" | |
return self.__shared_evaluation(batch, batch_idx, score_prefix) | |
def __get_pred_matrix(self, batch, batch_idx): | |
batch_matrix, user_ids = batch | |
pred_matrix = self(batch) | |
return pred_matrix, batch_matrix | |
def __shared_evaluation(self, batch, batch_idx, prefix): | |
with torch.no_grad(): | |
train_matrix, test_matrix, user_ids = batch | |
pred_matrix = self((train_matrix, user_ids)) | |
# compute loss on targets | |
loss = self.criterion(pred_matrix, test_matrix) | |
# (1) convert test matrix to dictionary | |
targets = np_mat_to_dict(test_matrix.cpu().numpy()) | |
# (2) Get the top-k predictions | |
pred_matrix = pred_matrix.cpu().numpy() | |
top_k_recos = self.predict_topk(pred_matrix, self.max_k) | |
# (3) Precision, Recall, NDCG @ k | |
scores = self.__prec_recall_ndcg(top_k_recos, targets) | |
# ... and other metrics | |
# ... | |
self.log_dict(score_dict, on_epoch=True, prog_bar=True, logger=True) | |
return score_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment