-
-
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
device = "cuda" | |
# Copied from the gpt-fast repo | |
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
logits = logits / max(temperature, 1e-5) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
probs = logits_to_probs(logits[:, -1], temperature, top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def decode_one_tokens(model, cur_token, cache_position): | |
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] | |
new_token = sample(logits,temperature=0.6, top_k=5)[0] | |
return new_token | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) | |
model = model.to(device).eval() | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
prompt = "My favourite condiment is" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
batch_size, sequence_length = input_ids.shape | |
max_cache_length = 2048 | |
max_new_tokens = 100 | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) | |
generated_ids[:,:sequence_length] = input_ids | |
cache_position = torch.tensor([sequence_length], device=device) | |
with torch.no_grad(): | |
for i in range(100): | |
if i == 0: # prefill uses vanilla model | |
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] | |
input_id = sample(logits, temperature=0.6, top_k=5)[0] | |
generated_ids[:,sequence_length] = input_id[:,0] | |
else: | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
input_id = decode_one_tokens(model, input_id.clone(), cache_position) | |
generated_ids.index_copy_(1, cache_position, input_id) | |
cache_position += 1 | |
print(tokenizer.batch_decode(generated_ids.long())) | |
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"] |
You need to call setup cache and reset cache for each sentence
I tried the way you suggested, but was getting the same answer for both the texts.
Responses: ["19th century?\n\nI am looking for a joke that would be considered appropriate and funny in the 19th century. Do you have any suggestions?\n\nAnswer:\n\nSure! Here's a joke from the 19th century that might fit the bill:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the pond!\n\nThis joke plays on the common phrase",
"of the day?\n\nI'd be happy to! Here is a joke for you:\n\nWhy couldn't the bicycle stand up by itself?\n\nBecause it was two-tired!\n\nI hope you found that joke amusing! Let me know if you would like another one."]
Where the inputs for these responses were:
["Can you suggest me a joke?", "give me top three rap songs"]
Hi @ArthurZucker, can you provide me any kind of support on this issue?
Hey sorry:
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with CaptureLogger(logging.get_logger(__name__)) as cl:
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
self.assertNotIn("skipping cudagraphs due to", cl.out)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
this works because you have the same length of the prompt. You need to make sure the position ids are passed to the model and incremented
Hi @ArthurZucker the link is broken btw
Yes, thanks
Thanks for providing this sample! I am building a customized MLP block with Triton2.2. Looks like my model cannot be processed by torch.compile
. It works well with CUDAGraph without cache though. Could you provide a sample code that builds a CUDA graph with static cache? The error I get is
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'layout'
target: triton_kernel_wrapper_functional
kwargs: {'kernel_idx': 0, 'grid': [(688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1)], 'kwargs': {'a_ptr': TensorBox(
View(
SliceView(
StorageBox(
ComputedBuffer(name='buf19', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 4096], stride=[4096, 4096, 1]), data=Pointwise(
'cuda',
torch.float16,
def inner_fn(index):
....
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Any suggestion on fixing this error is also helpful!!!
Hi @ArthurZucker, i have a doubt while trying out the code for sequences inference sequentially.
With batches the above code works, but how can i make it run over sequentially?
Because i was getting the same output for all both the text with the below code. Can you help me out?