Last active
September 1, 2022 10:44
-
-
Save dhuynh95/4357aec425bd30fbb41db0bc6ce0f8b2 to your computer and use it in GitHub Desktop.
Preprocessing for GPT2 text generation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired from https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb | |
def get_example_inputs(example, tokenizer): | |
tokenizer.padding_side = "left" | |
tokenizer.pad_token = tokenizer.eos_token | |
max_length = 64 | |
num_attention_heads, hidden_size, num_layer = 12, 768, 12 | |
encodings_dict = tokenizer.batch_encode_plus(example, padding='max_length', max_length=max_length) | |
input_ids = torch.tensor(encodings_dict["input_ids"], dtype=torch.int32) | |
attention_mask = torch.tensor(encodings_dict["attention_mask"], dtype=torch.int32) | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(position_ids < 0, 0) | |
position_ids = position_ids.to(torch.int32) | |
# Empty Past State for generating first word | |
empty_past = [] | |
batch_size = input_ids.size(0) | |
sequence_length = input_ids.size(1) | |
past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads] | |
for i in range(num_layer): | |
empty_past.append(torch.empty(past_shape).type(torch.float32)) | |
input_list = [input_ids, attention_mask, position_ids] | |
for i in range(len(empty_past)): | |
input_list.append(empty_past[0]) | |
return input_list |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment