Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Created March 27, 2022 17:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/de6990b132d8979f026c4c633d2cf523 to your computer and use it in GitHub Desktop.
Save krsnewwave/de6990b132d8979f026c4c633d2cf523 to your computer and use it in GitHub Desktop.
lightfm model for mlflow
class KedroMLFlowLightFM(mlflow.pyfunc.PythonModel):
def load_context(self, context):
contents = context.artifacts
self.idx_to_names = cloudpickle.load(open(contents["idx_to_names"], 'rb'))
self.item_factors = np.load(contents["item_factors"])
self.user_factors = np.load(contents["user_factors"])
self.item_biases = np.load(contents["item_biases"])
self.user_biases = np.load(contents["user_biases"])
self.item_rank = pd.read_csv(contents["item_rank"])
# load annoy index
annoy_index_file_path = contents["annoy_index"]
params = cloudpickle.load(open(contents["params"], 'rb'))
metric = params["metric"]
# n_trees = params["n_trees"]
self.annoy_index = AnnoyIndex(self.item_factors.shape[1], metric)
self.annoy_index.load(annoy_index_file_path)
def predict(self, context, model_input : pd.DataFrame):
"""Prediction
Args:
model_input (pd.DataFrame): if contains userid then prediction is warm start
if not, then prediction is item-based only
Returns:
_type_: _description_
"""
# (1) if dataframe contains user id
if USER_ID in model_input:
# group every user to item id
users_to_items = model_input.groupby(USER_ID)[ITEM_ID].unique()
list_recos = []
for user_id, items in users_to_items.iteritems():
# get nearest neighbors
list_nn = []
for item_id in items:
list_nn.extend(self.annoy_index.get_nns_by_item(item_id, n))
# get indexes of items
item_factors = self.item_factors[list_nn]
item_biases = self.item_biases[list_nn]
# get index of user
user_factors = self.user_factors[user_id].reshape(1, -1)
user_bias = self.user_biases[user_id].reshape(1)
# perform scoring
scores = RecommenderUtils.produce_scores(
item_factors, item_biases, user_factors, user_bias)
# argsort then reindex to old
sorted_items = np.array(list_nn)[np.argsort(scores)][0]
# get item names
recos = [self.idx_to_names[v] for v in sorted_items][:N_RECOS]
list_recos.append({USER_ID : user_id, "recos": recos})
return list_recos
elif ITEM_ID in model_input:
items = model_input[ITEM_ID]
# get nearest neighbors
list_nn = []
for item_id in items:
list_nn.extend(self.annoy_index.get_nns_by_item(item_id, n))
# get ranking (pandas)
df_rank = self.item_rank.set_index(ITEM_POSITIONAL_INDEX_NAME)
df_rank_subset = df_rank.loc[list_nn]
df_rank_subset = df_rank_subset.sort_values(
by=NUM_USERS_RANK_SORT_NAME, ascending=False)
return df_rank_subset[MOVIE_NAME][:N_RECOS].tolist()
else:
raise ValueError("Please correct format")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment