Skip to content

Instantly share code, notes, and snippets.

@xeb
Created March 9, 2023 22:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xeb/f48770e65118d2c544e2901c21e58f9c to your computer and use it in GitHub Desktop.
Save xeb/f48770e65118d2c544e2901c21e58f9c to your computer and use it in GitHub Desktop.
Quick demo of text completion using GPT-J
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
print("Model Loaded..!")
def infer(input_text="This is a sample of", temperature=0.8, max_length=150):
start_time = time.time()
length = len(input_text)
max_max_length = 2048
if length + max_length > max_max_length:
max_length = max_max_length - length
print(f"Using max_length of {max_length}")
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"]
output = model.generate(
input_ids,
attention_mask=inputs["attention_mask"],
do_sample=True,
max_length=max_length,
temperature=temperature,
use_cache=True,
)
end_time = time.time() - start_time
print("Total Taken => ",end_time)
output = tokenizer.decode(output[0])
print(output)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment