Created
May 2, 2021 10:50
-
-
Save johnny12150/083ccda234056839ac73bb396fe44b18 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import pandas as pd | |
from sklearn.preprocessing import LabelEncoder | |
from tqdm import tqdm | |
import copy | |
def process_seqs(iseqs, idates, ihist=None, data_aug=False, gen_target=True): | |
''' | |
產生data | |
:param iseqs: 所有原始session | |
:param idates: 該session發生時間 | |
:param ihist: 該session過去的紀錄 | |
:param data_aug: 是否做augmentation | |
:param gen_target: 是否要產答案 | |
:return: | |
''' | |
out_seqs = [] | |
out_dates = [] | |
out_hists = [] | |
labs = [] | |
ids, un_aug_ids = [], [] | |
count = 0 | |
if ihist: | |
zip_input = zip(range(len(iseqs)), iseqs, idates, ihist) | |
else: | |
zip_input = zip(range(len(iseqs)), iseqs, idates, [None]*len(iseqs)) | |
for id, seq, date, hist in zip_input: | |
for i in range(1, len(seq)): | |
if i == 1: | |
un_aug_ids += [count] | |
count += 1 | |
if gen_target: | |
# 最後一個step當答案 | |
tar = seq[-i] | |
labs += [tar] | |
out_seqs += [seq[:-i]] | |
else: | |
out_seqs += [seq] | |
out_dates += [date] # 紀錄每個session的月份 | |
if ihist: | |
out_hists += [hist] | |
ids += [id] # 原始session id | |
# 設定要不要做 data augmentation | |
if not data_aug: | |
break | |
if ihist: | |
return out_seqs, out_dates, out_hists, labs, ids, un_aug_ids | |
else: | |
return out_seqs, out_dates, labs, ids, un_aug_ids | |
dataset_ = 'trivago' | |
save = True | |
if dataset_ == 'jdata2019': | |
dataset_path = './jdata2019/jdata_action.csv' | |
df = pd.read_csv(dataset_path, delimiter=',') | |
dates = pd.to_datetime(df['action_time']) | |
df['action_time'] = (dates - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s') # datetime to unix | |
df['day'] = dates.dt.date | |
df.rename(columns={'user_id': 'userId', 'sku_id': 'itemId', 'action_time': 'timestamp'}, inplace=True) | |
df['sessionId'] = df.groupby(['userId', 'day']).ngroup() | |
elif dataset_ == 'Retailrocket': | |
dataset_path = './Retailrocket/events.csv' | |
df = pd.read_csv(dataset_path, delimiter=',') | |
df = df[df['event']=='view'] | |
df = df.reset_index() | |
df = df.drop(columns=['event','transactionid']) | |
df['timestamp'] = (df['timestamp']/1000).astype('int') | |
dates = pd.to_datetime(df['timestamp'], unit='s') | |
df['date_hour'] = dates.dt.floor('h') | |
df.rename(columns={'visitorid': 'userId', 'itemid': 'itemId'}, inplace=True) | |
df['sessionId'] = df.groupby(['userId', 'date_hour']).ngroup() | |
elif dataset_ == 'nowplaying': | |
dataset_path = './nowplaying/user_track_hashtag_timestamp.csv' | |
df = pd.read_csv(dataset_path, delimiter=',') | |
df['hashtag'] = df['hashtag'].str.lower() | |
df = df[df['hashtag']=='nowplaying'] | |
df = df.reset_index() | |
df = df[['user_id', 'track_id', 'created_at']] | |
df.rename(columns={'user_id': 'userId', 'track_id': 'itemId', 'created_at':'timestamp'}, inplace=True) | |
df['timestamp'] = (pd.to_datetime(df['timestamp']) - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s') | |
dates = pd.to_datetime(df['timestamp'], unit='s') | |
df['date_hour'] = dates.dt.floor('h') | |
# df['day'] = dates.dt.date | |
df['sessionId'] = df.groupby(['userId', 'date_hour']).ngroup() | |
elif dataset_ == 'Taobao': | |
dataset_path = './Taobao/UserBehavior-withHeader.csv' | |
df = pd.read_csv(dataset_path, delimiter=',') | |
df.rename(columns={'sessionId': 'userId'}, inplace=True) | |
df = df[['userId', 'itemId', 'timestamp']] | |
# 11/30 1511971200 12/4 1512316800 | |
df = df[(df['timestamp'] >= 1511539200) & (df['timestamp'] < 1511971200)] | |
dates = pd.to_datetime(df['timestamp'], unit='s') | |
df['date_hour'] = dates.dt.floor('h') | |
# df['day'] = dates.dt.date | |
df['sessionId'] = df.groupby(['userId', 'date_hour']).ngroup() | |
elif dataset_ == 'trivago': | |
dataset_path = 'trivago.csv' | |
df = pd.read_csv(dataset_path, delimiter=',') | |
df.rename(columns={'user_id': 'userId', 'session_id': 'sessionId', 'reference': 'itemId'}, inplace=True) | |
dataset = './' + dataset_ | |
if not os.path.exists(dataset + '_hist') and save: | |
os.makedirs(dataset + '_hist') | |
min_item = 10 # user至少要看過幾個東西 | |
min_user = 50 # item至少被user看過幾次 | |
min_sess = 0 # user至少要有幾個session | |
min_sess_len = 3 # >= | |
min_hist_len = 1 # >= | |
print('Clicks Count:', str(df.shape[0])) | |
#shortcut = False | |
#dataset = './' + dataset | |
#if not os.path.exists(dataset + '_users') and save: | |
# os.makedirs(dataset + '_users') | |
df_with_user = df.dropna(subset=['userId'], axis=0) | |
del df | |
# 過濾太少出現的 user & item | |
frequent_user = df_with_user['userId'].value_counts() | |
frequent_user = frequent_user[frequent_user > min_item].index.tolist() # 至少買過三件的user | |
frequent_user = df_with_user.loc[df_with_user['userId'].isin(frequent_user)] | |
frequent_user_idx = set(frequent_user.index.tolist()) | |
frequent_item = df_with_user['itemId'].value_counts() | |
frequent_item = frequent_item[frequent_item > min_user].index.tolist() # 至少被買過5次的item | |
frequent_item = df_with_user.loc[df_with_user['itemId'].isin(frequent_item)] | |
frequent_item_idx = set(frequent_item.index.tolist()) | |
# 過濾掉session數太少的user | |
if min_sess > 0: | |
multi_sess = df_with_user.groupby('userId')['sessionId'].apply(lambda x: len(x)) | |
multi_sess_user = multi_sess[multi_sess.gt(min_sess)].index.tolist() | |
multi_sess_user = df_with_user.loc[df_with_user['userId'].isin(multi_sess_user)] | |
multi_sess_index = set(multi_sess_user.index.tolist()) | |
intersect_idx = list(set.intersection(frequent_item_idx, frequent_user_idx, multi_sess_index)) | |
else: | |
intersect_idx = list(set(frequent_item_idx).intersection(set(frequent_user_idx))) | |
df_with_user = df_with_user.loc[df_with_user.index.isin(intersect_idx)] | |
print('# of sessions per user: ', str(df_with_user.sessionId.nunique()/df_with_user.userId.nunique())) | |
# 對user跟item id reindex | |
le = LabelEncoder() # 從1開始編碼 | |
df_with_user['itemId'] = le.fit_transform(df_with_user['itemId']) + 1 | |
max_itemid = len(le.classes_) + 1 | |
print('# of items: ', str(max_itemid)) | |
# new item id map category | |
#cat_dataset = './jdata2019/jdata_product.csv' | |
#cat_data = pd.read_csv(cat_dataset, index_col=False) | |
#id_cat = {} | |
#for i in range(1,max_itemid): | |
# i_ = le.inverse_transform([i-1]) | |
# if not cat_data[cat_data['sku_id']==int(i_[0])].empty: | |
# id_cat[i] = cat_data[cat_data['sku_id']==int(i_[0])]['cate'].values[0] | |
# else: | |
# id_cat[i] = 82 | |
le = LabelEncoder() | |
df_with_user['userId'] = le.fit_transform(df_with_user['userId']) + max_itemid | |
print('# of users: ', str(len(le.classes_))) | |
print('total_nodes: ', str(len(le.classes_) + max_itemid)) | |
# 根據session, time排序, 然後group起來轉成list | |
df_with_user.sort_values(['sessionId', 'timestamp'], ascending=[True, True], inplace=True) | |
# 過濾掉長度小於3的session | |
sess_grouped_mask = df_with_user.groupby(['sessionId']).transform('count') >= min_sess_len | |
# mask完是true false | |
sess_grouped_mask['sessionId'] = sess_grouped_mask['userId'] # 把剛剛當group key的session id補回來 | |
sess_ = df_with_user[sess_grouped_mask].dropna(axis=0) # drop掉被mask認為該丟掉的 | |
sess_g = sess_.groupby(['sessionId']) | |
# 或是直接從groupby找最後的userid & ts | |
sess_ts = sess_g['timestamp'].last().reset_index(drop=True) | |
sess_user = sess_g['userId'].first().reset_index(drop=True) | |
sess_df = sess_g['itemId'].apply(list).reset_index() # add session_id back | |
sess_its = sess_g['timestamp'].apply(list).reset_index() | |
sess_its.rename(columns={'timestamp':'i_timestamp'}, inplace=True) | |
sess_its = sess_its.drop(columns=['sessionId']) | |
# sess = sess_g['itemId'].apply(list).tolist() # 把不同steps的item id根據session id整合起來成一個session | |
sess_df = pd.concat([sess_df, sess_user, sess_ts, sess_its], axis=1) # combine multi series | |
del sess_ts, sess_user, sess_its | |
# iter thru sessions, 找同user的歷史sessions | |
sess_df.sort_values(['timestamp'], ascending=[True], inplace=True) | |
sess_df.reset_index(drop=True, inplace=True) | |
sess_df['history'] = '' | |
sess_df['count_history'] = 0 | |
user_history_session = {u: [] for u in sess_df.userId.unique().tolist()} | |
for i, row in tqdm(sess_df.iterrows(), total=len(sess_df)): | |
user = row.userId | |
sess_df.at[i, 'history'] = copy.deepcopy(user_history_session[user][-30:]) | |
sess_df.at[i, 'count_history'] = len(user_history_session[user][-30:]) | |
user_history_session[user].append(row.itemId) | |
sess_df = sess_df.loc[sess_df['count_history'] >= min_hist_len].reset_index(drop=True) | |
sess = sess_df.itemId.tolist() | |
sess_hist = sess_df.history.tolist() | |
sess_count = sess_df.count_history.tolist() | |
print('Avg Sess len: ', str(sum([len(i) for i in sess]) / len(sess))) | |
date_obj = pd.to_datetime(sess_df['timestamp'], unit='s').dt.date | |
n_date = len(pd.unique(date_obj)) | |
if dataset_ == 'Taobao': | |
splitdate = sess_df['timestamp'].max() - (86400 * 0.5) | |
else: | |
splitdate = sess_df['timestamp'].max()-(86400*int(0.2*n_date)) | |
sess_tra_df = sess_df.loc[sess_df['timestamp'] < splitdate].reset_index(drop=True) | |
sess_tes_df = sess_df.loc[sess_df['timestamp'] >= splitdate].reset_index(drop=True) | |
sess_tra_df_ = sess_tra_df[['sessionId','userId','itemId','i_timestamp']] | |
sess_tra_df_ = sess_tra_df_.set_index(['sessionId','userId']).apply(pd.Series.explode).reset_index() | |
sess_tra_df_.rename(columns={'sessionId':'SessionId','userId':'UserId', 'itemId':'ItemId','i_timestamp':'Time'}, inplace=True) | |
sess_tes_df_ = sess_tes_df[['sessionId','userId','itemId','i_timestamp']] | |
sess_tes_df_ = sess_tes_df_.set_index(['sessionId','userId']).apply(pd.Series.explode).reset_index() | |
sess_tes_df_.rename(columns={'sessionId':'SessionId','userId':'UserId', 'itemId':'ItemId','i_timestamp':'Time'}, inplace=True) | |
# todo 有history的情況下能DA? | |
sess_tra = sess_tra_df['itemId'].tolist() | |
sess_tra_m = sess_tra_df['timestamp'].tolist() | |
sess_tra_hist = sess_tra_df['history'].tolist() | |
sess_te = sess_tes_df['itemId'].tolist() | |
sess_te_m = sess_tes_df['timestamp'].tolist() | |
sess_te_hist = sess_tes_df['history'].tolist() | |
# 產生根據上方 reindex的session graph, 轉成 data augment形式的資料 | |
tr_seqs, tr_dates, tr_hist, tr_labs, tr_ids, _ = process_seqs(sess_tra, sess_tra_m, sess_tra_hist) | |
te_seqs, te_dates, te_hist, te_labs, te_ids, _ = process_seqs(sess_te, sess_te_m, sess_te_hist) | |
print('# of tra: '+str(len(tr_seqs))) | |
print('# of tes: '+str(len(te_seqs))) | |
tra = (tr_seqs, tr_labs, tr_hist) # session, target, history | |
tes = (te_seqs, te_labs, te_hist) | |
if save: | |
with open(dataset + '_hist/info.txt', 'w') as f: | |
f.write('user need to view how many items: ' + str(min_item) + '\n') | |
f.write('user need # of sessions: ' + str(min_sess) + '\n') | |
f.write('item must be seen by: ' + str(min_user) + '\n') | |
f.write('min len of session: ' + str(min_sess_len) + '\n') | |
f.write('min # of hist sessions: ' + str(min_hist_len) + '\n') | |
f.write('# of sessions per user: ' + str(sess_df.sessionId.nunique() / sess_df.userId.nunique()) + '\n') | |
f.write('# of items: ' + str(max_itemid) + '\n') | |
f.write('# of users: ' + str(len(le.classes_)) + '\n') | |
f.write('Clicks matched: ' + str(len(sess_df)) + '\n') | |
f.write('Avg Sess len: ' + str(sum([len(i) for i in sess]) / len(sess)) + '\n') | |
# for Recbole | |
sess_tra_df[['sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_train_full.csv', index=False) | |
sess_tes_df[['sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_test.csv', index=False) | |
# for another session framework | |
sess_tra_df[['userId','sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_train_full_u.csv', index=False) | |
sess_tes_df[['userId','sessionId', 'itemId', 'timestamp']].to_csv(dataset + '_hist/' + dataset + '_test_u.csv', index=False) | |
sess_tra_df_.to_csv(dataset + '_hist/' + dataset + '_train_full_u_.csv', index=False) | |
sess_tes_df_.to_csv(dataset + '_hist/' + dataset + '_test_u_.csv', index=False) | |
# no DA, EII-GNN | |
pickle.dump(tra, open(dataset + '_hist/train.txt', 'wb')) | |
pickle.dump(tes, open(dataset + '_hist/test.txt', 'wb')) | |
# pickle.dump(id_cat, open(dataset + '_hist/item_cate.txt', 'wb')) | |
# for DGTN | |
# with DA for other models | |
# tr_seqs, tr_dates, tr_labs, tr_ids, tr_unaug_idx = process_seqs(sess_tra, sess_tra_m, data_aug=True) | |
# te_seqs, te_dates, te_labs, te_ids, te_unaug_idx = process_seqs(sess_te, sess_te_m, data_aug=True) | |
# tra = (tr_seqs, tr_labs) # session, target | |
# tes = (te_seqs, te_labs) | |
# if not os.path.exists(dataset + '_users/raw'): | |
# os.makedirs(dataset + '_users/raw') | |
# pickle.dump(tra, open(dataset + '_users/raw/train.txt', 'wb')) | |
# pickle.dump(tes, open(dataset + '_users/raw/test.txt', 'wb')) | |
# # 存no DA在DA的index | |
# tr_seqs, tr_dates, tr_labs, _, _ = process_seqs(sess_tra, sess_tra_m, gen_target=False) | |
# te_seqs, te_dates, te_labs, _, _ = process_seqs(sess_te, sess_te_m, gen_target=False) | |
# tra = (tr_seqs, tr_unaug_idx) | |
# tes = (te_seqs, te_unaug_idx) | |
# pickle.dump(tra, open(dataset + '_users/unaug_train.txt', 'wb')) | |
# pickle.dump(tes, open(dataset + '_users/unaug_test.txt', 'wb')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment