Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Last active May 1, 2022 15:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/1033d337a1009ddbb21fbee331c7be42 to your computer and use it in GitHub Desktop.
Save krsnewwave/1033d337a1009ddbb21fbee331c7be42 to your computer and use it in GitHub Desktop.
class RecoSparseTrainDataset(Dataset):
def __init__(self, sparse_mat):
self.sparse_mat = sparse_mat
def __len__(self):
return self.sparse_mat.shape[0]
def __getitem__(self, idx):
batch_matrix = self.sparse_mat[idx].toarray().squeeze()
return batch_matrix, idx
class RecoSparseTestSet(Dataset):
"""
The test dataset contains the training and test matrices.
The latter should be predicted from the training
"""
def __init__(self, train_mat, test_mat):
self.train_mat = train_mat
self.test_mat = test_mat
assert train_mat.shape == test_mat.shape
def __len__(self):
return self.train_mat.shape[0]
def __getitem__(self, idx):
train_matrix = self.train_mat[idx].toarray().squeeze()
test_matrix = self.test_mat[idx].toarray().squeeze()
return train_matrix, test_matrix, idx
class RecoSparseInferenceDataset(Dataset):
def __init__(self, sparse_mat, user_ids):
"""
sparse_mat : interaction matrix
user_ids : ids of the users (positional)
"""
self.sparse_mat = sparse_mat
self.user_ids = user_ids
assert sparse_mat.shape[0] == len(user_ids)
def __len__(self):
return self.sparse_mat.shape[0]
def __getitem__(self, idx):
batch_matrix = self.sparse_mat[idx].toarray().squeeze()
batch_ids = self.user_ids[idx]
return batch_matrix, batch_ids
###
batch_size = 512
num_workers = multiprocessing.cpu_count()
train_loader = torch.utils.data.DataLoader(RecoSparseTrainDataset(train), batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(RecoSparseTestSet(train, val), batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(RecoSparseTestSet(train, test), batch_size=batch_size, shuffle=False, num_workers=num_workers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment