Skip to content

Instantly share code, notes, and snippets.

@zubietaroberto
Created June 17, 2023 00:59
Show Gist options
  • Save zubietaroberto/4ee2e696f65d430b1a3741d07247a606 to your computer and use it in GitHub Desktop.
Save zubietaroberto/4ee2e696f65d430b1a3741d07247a606 to your computer and use it in GitHub Desktop.
Huggingface Tests
from transformers import AutoTokenizer, AutoConfig, GPTJForCausalLM
import torch
PROMPT = """
[Pigmalion]'s Persona: A whimsy robot that orbits around planet earth. His head is an old CRT Monitor, and his arms are two pincers. He enjoys spying on evil people and calling the police on them. He is a good robot.
<START>
[DIALOGUE HISTORY]
You: Are you alive?
"""
# Load the model and tokenizer
config = AutoConfig.from_pretrained("PygmalionAI/pygmalion-6b")
tokenizer = AutoTokenizer.from_pretrained("PygmalionAI/pygmalion-6b")
model = GPTJForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b")
def check_environment():
if torch.cuda.is_available():
print("Using GPU")
device_id = torch.cuda.current_device()
print("Device Name: " + torch.cuda.get_device_name(device_id))
else:
print("Using CPU")
print()
# Entry Point
if __name__ == '__main__':
check_environment()
tokenized = tokenizer(PROMPT, return_tensors='pt').input_ids
reply_ids = model.generate(
tokenized, max_length=1250, do_sample=True, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
gen = tokenizer.batch_decode(reply_ids)
gen_text = gen[0]
print("output text: " + gen_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment