Skip to content

Instantly share code, notes, and snippets.

@kabirahuja2431
Created October 9, 2019 12:50
Show Gist options
  • Save kabirahuja2431/2572a0ae090b12ef4741f58728b71ee2 to your computer and use it in GitHub Desktop.
Save kabirahuja2431/2572a0ae090b12ef4741f58728b71ee2 to your computer and use it in GitHub Desktop.
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
class SSTDataset(Dataset):
def __init__(self, filename, maxlen):
#Store the contents of the file in a pandas dataframe
self.df = pd.read_csv(filename, delimiter = '\t')
#Initialize the BERT tokenizer
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.maxlen = maxlen
def __len__(self):
return len(self.df)
def __getitem__(self, index):
#Selecting the sentence and label at the specified index in the data frame
sentence = self.df.loc[index, 'sentence']
label = self.df.loc[index, 'label']
#Preprocessing the text to be suitable for BERT
tokens = self.tokenizer.tokenize(sentence) #Tokenize the sentence
tokens = ['[CLS]'] + tokens + ['[SEP]'] #Insering the CLS and SEP token in the beginning and end of the sentence
if len(tokens) < self.maxlen:
tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))] #Padding sentences
else:
tokens = tokens[:self.maxlen-1] + ['[SEP]'] #Prunning the list to be of specified max length
tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) #Obtaining the indices of the tokens in the BERT Vocabulary
tokens_ids_tensor = torch.tensor(tokens_ids) #Converting the list to a pytorch tensor
#Obtaining the attention mask i.e a tensor containing 1s for no padded tokens and 0s for padded ones
attn_mask = (tokens_ids_tensor != 0).long()
return tokens_ids_tensor, attn_mask, label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment