Skip to content

Instantly share code, notes, and snippets.

@fufufukakaka
Created February 20, 2022 13:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fufufukakaka/eea6eef1d21aa94cafbc40887df51d6d to your computer and use it in GitHub Desktop.
Save fufufukakaka/eea6eef1d21aa94cafbc40887df51d6d to your computer and use it in GitHub Desktop.
RecBole full_sort_scores のカスタム版(for sequential model)
import numpy as np
import torch
from recbole.data.interaction import Interaction
@torch.no_grad()
def full_sort_scores(uid_series, model, test_data, device=None):
"""Calculate the scores of all items for each user in uid_series.
Note:
The score of [pad] and history items will be set into -inf.
Args:
uid_series (numpy.ndarray or list): User id series.
model (AbstractRecommender): Model to predict.
test_data (FullSortEvalDataLoader): The test_data of model.
device (torch.device, optional): The device which model will run on. Defaults to ``None``.
Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``.
Returns:
torch.Tensor: the scores of all items for each user in uid_series.
"""
device = device or torch.device('cpu')
uid_series = torch.tensor(uid_series)
uid_field = test_data.dataset.uid_field
dataset = test_data.dataset
model.eval()
if not test_data.is_sequential:
input_interaction = dataset.join(Interaction({uid_field: uid_series}))
history_item = test_data.uid2history_item[list(uid_series)]
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
history_col = torch.cat(list(history_item))
history_index = history_row, history_col
else:
# for sequential
# 各ユーザごとに一番最後の index のみ取得する
unique_last_interaction_indexes = []
for uid in uid_series[:, None]:
_, index = (dataset.inter_feat[uid_field] == uid[:,None]).nonzero(as_tuple=True)
unique_last_interaction_indexes.append(index[-1:].tolist()[0])
input_interaction = dataset[torch.tensor(unique_last_interaction_indexes)]
history_index = None
# Get scores of all items
input_interaction = input_interaction.to(device)
try:
scores = model.full_sort_predict(input_interaction)
except NotImplementedError:
input_interaction = input_interaction.repeat(dataset.item_num)
input_interaction.update(test_data.dataset.get_item_feature().to(device).repeat(len(uid_series)))
scores = model.predict(input_interaction)
scores = scores.view(-1, dataset.item_num)
scores[:, 0] = -np.inf # set scores of [pad] to -inf
if history_index is not None:
scores[history_index] = -np.inf # set scores of history items to -inf
return scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment