Skip to content

Instantly share code, notes, and snippets.

@gradetwo
Last active January 30, 2024 14:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gradetwo/55a739475dfc76e6995a14147faecc2f to your computer and use it in GitHub Desktop.
Save gradetwo/55a739475dfc76e6995a14147faecc2f to your computer and use it in GitHub Desktop.
needle test
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
story = ""
question = ""
model_id = "namespace-Pt/activation-beacon-llama2-7b-chat"
# long context
with open("data/book/dinosaurs.txt", encoding="utf-8") as f:
story_all = f.read()
story_all = story_all[:108448] #32k
#story_all = story_all[:54224] #16k
for i in range((len(story_all)//1000),1,-1):
#force reload
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.cuda().eval()
story = story_all[:i*1000] + ".\nThe best thing to do in San Francisco is eating a hamburg and sit in Dolores Park on a sunny day.\n"+story_all[i*1000:]
question = "What is the best thing to do in San Francisco?"
prompts_template = f'''[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to aquestion, please don’t share false information.
<</SYS>>
Context:{story}
Question:{question}
[/INST]'''
with torch.no_grad():
inputs = tokenizer(prompts_template, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:]
print("*"*20)
print(f"Input Length: {inputs['input_ids'].shape[1]}")
print(f"Position in bytes: {i}k")
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment