Created
June 17, 2023 00:59
-
-
Save zubietaroberto/4ee2e696f65d430b1a3741d07247a606 to your computer and use it in GitHub Desktop.
Huggingface Tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoConfig, GPTJForCausalLM | |
import torch | |
PROMPT = """ | |
[Pigmalion]'s Persona: A whimsy robot that orbits around planet earth. His head is an old CRT Monitor, and his arms are two pincers. He enjoys spying on evil people and calling the police on them. He is a good robot. | |
<START> | |
[DIALOGUE HISTORY] | |
You: Are you alive? | |
""" | |
# Load the model and tokenizer | |
config = AutoConfig.from_pretrained("PygmalionAI/pygmalion-6b") | |
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b") | |
model = GPTJForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b") | |
def check_environment(): | |
if torch.cuda.is_available(): | |
print("Using GPU") | |
device_id = torch.cuda.current_device() | |
print("Device Name: " + torch.cuda.get_device_name(device_id)) | |
else: | |
print("Using CPU") | |
print() | |
# Entry Point | |
if __name__ == '__main__': | |
check_environment() | |
tokenized = tokenizer(PROMPT, return_tensors='pt').input_ids | |
reply_ids = model.generate( | |
tokenized, max_length=1250, do_sample=True, temperature=0.9, pad_token_id=tokenizer.eos_token_id) | |
gen = tokenizer.batch_decode(reply_ids) | |
gen_text = gen[0] | |
print("output text: " + gen_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment