Skip to content

Instantly share code, notes, and snippets.

@johnny12150
Created May 2, 2021 10:50
Show Gist options
  • Save johnny12150/083ccda234056839ac73bb396fe44b18 to your computer and use it in GitHub Desktop.
Save johnny12150/083ccda234056839ac73bb396fe44b18 to your computer and use it in GitHub Desktop.
import os
import pickle
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import copy
def process_seqs(iseqs, idates, ihist=None, data_aug=False, gen_target=True):
'''
產生data
:param iseqs: 所有原始session
:param idates: 該session發生時間
:param ihist: 該session過去的紀錄
:param data_aug: 是否做augmentation
:param gen_target: 是否要產答案
:return:
'''
out_seqs = []
out_dates = []
out_hists = []
labs = []
ids, un_aug_ids = [], []
count = 0
if ihist:
zip_input = zip(range(len(iseqs)), iseqs, idates, ihist)
else:
zip_input = zip(range(len(iseqs)), iseqs, idates, [None]*len(iseqs))
for id, seq, date, hist in zip_input:
for i in range(1, len(seq)):
if i == 1:
un_aug_ids += [count]
count += 1
if gen_target:
# 最後一個step當答案
tar = seq[-i]
labs += [tar]
out_seqs += [seq[:-i]]
else:
out_seqs += [seq]
out_dates += [date] # 紀錄每個session的月份
if ihist:
out_hists += [hist]
ids += [id] # 原始session id
# 設定要不要做 data augmentation
if not data_aug:
break
if ihist:
return out_seqs, out_dates, out_hists, labs, ids, un_aug_ids
else:
return out_seqs, out_dates, labs, ids, un_aug_ids
dataset_ = 'trivago'
save = True
if dataset_ == 'jdata2019':
dataset_path = './jdata2019/jdata_action.csv'
df = pd.read_csv(dataset_path, delimiter=',')
dates = pd.to_datetime(df['action_time'])
df['action_time'] = (dates - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s') # datetime to unix
df['day'] = dates.dt.date
df.rename(columns={'user_id': 'userId', 'sku_id': 'itemId', 'action_time': 'timestamp'}, inplace=True)
df['sessionId'] = df.groupby(['userId', 'day']).ngroup()
elif dataset_ == 'Retailrocket':
dataset_path = './Retailrocket/events.csv'
df = pd.read_csv(dataset_path, delimiter=',')
df = df[df['event']=='view']
df = df.reset_index()
df = df.drop(columns=['event','transactionid'])
df['timestamp'] = (df['timestamp']/1000).astype('int')
dates = pd.to_datetime(df['timestamp'], unit='s')
df['date_hour'] = dates.dt.floor('h')
df.rename(columns={'visitorid': 'userId', 'itemid': 'itemId'}, inplace=True)
df['sessionId'] = df.groupby(['userId', 'date_hour']).ngroup()
elif dataset_ == 'nowplaying':
dataset_path = './nowplaying/user_track_hashtag_timestamp.csv'
df = pd.read_csv(dataset_path, delimiter=',')
df['hashtag'] = df['hashtag'].str.lower()
df = df[df['hashtag']=='nowplaying']
df = df.reset_index()
df = df[['user_id', 'track_id', 'created_at']]
df.rename(columns={'user_id': 'userId', 'track_id': 'itemId', 'created_at':'timestamp'}, inplace=True)
df['timestamp'] = (pd.to_datetime(df['timestamp']) - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
dates = pd.to_datetime(df['timestamp'], unit='s')
df['date_hour'] = dates.dt.floor('h')
# df['day'] = dates.dt.date
df['sessionId'] = df.groupby(['userId', 'date_hour']).ngroup()
elif dataset_ == 'Taobao':
dataset_path = './Taobao/UserBehavior-withHeader.csv'
df = pd.read_csv(dataset_path, delimiter=',')
df.rename(columns={'sessionId': 'userId'}, inplace=True)
df = df[['userId', 'itemId', 'timestamp']]
# 11/30 1511971200 12/4 1512316800
df = df[(df['timestamp'] >= 1511539200) & (df['timestamp'] < 1511971200)]
dates = pd.to_datetime(df['timestamp'], unit='s')
df['date_hour'] = dates.dt.floor('h')
# df['day'] = dates.dt.date
df['sessionId'] = df.groupby(['userId', 'date_hour']).ngroup()
elif dataset_ == 'trivago':
dataset_path = 'trivago.csv'
df = pd.read_csv(dataset_path, delimiter=',')
df.rename(columns={'user_id': 'userId', 'session_id': 'sessionId', 'reference': 'itemId'}, inplace=True)
dataset = './' + dataset_
if not os.path.exists(dataset + '_hist') and save:
os.makedirs(dataset + '_hist')
min_item = 10 # user至少要看過幾個東西
min_user = 50 # item至少被user看過幾次
min_sess = 0 # user至少要有幾個session
min_sess_len = 3 # >=
min_hist_len = 1 # >=
print('Clicks Count:', str(df.shape[0]))
#shortcut = False
#dataset = './' + dataset
#if not os.path.exists(dataset + '_users') and save:
# os.makedirs(dataset + '_users')
df_with_user = df.dropna(subset=['userId'], axis=0)
del df
# 過濾太少出現的 user & item
frequent_user = df_with_user['userId'].value_counts()
frequent_user = frequent_user[frequent_user > min_item].index.tolist() # 至少買過三件的user
frequent_user = df_with_user.loc[df_with_user['userId'].isin(frequent_user)]
frequent_user_idx = set(frequent_user.index.tolist())
frequent_item = df_with_user['itemId'].value_counts()
frequent_item = frequent_item[frequent_item > min_user].index.tolist() # 至少被買過5次的item
frequent_item = df_with_user.loc[df_with_user['itemId'].isin(frequent_item)]
frequent_item_idx = set(frequent_item.index.tolist())
# 過濾掉session數太少的user
if min_sess > 0:
multi_sess = df_with_user.groupby('userId')['sessionId'].apply(lambda x: len(x))
multi_sess_user = multi_sess[multi_sess.gt(min_sess)].index.tolist()
multi_sess_user = df_with_user.loc[df_with_user['userId'].isin(multi_sess_user)]
multi_sess_index = set(multi_sess_user.index.tolist())
intersect_idx = list(set.intersection(frequent_item_idx, frequent_user_idx, multi_sess_index))
else:
intersect_idx = list(set(frequent_item_idx).intersection(set(frequent_user_idx)))
df_with_user = df_with_user.loc[df_with_user.index.isin(intersect_idx)]
print('# of sessions per user: ', str(df_with_user.sessionId.nunique()/df_with_user.userId.nunique()))
# 對user跟item id reindex
le = LabelEncoder() # 從1開始編碼
df_with_user['itemId'] = le.fit_transform(df_with_user['itemId']) + 1
max_itemid = len(le.classes_) + 1
print('# of items: ', str(max_itemid))
# new item id map category
#cat_dataset = './jdata2019/jdata_product.csv'
#cat_data = pd.read_csv(cat_dataset, index_col=False)
#id_cat = {}
#for i in range(1,max_itemid):
# i_ = le.inverse_transform([i-1])
# if not cat_data[cat_data['sku_id']==int(i_[0])].empty:
# id_cat[i] = cat_data[cat_data['sku_id']==int(i_[0])]['cate'].values[0]
# else:
# id_cat[i] = 82
le = LabelEncoder()
df_with_user['userId'] = le.fit_transform(df_with_user['userId']) + max_itemid
print('# of users: ', str(len(le.classes_)))
print('total_nodes: ', str(len(le.classes_) + max_itemid))
# 根據session, time排序, 然後group起來轉成list
df_with_user.sort_values(['sessionId', 'timestamp'], ascending=[True, True], inplace=True)
# 過濾掉長度小於3的session
sess_grouped_mask = df_with_user.groupby(['sessionId']).transform('count') >= min_sess_len
# mask完是true false
sess_grouped_mask['sessionId'] = sess_grouped_mask['userId'] # 把剛剛當group key的session id補回來
sess_ = df_with_user[sess_grouped_mask].dropna(axis=0) # drop掉被mask認為該丟掉的
sess_g = sess_.groupby(['sessionId'])
# 或是直接從groupby找最後的userid & ts
sess_ts = sess_g['timestamp'].last().reset_index(drop=True)
sess_user = sess_g['userId'].first().reset_index(drop=True)
sess_df = sess_g['itemId'].apply(list).reset_index() # add session_id back
sess_its = sess_g['timestamp'].apply(list).reset_index()
sess_its.rename(columns={'timestamp':'i_timestamp'}, inplace=True)
sess_its = sess_its.drop(columns=['sessionId'])
# sess = sess_g['itemId'].apply(list).tolist() # 把不同steps的item id根據session id整合起來成一個session
sess_df = pd.concat([sess_df, sess_user, sess_ts, sess_its], axis=1) # combine multi series
del sess_ts, sess_user, sess_its
# iter thru sessions, 找同user的歷史sessions
sess_df.sort_values(['timestamp'], ascending=[True], inplace=True)
sess_df.reset_index(drop=True, inplace=True)
sess_df['history'] = ''
sess_df['count_history'] = 0
user_history_session = {u: [] for u in sess_df.userId.unique().tolist()}
for i, row in tqdm(sess_df.iterrows(), total=len(sess_df)):
user = row.userId
sess_df.at[i, 'history'] = copy.deepcopy(user_history_session[user][-30:])
sess_df.at[i, 'count_history'] = len(user_history_session[user][-30:])
user_history_session[user].append(row.itemId)
sess_df = sess_df.loc[sess_df['count_history'] >= min_hist_len].reset_index(drop=True)
sess = sess_df.itemId.tolist()
sess_hist = sess_df.history.tolist()
sess_count = sess_df.count_history.tolist()
print('Avg Sess len: ', str(sum([len(i) for i in sess]) / len(sess)))
date_obj = pd.to_datetime(sess_df['timestamp'], unit='s').dt.date
n_date = len(pd.unique(date_obj))
if dataset_ == 'Taobao':
splitdate = sess_df['timestamp'].max() - (86400 * 0.5)
else:
splitdate = sess_df['timestamp'].max()-(86400*int(0.2*n_date))
sess_tra_df = sess_df.loc[sess_df['timestamp'] < splitdate].reset_index(drop=True)
sess_tes_df = sess_df.loc[sess_df['timestamp'] >= splitdate].reset_index(drop=True)
sess_tra_df_ = sess_tra_df[['sessionId','userId','itemId','i_timestamp']]
sess_tra_df_ = sess_tra_df_.set_index(['sessionId','userId']).apply(pd.Series.explode).reset_index()
sess_tra_df_.rename(columns={'sessionId':'SessionId','userId':'UserId', 'itemId':'ItemId','i_timestamp':'Time'}, inplace=True)
sess_tes_df_ = sess_tes_df[['sessionId','userId','itemId','i_timestamp']]
sess_tes_df_ = sess_tes_df_.set_index(['sessionId','userId']).apply(pd.Series.explode).reset_index()
sess_tes_df_.rename(columns={'sessionId':'SessionId','userId':'UserId', 'itemId':'ItemId','i_timestamp':'Time'}, inplace=True)
# todo 有history的情況下能DA?
sess_tra = sess_tra_df['itemId'].tolist()
sess_tra_m = sess_tra_df['timestamp'].tolist()
sess_tra_hist = sess_tra_df['history'].tolist()
sess_te = sess_tes_df['itemId'].tolist()
sess_te_m = sess_tes_df['timestamp'].tolist()
sess_te_hist = sess_tes_df['history'].tolist()
# 產生根據上方 reindex的session graph, 轉成 data augment形式的資料
tr_seqs, tr_dates, tr_hist, tr_labs, tr_ids, _ = process_seqs(sess_tra, sess_tra_m, sess_tra_hist)
te_seqs, te_dates, te_hist, te_labs, te_ids, _ = process_seqs(sess_te, sess_te_m, sess_te_hist)
print('# of tra: '+str(len(tr_seqs)))
print('# of tes: '+str(len(te_seqs)))
tra = (tr_seqs, tr_labs, tr_hist) # session, target, history
tes = (te_seqs, te_labs, te_hist)
if save:
with open(dataset + '_hist/info.txt', 'w') as f:
f.write('user need to view how many items: ' + str(min_item) + '\n')
f.write('user need # of sessions: ' + str(min_sess) + '\n')
f.write('item must be seen by: ' + str(min_user) + '\n')
f.write('min len of session: ' + str(min_sess_len) + '\n')
f.write('min # of hist sessions: ' + str(min_hist_len) + '\n')
f.write('# of sessions per user: ' + str(sess_df.sessionId.nunique() / sess_df.userId.nunique()) + '\n')
f.write('# of items: ' + str(max_itemid) + '\n')
f.write('# of users: ' + str(len(le.classes_)) + '\n')
f.write('Clicks matched: ' + str(len(sess_df)) + '\n')
f.write('Avg Sess len: ' + str(sum([len(i) for i in sess]) / len(sess)) + '\n')
# for Recbole
sess_tra_df[['sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_train_full.csv', index=False)
sess_tes_df[['sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_test.csv', index=False)
# for another session framework
sess_tra_df[['userId','sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_train_full_u.csv', index=False)
sess_tes_df[['userId','sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_test_u.csv', index=False)
sess_tra_df_.to_csv(dataset + '_hist/' + dataset + '_train_full_u_.csv', index=False)
sess_tes_df_.to_csv(dataset + '_hist/' + dataset + '_test_u_.csv', index=False)
# no DA, EII-GNN
pickle.dump(tra, open(dataset + '_hist/train.txt', 'wb'))
pickle.dump(tes, open(dataset + '_hist/test.txt', 'wb'))
# pickle.dump(id_cat, open(dataset + '_hist/item_cate.txt', 'wb'))
# for DGTN
# with DA for other models
# tr_seqs, tr_dates, tr_labs, tr_ids, tr_unaug_idx = process_seqs(sess_tra, sess_tra_m, data_aug=True)
# te_seqs, te_dates, te_labs, te_ids, te_unaug_idx = process_seqs(sess_te, sess_te_m, data_aug=True)
# tra = (tr_seqs, tr_labs) # session, target
# tes = (te_seqs, te_labs)
# if not os.path.exists(dataset + '_users/raw'):
# os.makedirs(dataset + '_users/raw')
# pickle.dump(tra, open(dataset + '_users/raw/train.txt', 'wb'))
# pickle.dump(tes, open(dataset + '_users/raw/test.txt', 'wb'))
# # 存no DA在DA的index
# tr_seqs, tr_dates, tr_labs, _, _ = process_seqs(sess_tra, sess_tra_m, gen_target=False)
# te_seqs, te_dates, te_labs, _, _ = process_seqs(sess_te, sess_te_m, gen_target=False)
# tra = (tr_seqs, tr_unaug_idx)
# tes = (te_seqs, te_unaug_idx)
# pickle.dump(tra, open(dataset + '_users/unaug_train.txt', 'wb'))
# pickle.dump(tes, open(dataset + '_users/unaug_test.txt', 'wb'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment