Skip to content

Instantly share code, notes, and snippets.

@dhuynh95
Last active September 1, 2022 10:44
Show Gist options
  • Save dhuynh95/4357aec425bd30fbb41db0bc6ce0f8b2 to your computer and use it in GitHub Desktop.
Save dhuynh95/4357aec425bd30fbb41db0bc6ce0f8b2 to your computer and use it in GitHub Desktop.
Preprocessing for GPT2 text generation.
# Inspired from https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb
def get_example_inputs(example, tokenizer):
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
max_length = 64
num_attention_heads, hidden_size, num_layer = 12, 768, 12
encodings_dict = tokenizer.batch_encode_plus(example, padding='max_length', max_length=max_length)
input_ids = torch.tensor(encodings_dict["input_ids"], dtype=torch.int32)
attention_mask = torch.tensor(encodings_dict["attention_mask"], dtype=torch.int32)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(position_ids < 0, 0)
position_ids = position_ids.to(torch.int32)
# Empty Past State for generating first word
empty_past = []
batch_size = input_ids.size(0)
sequence_length = input_ids.size(1)
past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
for i in range(num_layer):
empty_past.append(torch.empty(past_shape).type(torch.float32))
input_list = [input_ids, attention_mask, position_ids]
for i in range(len(empty_past)):
input_list.append(empty_past[0])
return input_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment