Skip to content

Instantly share code, notes, and snippets.

@kabirahuja2431
Last active October 6, 2019 08:57
Show Gist options
  • Save kabirahuja2431/860128fa1dfe4611bbd3f3bea7d9b708 to your computer and use it in GitHub Desktop.
Save kabirahuja2431/860128fa1dfe4611bbd3f3bea7d9b708 to your computer and use it in GitHub Desktop.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
# A pytorch dataset class for holding data for a text classification task.
def __init__(self, filename):
'''
Takes as input the name of a file containing sentences with a classification label (comma separated) in each line.
Stores the text data in a member variable X and labels in y
'''
#Opening the file and storing its contents in a list
with open(filename) as f:
lines = f.read().split('\n')
#Splitting the text data and lables from each other
X, y = [], []
for line in lines:
X.append(line.split(',')[0])
y.append(line.split(',')[1])
#Store them in member variables.
self.X = X
self.y = y
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def __len__(self):
return len(self.y)
def __getitem__(self, index):
'''
Returns the text and labels present at the specified index of the lists.
'''
return self.preprocess(self.X[index]), self.y[index]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment