Last active
August 17, 2024 08:15
-
-
Save cedrickchee/236e53ed2dca95bd96e5baa35cdd7be2 to your computer and use it in GitHub Desktop.
Code snippet to try OpenAssistant's pre-release model on your own system.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For more info, see my notes here: https://github.com/cedrickchee/chatgpt-universe#open-source-chatgpt | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import textwrap | |
MAX_NEW_TOKENS = 500 | |
name = "Rallio67/joi_20B_instruct_alpha" | |
model = AutoModelForCausalLM.from_pretrained( | |
name, | |
device_map='auto', | |
load_in_8bit=True | |
) | |
def reply(model, prompt) | |
prompt = "User: " + prompt + "\n\nJoi: " | |
tokenizer = AutoTokenizer.from_pretrained(name) | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") | |
gen_tokens = model.generate( | |
input_ids, | |
do_sample=True, | |
max_length=MAX_NEW_TOKENS, | |
num_return_sequences=1 | |
top_p=0.95, | |
temperature=0.5, | |
penalty_alpha=0.6, | |
top_k=4, | |
repetition_penalty=1.03, | |
pad_token_id=tokenizer.eos_token_id, | |
use_cache=True | |
) | |
wrapped_text = textwrap.wrap(tokenizer.decode(gen_tokens[0], skip_special_tokens=True), width=128) | |
for line in wrapped_text: | |
print(line) | |
while True: | |
try: | |
prompt = input("User: ") | |
except EOFError: | |
break | |
gen_from_model(model, prompt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment