Skip to content

Instantly share code, notes, and snippets.

@fullstackwebdev
Created September 25, 2024 19:36
Show Gist options
  • Select an option

  • Save fullstackwebdev/81e64e8faca496e5390d09a4756d8db4 to your computer and use it in GitHub Desktop.

Select an option

Save fullstackwebdev/81e64e8faca496e5390d09a4756d8db4 to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from unsloth.chat_templates import get_chat_template
def load_model():
model_name = "unsloth/Llama-3.2-3B-Instruct"
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Apply the chat template
tokenizer = get_chat_template(
tokenizer,
chat_template="chatml",
mapping={"role": "from", "content": "value",
"user": "human", "assistant": "gpt"},
map_eos_token=True,
)
return model, tokenizer
def main():
model, tokenizer = load_model()
# Prepare the model for inference
model.eval()
# Chat loop
conversation = []
print("Welcome to the chat with the Llama-3.2-3B-Instruct model! Type 'exit' to end the conversation.")
while True:
user_input = input("You: ")
if user_input.lower() == 'exit':
break
conversation.append({"from": "human", "value": user_input})
inputs = tokenizer.apply_chat_template(
conversation, tokenize=True, return_tensors="pt").to(model.device)
input_text_special_tokens = tokenizer.decode(
inputs[0], skip_special_tokens=False)
print(f"--------------------- Input")
print(input_text_special_tokens[:500]) # Print the first 500 characters as a sample
print(f"---------------------")
outputs = model.generate(
inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.95,
do_sample=True
)
response_special_tokens = tokenizer.decode(
outputs[0][inputs.shape[1]:], skip_special_tokens=False)
print(f"--------------------- Response (with special tokens)")
print(response_special_tokens)
print(f"---------------------")
response = tokenizer.decode(
outputs[0][inputs.shape[1]:], skip_special_tokens=True)
print("Assistant:", response)
conversation.append({"from": "gpt", "value": response})
if __name__ == "__main__":
main()
@fullstackwebdev
Copy link
Author

(unsloth_env) ➜ tmp git:(master) ✗ python chat.py
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Traceback (most recent call last):
File "/home/shazam/dev/augmentive-dspy/tmp/chat.py", line 71, in
main()
File "/home/shazam/dev/augmentive-dspy/tmp/chat.py", line 24, in main
model, tokenizer = load_model()
^^^^^^^^^^^^
File "/home/shazam/dev/augmentive-dspy/tmp/chat.py", line 10, in load_model
tokenizer = AutoTokenizer.from_pretrained(model_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py", line 897, in from_pretrained
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2271, in from_pretrained
return cls._from_pretrained(
^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2505, in _from_pretrained
tokenizer = cls(*init_inputs, **init_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py", line 115, in init
fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Exception: data did not match any variant of untagged enum ModelWrapper at line 1251003 column 3
(unsloth_env) ➜ tmp git:(master) ✗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment