This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # get list of unique user ids | |
| users = sorted(list(set(ratings[user_col].values))) | |
| # get list of unique item ids | |
| items = sorted(list(set(ratings[item_col].values))) | |
| # generate dict of correponding indexes for the user ids | |
| user2idx = list_2_dict(users) | |
| # generate dict of correponding indexes for the item ids |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # get a batch -> (user, item and rating arrays) from the dataframe | |
| def get_batch(ratings, start:int, end:int): | |
| return ratings[user_col][start:end].values, ratings[item_col][start:end].values, ratings[rating_col][start:end].values |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # splits ratings dataframe to training and validation dataframes | |
| def get_data(ratings, valid_pct:float = 0.2): | |
| # shuffle the indexes | |
| ln = random.sample(range(0, len(ratings)), len(ratings)) | |
| # split based on the given validation set percentage | |
| part = int(len(ln)*valid_pct) | |
| valid_index = ln[0:part] | |
| train_index = ln[part:] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # this function returns a python dictionary | |
| # which maps each id to a corresponding index value | |
| def list_2_dict(id_list:list): | |
| d={} | |
| for id, index in zip(id_list, range(len(id_list))): | |
| d[id] = index | |
| return d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # laoding the table as a pandas dataframe | |
| ratings = pd.read_csv('ratings.csv') | |
| # getting the three column names from a pandas dataframe | |
| user_col, item_col, rating_col = ratings.columns |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # required libraries - numpy, pandas, pytorch | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import random |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # required libraries - numpy, pandas, pytorch | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import random | |
| # laoding the table as a pandas dataframe | |
| ratings = pd.read_csv('ratings.csv') |
NewerOlder