-
-
Save fullstackwebdev/81e64e8faca496e5390d09a4756d8db4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from unsloth.chat_templates import get_chat_template | |
| def load_model(): | |
| model_name = "unsloth/Llama-3.2-3B-Instruct" | |
| # Load the model and tokenizer | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Apply the chat template | |
| tokenizer = get_chat_template( | |
| tokenizer, | |
| chat_template="chatml", | |
| mapping={"role": "from", "content": "value", | |
| "user": "human", "assistant": "gpt"}, | |
| map_eos_token=True, | |
| ) | |
| return model, tokenizer | |
| def main(): | |
| model, tokenizer = load_model() | |
| # Prepare the model for inference | |
| model.eval() | |
| # Chat loop | |
| conversation = [] | |
| print("Welcome to the chat with the Llama-3.2-3B-Instruct model! Type 'exit' to end the conversation.") | |
| while True: | |
| user_input = input("You: ") | |
| if user_input.lower() == 'exit': | |
| break | |
| conversation.append({"from": "human", "value": user_input}) | |
| inputs = tokenizer.apply_chat_template( | |
| conversation, tokenize=True, return_tensors="pt").to(model.device) | |
| input_text_special_tokens = tokenizer.decode( | |
| inputs[0], skip_special_tokens=False) | |
| print(f"--------------------- Input") | |
| print(input_text_special_tokens[:500]) # Print the first 500 characters as a sample | |
| print(f"---------------------") | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=1024, | |
| temperature=0.7, | |
| top_p=0.95, | |
| do_sample=True | |
| ) | |
| response_special_tokens = tokenizer.decode( | |
| outputs[0][inputs.shape[1]:], skip_special_tokens=False) | |
| print(f"--------------------- Response (with special tokens)") | |
| print(response_special_tokens) | |
| print(f"---------------------") | |
| response = tokenizer.decode( | |
| outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| print("Assistant:", response) | |
| conversation.append({"from": "gpt", "value": response}) | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
(unsloth_env) ➜ tmp git:(master) ✗ python chat.py
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Traceback (most recent call last):
File "/home/shazam/dev/augmentive-dspy/tmp/chat.py", line 71, in
main()
File "/home/shazam/dev/augmentive-dspy/tmp/chat.py", line 24, in main
model, tokenizer = load_model()
^^^^^^^^^^^^
File "/home/shazam/dev/augmentive-dspy/tmp/chat.py", line 10, in load_model
tokenizer = AutoTokenizer.from_pretrained(model_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py", line 897, in from_pretrained
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2271, in from_pretrained
return cls._from_pretrained(
^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2505, in _from_pretrained
tokenizer = cls(*init_inputs, **init_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shazam/miniforge3/envs/unsloth_env/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py", line 115, in init
fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Exception: data did not match any variant of untagged enum ModelWrapper at line 1251003 column 3
(unsloth_env) ➜ tmp git:(master) ✗