Skip to content

Instantly share code, notes, and snippets.

@KyleOng
Created April 8, 2021 06:59
Show Gist options
  • Save KyleOng/e70c80c49991613eaa8acc7c238576f5 to your computer and use it in GitHub Desktop.
Save KyleOng/e70c80c49991613eaa8acc7c238576f5 to your computer and use it in GitHub Desktop.
Pytorch dataloader for sparse tensor
from typing import Union
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from scipy.sparse import (random,
coo_matrix,
csr_matrix,
vstack)
from tqdm import tqdm
class SparseDataset(Dataset):
"""
Custom Dataset class for scipy sparse matrix
"""
def __init__(self, data:Union[np.ndarray, coo_matrix, csr_matrix],
targets:Union[np.ndarray, coo_matrix, csr_matrix],
transform:bool = None):
# Transform data coo_matrix to csr_matrix for indexing
if type(data) == coo_matrix:
self.data = data.tocsr()
else:
self.data = data
# Transform targets coo_matrix to csr_matrix for indexing
if type(targets) == coo_matrix:
self.targets = targets.tocsr()
else:
self.targets = targets
self.transform = transform # Can be removed
def __getitem__(self, index:int):
return self.data[index], self.targets[index]
def __len__(self):
return self.data.shape[0]
def sparse_coo_to_tensor(coo:coo_matrix):
"""
Transform scipy coo matrix to pytorch sparse tensor
"""
values = coo.data
indices = np.vstack((coo.row, coo.col))
shape = coo.shape
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
s = torch.Size(shape)
return torch.sparse.FloatTensor(i, v, s)
def sparse_batch_collate(batch:list):
"""
Collate function which to transform scipy coo matrix to pytorch sparse tensor
"""
data_batch, targets_batch = zip(*batch)
if type(data_batch[0]) == csr_matrix:
data_batch = vstack(data_batch).tocoo()
data_batch = sparse_coo_to_tensor(data_batch)
else:
data_batch = torch.FloatTensor(data_batch)
if type(targets_batch[0]) == csr_matrix:
targets_batch = vstack(targets_batch).tocoo()
targets_batch = sparse_coo_to_tensor(targets_batch)
else:
targets_batch = torch.FloatTensor(targets_batch)
return data_batch, targets_batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment