Last active
January 30, 2024 14:27
-
-
Save gradetwo/55a739475dfc76e6995a14147faecc2f to your computer and use it in GitHub Desktop.
needle test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
story = "" | |
question = "" | |
model_id = "namespace-Pt/activation-beacon-llama2-7b-chat" | |
# long context | |
with open("data/book/dinosaurs.txt", encoding="utf-8") as f: | |
story_all = f.read() | |
story_all = story_all[:108448] #32k | |
#story_all = story_all[:54224] #16k | |
for i in range((len(story_all)//1000),1,-1): | |
#force reload | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
model = model.cuda().eval() | |
story = story_all[:i*1000] + ".\nThe best thing to do in San Francisco is eating a hamburg and sit in Dolores Park on a sunny day.\n"+story_all[i*1000:] | |
question = "What is the best thing to do in San Francisco?" | |
prompts_template = f'''[INST] <<SYS>> | |
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to aquestion, please don’t share false information. | |
<</SYS>> | |
Context:{story} | |
Question:{question} | |
[/INST]''' | |
with torch.no_grad(): | |
inputs = tokenizer(prompts_template, return_tensors="pt").to("cuda") | |
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:] | |
print("*"*20) | |
print(f"Input Length: {inputs['input_ids'].shape[1]}") | |
print(f"Position in bytes: {i}k") | |
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") | |
torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment