Skip to content

Instantly share code, notes, and snippets.

@kirill-fedyanin
Created January 30, 2024 08:19
Show Gist options
  • Save kirill-fedyanin/8e786be58f268b556bd02cbd659d538c to your computer and use it in GitHub Desktop.
Save kirill-fedyanin/8e786be58f268b556bd02cbd659d538c to your computer and use it in GitHub Desktop.
"""
Script to quickly check that model answers make sense and not complete garbage
Also useful to check if a checkpoint runs at all
python scripts/arabic_sanity.py --model-name='mistralai/Mistral-7B-v0.1'
python scripts/arabic_sanity.py --model-name='./data/zephyr-7b-sft-lora/checkpoint-12'
"""
from argparse import ArgumentParser
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = "cuda" # the device to load the model onto
parser = ArgumentParser()
# 'mistralai/Mistral-7B-v0.1'
# './data/zephyr-7b-sft-lora/checkpoint-12'
parser.add_argument("--model-name", type=str, default='tiiuae/falcon-7b-instruct')
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
prompts = [
"اسم بعض الفواكه",
"من هو اول رئيس لدوله الامارات العربيه المتحده",
"عد من واحد إلى عشرة",
"اكتب قصيدة قصيرة عن فالكون"
]
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'المستخدم: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '' }}\n{% elif message['role'] == 'assistant' %}\n{{ 'نور:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'نور:' }}\n{% endif %}\n{% endfor %}"
for prompt in prompts:
request = tokenizer.apply_chat_template([{'content': prompt, 'role': 'user'}], tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([request], return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
max_new_tokens=180,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print("\n\n$*******************$\n")
print(request)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment