Skip to content

Instantly share code, notes, and snippets.

@KexinFeng
Last active December 7, 2023 05:03
Show Gist options
  • Save KexinFeng/4876c6bfb27f40abffe4d5a92c02acff to your computer and use it in GitHub Desktop.
Save KexinFeng/4876c6bfb27f40abffe4d5a92c02acff to your computer and use it in GitHub Desktop.
GPT2 model tracing
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model_name = 'gpt2-large'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)
# %% model_inputs
output_attentions = False
output_hidden_states = False
model_inputs = {}
model_inputs['past_key_values'] = torch.load(
"../data/nested_tuple_" + model_name + ".pt")
past_seq = model_inputs['past_key_values'][0][0].shape[-2]
model_inputs['input_ids'] = torch.tensor([[404]])
model_inputs['position_ids'] = torch.tensor([[past_seq]])
# |attention_mask| = `len(past_key_values) + len(input_ids)`
model_inputs['attention_mask'] = torch.ones(past_seq + 1, dtype=torch.int64)
model_inputs['use_cache'] = True
model_inputs['token_type_ids'] = None
model_inputs['return_dict'] = False
model_inputs['output_attentions'] = False
model_inputs['output_hidden_states'] = False
# This is a testing of text generation
outputs = model(**model_inputs)
# %% Wrapper class of GPT2LMHeadModel
from typing import Tuple
class Tracable(torch.nn.Module):
def __init__(self, config: dict):
super().__init__()
self.model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)
self.config = {'use_cache': config.get('use_cache', True),
'token_type_ids': config.get('token_type_ids', None),
'return_dict': config.get('return_dict', False),
'output_attentions': config.get('output_attentions', False),
'output_hidden_states': config.get('output_hidden_states', True)}
def forward(self, my_input_ids, position_ids, attention_mask, past_key_values):
return self.model(input_ids=my_input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
**self.config) # return_tensor = True
# %% create class
config = {}
tracable = Tracable(config)
input = (model_inputs['input_ids'],
model_inputs['position_ids'],
model_inputs['attention_mask'],
model_inputs['past_key_values'])
output = tracable(*input)
# %% trace
tracable.eval()
traced_model = torch.jit.trace(tracable, input)
torch.jit.save(traced_model, "../traced_GPT2_hidden.pt")
out1 = traced_model(*input)
# %% load back
loaded_model = torch.jit.load("../traced_GPT2_hidden.pt")
out2 = loaded_model(*input)
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model_name = 'gpt2-large'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)
# %% model_inputs
output_attentions = False
output_hidden_states = False
model_inputs = {}
model_inputs['input_ids'] = torch.tensor([40, 2883, 6155, 351, 616, 13779, 3290])
model_inputs['position_ids'] = torch.arange(7)
model_inputs['attention_mask'] = torch.ones(7, dtype=torch.int64)
model_inputs['past_key_values'] = None
model_inputs['use_cache'] = True
model_inputs['token_type_ids'] = None
model_inputs['return_dict'] = True
model_inputs['output_attentions'] = False
model_inputs['output_hidden_states'] = False
# This is a testing of text generation
outputs = model(**model_inputs)
# %% Wrapper class of GPT2LMHeadModel
from typing import Tuple
class Tracable(torch.nn.Module):
def __init__(self, config: dict):
super().__init__()
self.model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)
self.config = {'use_cache': config.get('use_cache', True),
'token_type_ids': config.get('token_type_ids', None),
'return_dict': config.get('return_dict', False),
'output_attentions': config.get('output_attentions', False),
'output_hidden_states': config.get('output_hidden_states', True)}
def forward(self, my_input_ids, position_ids, attention_mask) -> Tuple:
return self.model(input_ids=my_input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=None,
**self.config) # return_tensor = True
# %% create class
config = {}
tracable = Tracable(config)
# %% input
input = (torch.tensor([[40, 2883, 6155, 351, 616, 13779, 3290]]),
torch.arange(7)[None, :],
torch.ones(7, dtype=torch.int64)[None, :])
output = tracable(*input)
# %% trace
tracable.eval()
traced_model = torch.jit.trace(tracable, input)
torch.jit.save(traced_model, "../traced_GPT2_init_hidden.pt")
out1 = traced_model(*input)
# %% load back
loaded_model = torch.jit.load("../traced_GPT2_init_hidden.pt")
out2 = loaded_model(*input)
# save the past_key_values for the model tracing with it
torch.save(out2[1], "../data/nested_tuple_gpt2-large.pt")
@KexinFeng
Copy link
Author

KexinFeng commented Jul 28, 2023

Update

trace_model_without_past_key_values.py , along with its output traced_GPT2_init_hidden.pt, is not actually needed; on the input with None past_key_values, traced_GPT2_hidden.pt also works. The trick is to create a dummy past_key_values, and to left-append a 0 attention_mask. The position_id remains unchanged.
See deepjavalibrary/djl#2637

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment