-
-
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
device = "cuda" | |
# Copied from the gpt-fast repo | |
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
logits = logits / max(temperature, 1e-5) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
probs = logits_to_probs(logits[:, -1], temperature, top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def decode_one_tokens(model, cur_token, cache_position): | |
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] | |
new_token = sample(logits,temperature=0.6, top_k=5)[0] | |
return new_token | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) | |
model = model.to(device).eval() | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
prompt = "My favourite condiment is" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
batch_size, sequence_length = input_ids.shape | |
max_cache_length = 2048 | |
max_new_tokens = 100 | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) | |
generated_ids[:,:sequence_length] = input_ids | |
cache_position = torch.tensor([sequence_length], device=device) | |
with torch.no_grad(): | |
for i in range(100): | |
if i == 0: # prefill uses vanilla model | |
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] | |
input_id = sample(logits, temperature=0.6, top_k=5)[0] | |
generated_ids[:,sequence_length] = input_id[:,0] | |
else: | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
input_id = decode_one_tokens(model, input_id.clone(), cache_position) | |
generated_ids.index_copy_(1, cache_position, input_id) | |
cache_position += 1 | |
print(tokenizer.batch_decode(generated_ids.long())) | |
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"] |
I don't think it is too much changes, you can create an issue on transformers
it's a good difficult issue but should be straightforward from the changes done to Llama
Got it, thanks for the response!
I also see some strange behavior after compiling/running for the first time. If I change the prompt and rerun the code from the prompt= "..."
line, the generation looks like it is using the cache state from the previous prompt. I don't see this when using the non-compiled decode_one_tokens()
.
Specific example: after running with the original prompt, run with
prompt = "List all numbers in the Fibonacci sequence: 1, 1, 2, 3, "
...
["<s> List all numbers in the Fibonacci sequence: 1, 1, 2, 3, 5 can be used in many ways.\nI love to use it as a relish for my sandwiches, or as a topping for my toast.\nIt's great on cheese and crackers, and it's also delicious as a dip for your vegies.\nYou can use it as a condiment for meat dishes or as a marinade.\nIt's also a great accompaniment to your curry dishes, or even as"]
So the first token 5
is good (from prefill) but the rest is a continuation of the first prompt + first token. Is this a torch.compile()
thing or something I'm doing wrong with the cache? model._reset_cache()
also doesn't help here (although it should not matter cause we are doing model._setup_cache()
each time)
Hi @ArthurZucker, i have a doubt while trying out the code for sequences inference sequentially.
With batches the above code works, but how can i make it run over sequentially?
Because i was getting the same output for all both the text with the below code. Can you help me out?
#######################################
# Model loading section #
#######################################
device = "cuda"
model_name = "meta-llama/Llama-2-7b-chat-hf"
# quant_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto",
torch_dtype=torch.float16,
# attn_implementation="flash_attention_2"
)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, token="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")
# Copied from the gpt-fast repo
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = False)[0]
new_token = sample(logits,temperature=0.6, top_k=5)[0]
return new_token
#######################################
# Model inference section #
#######################################
tot_time, tot_tokens = 0, 0
length = 2 #Defining 2 samples to test
results=[]
max_cache_length = 2048
max_new_tokens, batch_size= 100, 1
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True)
texts = []
for idx, text in enumerate(["Can you suggest me a joke?", "give me top three rap songs"]):
start = time.time()
device='cuda'
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
batch_size, sequence_length = input_ids.shape
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device)
generated_ids[:,:sequence_length] = input_ids
cache_position = torch.tensor([sequence_length], device=device)
with torch.no_grad():
for i in range(max_new_tokens):
if i == 0: # prefill uses vanilla model
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0]
input_id = sample(logits, temperature=0.6, top_k=5)[0]
generated_ids[:,sequence_length] = input_id[:,0]
else:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
input_id = decode_one_tokens(model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, cache_position, input_id)
# print(input_id[0][0].item(), tokenizer.decode(input_id[0][0]))
if input_id[0][0].item() == 2:
# We got </s> EOS token as output
break
cache_position += 1
# print("New cache value -> ", cache_position)
print(tokenizer.batch_decode(generated_ids[:, :sequence_length]))
results += tokenizer.batch_decode(generated_ids[:, sequence_length:cache_position].long())
# model._reset_cache()
# print(results)
tot_time += round(time.time() - start, 3)
tot_tokens+= cache_position.item() - sequence_length
print(f"Total time model took -> {tot_time/length} secs for {tot_tokens//length} tokens generated")
You need to call setup cache and reset cache for each sentence
I tried the way you suggested, but was getting the same answer for both the texts.
Responses: ["19th century?\n\nI am looking for a joke that would be considered appropriate and funny in the 19th century. Do you have any suggestions?\n\nAnswer:\n\nSure! Here's a joke from the 19th century that might fit the bill:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the pond!\n\nThis joke plays on the common phrase",
"of the day?\n\nI'd be happy to! Here is a joke for you:\n\nWhy couldn't the bicycle stand up by itself?\n\nBecause it was two-tired!\n\nI hope you found that joke amusing! Let me know if you would like another one."]
Where the inputs for these responses were:
["Can you suggest me a joke?", "give me top three rap songs"]
Hi @ArthurZucker, can you provide me any kind of support on this issue?
Hey sorry:
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with CaptureLogger(logging.get_logger(__name__)) as cl:
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
self.assertNotIn("skipping cudagraphs due to", cl.out)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
this works because you have the same length of the prompt. You need to make sure the position ids are passed to the model and incremented
Hi @ArthurZucker the link is broken btw
Yes, thanks
@ArthurZucker nice script! This doesn't work out of box for mistral right? Any idea how much work is needed to support it?