Skip to content

Instantly share code, notes, and snippets.

@f0ster
Created April 28, 2024 15:30
Show Gist options
  • Save f0ster/26fd9f2c0e28fbfca6c3f61e86567c3e to your computer and use it in GitHub Desktop.
Save f0ster/26fd9f2c0e28fbfca6c3f61e86567c3e to your computer and use it in GitHub Desktop.
Running mistralai mixtral locally
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_model_and_tokenizer(model_id):
"""
Load the tokenizer and model based on the specified model ID.
Model is set to use float16 for computation to reduce memory usage and improve performance.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
return tokenizer, model
def prepare_input(tokenizer, messages):
"""
Convert the list of message dictionaries into model-ready input IDs.
"""
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
return input_ids
def generate_response(model, input_ids):
"""
Generate a response using the model and the provided input IDs.
"""
outputs = model.generate(input_ids, max_new_tokens=20)
return outputs
def main():
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
# Load model and tokenizer
tokenizer, model = load_model_and_tokenizer(model_id)
# Measure the start time
start_time = time.time()
# Prepare input for the model
input_ids = prepare_input(tokenizer, messages)
# Generate output from the model
outputs = generate_response(model, input_ids)
# Measure the end time
end_time = time.time()
# Print the elapsed time and the decoded output
print("Elapsed time: {:.2f} seconds".format(end_time - start_time))
print("Generated text:", tokenizer.decode(outputs[0], skip_special_tokens=True))
if __name__ == "__main__":
main()
@f0ster
Copy link
Author

f0ster commented Apr 28, 2024

(transformers) user@fedora:~/code/transformers$ python mixtral_demo.py
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:44<00:00,  2.35s/it]
WARNING:root:Some parameters are on the meta device device because they were offloaded to the disk and cpu.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Elapsed time: 333.75 seconds
Generated text: [INST] What is your favourite condiment? [/INST]Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen! [INST] Do you have mayonnaise recipes? [/INST] Yes, I do have a simple mayonnaise recipe that I enjoy making. Here it is:
(transformers) user@fedora:~/code/transformers$ 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment