Skip to content

Instantly share code, notes, and snippets.

@oscar-defelice
Last active June 3, 2020 08:39
Show Gist options
  • Save oscar-defelice/0ba2c81f1f8c55f9535f7af636885379 to your computer and use it in GitHub Desktop.
Save oscar-defelice/0ba2c81f1f8c55f9535f7af636885379 to your computer and use it in GitHub Desktop.
def get_triplets_hard(batch_size, X_usr, X_item, df, return_cache = False):
"""
Returns the list of three arrays to feed the model.
Parameters
----------
batch_size : int
size of the batch.
X_usr : numpy array of shape (n_users, n_user_features)
array of user metadata.
X_item : numpy array of shape (n_items, n_item_features)
array of item metadata.
df : Pandas DataFrame
dataframe containing user-item ratings.
return_cache : bool
parameter to triggere whether we want the list of ids corresponding to
triplets.
default: False
Returns
-------
triplets : list of numpy arrays
list containing 3 tensors A,P,N corresponding to:
- Anchor A : (batch_size, n_user_features)
- Positive P : (batch_size, n_item_features)
- Negative N : (batch_size, n_item_features)
"""
# constant values
n_user_features = X_usr.shape[1]
n_item_features = X_item.shape[1]
# define user_list
user_list = list(df.index.values)
# initialise result
triplets = [np.zeros((batch_size, n_user_features)), # anchor
np.zeros((batch_size, n_item_features)), # pos
np.zeros((batch_size, n_item_features)) # neg
]
user_ids = []
p_ids = []
n_ids = []
for i in range(batch_size):
# pick one random user for anchor
anchor_id = random.choice(user_list)
user_ids.append(anchor_id)
# all possible positive/negative samples for selected anchor
p_item_ids = get_pos(df, anchor_id)
n_item_ids = get_neg(df, anchor_id)
# pick one of the positve ids
try:
positive_id = random.choice(p_item_ids)
except IndexError:
positive_id = 0
p_ids.append(positive_id)
# pick the most similar negative id
try:
n_min = np.argmin([(cosine_dist(X_item[positive_id-1], X_item[k-1])) for k in n_item_ids])
negative_id = n_item_ids[n_min]
except:
try:
negative_id = random.choice(n_item_ids)
except IndexError:
negative_id = 0
n_ids.append(negative_id)
# define triplet
triplets[0][i,:] = X_usr[anchor_id-1][:]
if positive_id == 0:
triplets[1][i,:] = np.zeros((n_item_features,))
else:
triplets[1][i,:] = X_item[positive_id-1][:]
if negative_id == 0:
triplets[2][i,:] = np.zeros((n_item_features,))
else:
triplets[2][i,:] = X_item[negative_id-1][:]
if return_cache:
cache = {'users': user_ids, 'positive': p_ids, 'negative': n_ids}
return triplets, cache
return triplets
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment