Last active
October 6, 2019 08:57
-
-
Save kabirahuja2431/860128fa1dfe4611bbd3f3bea7d9b708 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset | |
class CustomDataset(Dataset): | |
# A pytorch dataset class for holding data for a text classification task. | |
def __init__(self, filename): | |
''' | |
Takes as input the name of a file containing sentences with a classification label (comma separated) in each line. | |
Stores the text data in a member variable X and labels in y | |
''' | |
#Opening the file and storing its contents in a list | |
with open(filename) as f: | |
lines = f.read().split('\n') | |
#Splitting the text data and lables from each other | |
X, y = [], [] | |
for line in lines: | |
X.append(line.split(',')[0]) | |
y.append(line.split(',')[1]) | |
#Store them in member variables. | |
self.X = X | |
self.y = y | |
def preprocess(self, text): | |
### Do something with text here | |
text_pp = text.lower().strip() | |
### | |
return text_pp | |
def __len__(self): | |
return len(self.y) | |
def __getitem__(self, index): | |
''' | |
Returns the text and labels present at the specified index of the lists. | |
''' | |
return self.preprocess(self.X[index]), self.y[index] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment