Skip to content

Instantly share code, notes, and snippets.

@priya-dwivedi
Created April 5, 2022 01:03
Show Gist options
  • Save priya-dwivedi/420dfa05ab8ffef4556d27294cbe88eb to your computer and use it in GitHub Desktop.
Save priya-dwivedi/420dfa05ab8ffef4556d27294cbe88eb to your computer and use it in GitHub Desktop.
Tokenizer for GEC model
class GrammarDataset(Dataset):
def __init__(self, dataset, tokenizer,print_text=False):
self.dataset = dataset
self.pad_to_max_length = False
self.tokenizer = tokenizer
self.print_text = print_text
self.max_len = 64
def __len__(self):
return len(self.dataset)
def tokenize_data(self, example):
input_, target_ = example['input'], example['output']
# tokenize inputs
tokenized_inputs = tokenizer(input_, pad_to_max_length=self.pad_to_max_length,
max_length=self.max_len,
return_attention_mask=True)
tokenized_targets = tokenizer(target_, pad_to_max_length=self.pad_to_max_length,
max_length=self.max_len,
return_attention_mask=True)
inputs={"input_ids": tokenized_inputs['input_ids'],
"attention_mask": tokenized_inputs['attention_mask'],
"labels": tokenized_targets['input_ids']
}
return inputs
def __getitem__(self, index):
inputs = self.tokenize_data(self.dataset[index])
if self.print_text:
for k in inputs.keys():
print(k, len(inputs[k]))
return inputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment