Created
March 9, 2023 22:28
-
-
Save xeb/f48770e65118d2c544e2901c21e58f9c to your computer and use it in GitHub Desktop.
Quick demo of text completion using GPT-J
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import time | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
print("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") | |
print("Model Loaded..!") | |
def infer(input_text="This is a sample of", temperature=0.8, max_length=150): | |
start_time = time.time() | |
length = len(input_text) | |
max_max_length = 2048 | |
if length + max_length > max_max_length: | |
max_length = max_max_length - length | |
print(f"Using max_length of {max_length}") | |
inputs = tokenizer(input_text, return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
output = model.generate( | |
input_ids, | |
attention_mask=inputs["attention_mask"], | |
do_sample=True, | |
max_length=max_length, | |
temperature=temperature, | |
use_cache=True, | |
) | |
end_time = time.time() - start_time | |
print("Total Taken => ",end_time) | |
output = tokenizer.decode(output[0]) | |
print(output) | |
return output | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment