Created
February 20, 2022 13:09
-
-
Save fufufukakaka/eea6eef1d21aa94cafbc40887df51d6d to your computer and use it in GitHub Desktop.
RecBole full_sort_scores のカスタム版(for sequential model)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from recbole.data.interaction import Interaction | |
@torch.no_grad() | |
def full_sort_scores(uid_series, model, test_data, device=None): | |
"""Calculate the scores of all items for each user in uid_series. | |
Note: | |
The score of [pad] and history items will be set into -inf. | |
Args: | |
uid_series (numpy.ndarray or list): User id series. | |
model (AbstractRecommender): Model to predict. | |
test_data (FullSortEvalDataLoader): The test_data of model. | |
device (torch.device, optional): The device which model will run on. Defaults to ``None``. | |
Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``. | |
Returns: | |
torch.Tensor: the scores of all items for each user in uid_series. | |
""" | |
device = device or torch.device('cpu') | |
uid_series = torch.tensor(uid_series) | |
uid_field = test_data.dataset.uid_field | |
dataset = test_data.dataset | |
model.eval() | |
if not test_data.is_sequential: | |
input_interaction = dataset.join(Interaction({uid_field: uid_series})) | |
history_item = test_data.uid2history_item[list(uid_series)] | |
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) | |
history_col = torch.cat(list(history_item)) | |
history_index = history_row, history_col | |
else: | |
# for sequential | |
# 各ユーザごとに一番最後の index のみ取得する | |
unique_last_interaction_indexes = [] | |
for uid in uid_series[:, None]: | |
_, index = (dataset.inter_feat[uid_field] == uid[:,None]).nonzero(as_tuple=True) | |
unique_last_interaction_indexes.append(index[-1:].tolist()[0]) | |
input_interaction = dataset[torch.tensor(unique_last_interaction_indexes)] | |
history_index = None | |
# Get scores of all items | |
input_interaction = input_interaction.to(device) | |
try: | |
scores = model.full_sort_predict(input_interaction) | |
except NotImplementedError: | |
input_interaction = input_interaction.repeat(dataset.item_num) | |
input_interaction.update(test_data.dataset.get_item_feature().to(device).repeat(len(uid_series))) | |
scores = model.predict(input_interaction) | |
scores = scores.view(-1, dataset.item_num) | |
scores[:, 0] = -np.inf # set scores of [pad] to -inf | |
if history_index is not None: | |
scores[history_index] = -np.inf # set scores of history items to -inf | |
return scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://github.com/RUCAIBox/RecBole/blob/master/recbole/utils/case_study.py#L53