Skip to content

Instantly share code, notes, and snippets.

@eusip
Created April 12, 2023 08:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eusip/007c0bcd2f4ca0fe18e88c970b4d9bb0 to your computer and use it in GitHub Desktop.
Save eusip/007c0bcd2f4ca0fe18e88c970b4d9bb0 to your computer and use it in GitHub Desktop.
T5 Tokenization
import time
import torch
import torch.nn.functional as F
from tqdm import trange
from transformers import AutoTokenizer
from onnxruntime import InferenceSession
class GenerativeT5(torch.nn.Module):
def __init__(self, encoder, decoder_with_lm_head, tokenizer):
super().__init__()
self.encoder = encoder
self.decoder_with_lm_head = decoder_with_lm_head
self.tokenizer = tokenizer
def forward(self, prompt, max_length, temperature=1., repetition_penalty=1., top_k=50, top_p=0, max_context_length=512):
with torch.no_grad():
new_tokens = torch.tensor(())
new_logits = []
generated = torch.tensor(self.tokenizer(prompt)['input_ids'])[:max_context_length - 1].unsqueeze(0)
temperature = temperature
encoder_outputs_prompt = self.encoder.run(None, {"input_ids": generated.cpu().numpy()})[0]
repetition_penalty = repetition_penalty
top_k = top_k
top_p = top_p
# The sequence now needs to start with a
generated = torch.zeros((1,1), dtype=torch.long)
for _ in trange(max_length):
outputs = torch.tensor(self.decoder_with_lm_head.run(None, {"input_ids": generated.cpu().numpy(), "encoder_hidden_states": encoder_outputs_prompt})[0][0])
next_token_logits = outputs[-1, :] / (temperature if temperature > 0 else 1.0)
if int(next_token_logits.argmax()) == 1:
break
new_logits.append(next_token_logits)
for _ in set(generated.view(-1).tolist()):
next_token_logits[_] /= repetition_penalty
if temperature == 0: # greedy sampling:
next_token = torch.argmax(next_token_logits).unsqueeze(0)
else:
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
new_tokens = torch.cat((new_tokens, next_token), 0)
print("next tokens shape", next_token.shape)
print("new tokens shape", new_tokens.shape)
return self.tokenizer.decode(new_tokens), new_logits
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
decoder_sess = InferenceSession("/mnt/training/mirage-onnx/no_opt_alt/-decoder-with-lm-head.onnx", providers=['CUDAExecutionProvider'])
encoder_sess = InferenceSession("/mnt/training/mirage-onnx/no_opt_alt/-encoder.onnx", providers=['CUDAExecutionProvider'])
t5 = GenerativeT5(encoder_sess, decoder_sess, tokenizer)
prompt = """<|prompter|>[TRANSCRIPT]"""
while True:
start_time = time.time()
output_text, output_logits = flan_t5(prompt, max_length=512, temperature=0.)
print(output_text)
print("--- %s seconds ---" % (time.time() - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment