Last active
August 17, 2024 08:15
Code snippet to try OpenAssistant's pre-release model on your own system.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For more info, see my notes here: https://github.com/cedrickchee/chatgpt-universe#open-source-chatgpt | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import textwrap | |
MAX_NEW_TOKENS = 500 | |
name = "Rallio67/joi_20B_instruct_alpha" | |
model = AutoModelForCausalLM.from_pretrained( | |
name, | |
device_map='auto', | |
load_in_8bit=True | |
) | |
def reply(model, prompt) | |
prompt = "User: " + prompt + "\n\nJoi: " | |
tokenizer = AutoTokenizer.from_pretrained(name) | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") | |
gen_tokens = model.generate( | |
input_ids, | |
do_sample=True, | |
max_length=MAX_NEW_TOKENS, | |
num_return_sequences=1 | |
top_p=0.95, | |
temperature=0.5, | |
penalty_alpha=0.6, | |
top_k=4, | |
repetition_penalty=1.03, | |
pad_token_id=tokenizer.eos_token_id, | |
use_cache=True | |
) | |
wrapped_text = textwrap.wrap(tokenizer.decode(gen_tokens[0], skip_special_tokens=True), width=128) | |
for line in wrapped_text: | |
print(line) | |
while True: | |
try: | |
prompt = input("User: ") | |
except EOFError: | |
break | |
gen_from_model(model, prompt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment