Skip to content

Instantly share code, notes, and snippets.

@ArthurDelannoyazerty
Created May 17, 2024 09:52
Show Gist options
  • Save ArthurDelannoyazerty/00f4c3b90c23258c87ed16d6fe69b2cd to your computer and use it in GitHub Desktop.
Save ArthurDelannoyazerty/00f4c3b90c23258c87ed16d6fe69b2cd to your computer and use it in GitHub Desktop.
Example of a torch dataset that load sparse matrix.
class CustomImageDatasetNumpy(Dataset):
def __init__(self, dir_path:str):
self.dir_path = dir_path
self.filenames = os.listdir(dir_path)
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
np_path = self.dir_path + "/" + self.filenames[idx]
sparse_array = sparse.load_npz(np_path)
np_array = sparse_array.toarray()
t = Tensor(np_array)
return t, 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment