Skip to content

Instantly share code, notes, and snippets.

@kusal1990
Created June 25, 2022 07:30
Show Gist options
  • Save kusal1990/c8d0ab0f9714fdcfd8d319784e25adff to your computer and use it in GitHub Desktop.
Save kusal1990/c8d0ab0f9714fdcfd8d319784e25adff to your computer and use it in GitHub Desktop.
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained('t5-small')
def arc_preprocessor(dataset, tokenizer):
'''
This function will convert a given context, question, choices in a format:
input: question \n options \n context </s>
target: label </s>
After converting in this format the data will be tokenized using a given tokenizer.
This function will return 4 arrays namely, input_ids, attention_mask, token_type_ids and labels.
'''
global MAX_LEN
all_input_ids = []
all_attention_mask = []
all_decoder_input_ids = []
all_labels = []
for i in range(len(dataset)):
context = ' '.join(dataset['context'].iloc[i].split()[:350]) #Limiting the maximum length of context to 350 words.
question = dataset['only_question'].iloc[i]
options = dataset['only_answers'].iloc[i]
target = dataset['Answer'].iloc[i]
choice_features = []
input_string = question + ' ' + '\\n' + ' ' + options + ' ' + '</s>'
decoder_input = tokenizer.pad_token + ' ' + target
target = target + ' ' + '</s>'
input_ids = tokenizer.encode(input_string, truncation=True, max_length=MAX_LEN)
decoder_input_ids = tokenizer.encode(decoder_input, max_length=decoder_max_len, truncation=True) #Max length of a answer is 23
labels = tokenizer.encode(target, max_length=decoder_max_len, truncation=True)
attention_mask = [1] * len(input_ids)
padding_id = tokenizer.pad_token_id
padding_length = MAX_LEN - len(input_ids)
input_ids = input_ids + [padding_id]*padding_length
attention_mask = attention_mask + [0]*padding_length
deocder_padding_length = decoder_max_len - len(decoder_input_ids)
decoder_input_ids = decoder_input_ids + [padding_id]*deocder_padding_length
labels = labels + [padding_id]*deocder_padding_length
assert len(input_ids) == MAX_LEN
assert len(attention_mask) == MAX_LEN
all_input_ids.append(np.asarray(input_ids, dtype='int32'))
all_attention_mask.append(np.asarray(attention_mask, dtype='int32'))
all_decoder_input_ids.append(np.asarray(decoder_input_ids, dtype='int32'))
all_labels.append(np.asarray(labels, dtype='int32'))
return all_input_ids, all_attention_mask, all_decoder_input_ids, all_labels
@kusal1990
Copy link
Author

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment