-
-
Save krsnewwave/de6990b132d8979f026c4c633d2cf523 to your computer and use it in GitHub Desktop.
lightfm model for mlflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class KedroMLFlowLightFM(mlflow.pyfunc.PythonModel): | |
def load_context(self, context): | |
contents = context.artifacts | |
self.idx_to_names = cloudpickle.load(open(contents["idx_to_names"], 'rb')) | |
self.item_factors = np.load(contents["item_factors"]) | |
self.user_factors = np.load(contents["user_factors"]) | |
self.item_biases = np.load(contents["item_biases"]) | |
self.user_biases = np.load(contents["user_biases"]) | |
self.item_rank = pd.read_csv(contents["item_rank"]) | |
# load annoy index | |
annoy_index_file_path = contents["annoy_index"] | |
params = cloudpickle.load(open(contents["params"], 'rb')) | |
metric = params["metric"] | |
# n_trees = params["n_trees"] | |
self.annoy_index = AnnoyIndex(self.item_factors.shape[1], metric) | |
self.annoy_index.load(annoy_index_file_path) | |
def predict(self, context, model_input : pd.DataFrame): | |
"""Prediction | |
Args: | |
model_input (pd.DataFrame): if contains userid then prediction is warm start | |
if not, then prediction is item-based only | |
Returns: | |
_type_: _description_ | |
""" | |
# (1) if dataframe contains user id | |
if USER_ID in model_input: | |
# group every user to item id | |
users_to_items = model_input.groupby(USER_ID)[ITEM_ID].unique() | |
list_recos = [] | |
for user_id, items in users_to_items.iteritems(): | |
# get nearest neighbors | |
list_nn = [] | |
for item_id in items: | |
list_nn.extend(self.annoy_index.get_nns_by_item(item_id, n)) | |
# get indexes of items | |
item_factors = self.item_factors[list_nn] | |
item_biases = self.item_biases[list_nn] | |
# get index of user | |
user_factors = self.user_factors[user_id].reshape(1, -1) | |
user_bias = self.user_biases[user_id].reshape(1) | |
# perform scoring | |
scores = RecommenderUtils.produce_scores( | |
item_factors, item_biases, user_factors, user_bias) | |
# argsort then reindex to old | |
sorted_items = np.array(list_nn)[np.argsort(scores)][0] | |
# get item names | |
recos = [self.idx_to_names[v] for v in sorted_items][:N_RECOS] | |
list_recos.append({USER_ID : user_id, "recos": recos}) | |
return list_recos | |
elif ITEM_ID in model_input: | |
items = model_input[ITEM_ID] | |
# get nearest neighbors | |
list_nn = [] | |
for item_id in items: | |
list_nn.extend(self.annoy_index.get_nns_by_item(item_id, n)) | |
# get ranking (pandas) | |
df_rank = self.item_rank.set_index(ITEM_POSITIONAL_INDEX_NAME) | |
df_rank_subset = df_rank.loc[list_nn] | |
df_rank_subset = df_rank_subset.sort_values( | |
by=NUM_USERS_RANK_SORT_NAME, ascending=False) | |
return df_rank_subset[MOVIE_NAME][:N_RECOS].tolist() | |
else: | |
raise ValueError("Please correct format") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment