Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Created October 3, 2017 11:48
Show Gist options
  • Save thomwolf/f58def98e8a61117f7214492c5d72b34 to your computer and use it in GitHub Desktop.
Save thomwolf/f58def98e8a61117f7214492c5d72b34 to your computer and use it in GitHub Desktop.
A simple pyTorch Dataset class
class DeepMojiDataset(Dataset):
""" A simple Dataset class.
# Arguments:
X_in: Inputs of the given dataset.
y_in: Outputs of the given dataset.
# __getitem__ output:
(torch.LongTensor, torch.LongTensor)
"""
def __init__(self, X_in, y_in):
# Check if we have Torch.LongTensor inputs (assume Numpy array otherwise)
if not isinstance(X_in, torch.LongTensor):
X_in = torch.from_numpy(X_in.astype('int64')).long()
if not isinstance(y_in, torch.LongTensor):
y_in = torch.from_numpy(y_in.astype('int64')).long()
self.X_in = torch.split(X_in, 1, dim=0)
self.y_in = torch.split(y_in, 1, dim=0)
def __len__(self):
return len(self.X_in)
def __getitem__(self, idx):
return self.X_in[idx].squeeze(), self.y_in[idx].squeeze()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment