Created
January 30, 2024 08:19
-
-
Save kirill-fedyanin/8e786be58f268b556bd02cbd659d538c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script to quickly check that model answers make sense and not complete garbage | |
Also useful to check if a checkpoint runs at all | |
python scripts/arabic_sanity.py --model-name='mistralai/Mistral-7B-v0.1' | |
python scripts/arabic_sanity.py --model-name='./data/zephyr-7b-sft-lora/checkpoint-12' | |
""" | |
from argparse import ArgumentParser | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
device = "cuda" # the device to load the model onto | |
parser = ArgumentParser() | |
# 'mistralai/Mistral-7B-v0.1' | |
# './data/zephyr-7b-sft-lora/checkpoint-12' | |
parser.add_argument("--model-name", type=str, default='tiiuae/falcon-7b-instruct') | |
args = parser.parse_args() | |
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto', trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
prompts = [ | |
"اسم بعض الفواكه", | |
"من هو اول رئيس لدوله الامارات العربيه المتحده", | |
"عد من واحد إلى عشرة", | |
"اكتب قصيدة قصيرة عن فالكون" | |
] | |
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'المستخدم: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '' }}\n{% elif message['role'] == 'assistant' %}\n{{ 'نور:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'نور:' }}\n{% endif %}\n{% endfor %}" | |
for prompt in prompts: | |
request = tokenizer.apply_chat_template([{'content': prompt, 'role': 'user'}], tokenize=False, add_generation_prompt=True) | |
model_inputs = tokenizer([request], return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
input_ids=model_inputs.input_ids, | |
attention_mask=model_inputs.attention_mask, | |
max_new_tokens=180, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
print("\n\n$*******************$\n") | |
print(request) | |
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment