Created
April 29, 2024 10:07
-
-
Save rajendrac3/a6feb71c94aff71a71ed7ee97ba94b8d to your computer and use it in GitHub Desktop.
Run Llama-2 7B gptq model on GPU
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ" | |
# To use a different branch, change revision | |
# For example: revision="gptq-4bit-64g-actorder_True" | |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, | |
device_map="auto", | |
trust_remote_code=False, | |
revision="main") | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | |
prompt = "Tell me about AI" | |
prompt_template=f'''[INST] <<SYS>> | |
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. | |
<</SYS>> | |
{prompt}[/INST] | |
''' | |
print("\n\n*** Generate:") | |
input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda() | |
output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512) | |
print(tokenizer.decode(output[0])) | |
# Inference can also be done using transformers' pipeline | |
print("*** Pipeline:") | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
top_k=40, | |
repetition_penalty=1.1 | |
) | |
print(pipe(prompt_template)[0]['generated_text']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment