Last active
May 17, 2024 02:26
-
-
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
simple static kv cache script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
from typing import Optional | |
device = "cuda" | |
# Copied from the gpt-fast repo | |
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | |
q = torch.empty_like(probs_sort).exponential_(1) | |
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
logits = logits / max(temperature, 1e-5) | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
pivot = v.select(-1, -1).unsqueeze(-1) | |
logits = torch.where(logits < pivot, -float("Inf"), logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
return probs | |
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | |
probs = logits_to_probs(logits[:, -1], temperature, top_k) | |
idx_next = multinomial_sample_one_no_sync(probs) | |
return idx_next, probs | |
def decode_one_tokens(model, cur_token, cache_position): | |
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] | |
new_token = sample(logits,temperature=0.6, top_k=5)[0] | |
return new_token | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) | |
model = model.to(device).eval() | |
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
prompt = "My favourite condiment is" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
batch_size, sequence_length = input_ids.shape | |
max_cache_length = 2048 | |
max_new_tokens = 100 | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) | |
generated_ids[:,:sequence_length] = input_ids | |
cache_position = torch.tensor([sequence_length], device=device) | |
with torch.no_grad(): | |
for i in range(100): | |
if i == 0: # prefill uses vanilla model | |
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] | |
input_id = sample(logits, temperature=0.6, top_k=5)[0] | |
generated_ids[:,sequence_length] = input_id[:,0] | |
else: | |
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): | |
input_id = decode_one_tokens(model, input_id.clone(), cache_position) | |
generated_ids.index_copy_(1, cache_position, input_id) | |
cache_position += 1 | |
print(tokenizer.batch_decode(generated_ids.long())) | |
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"] |
Author
ArthurZucker
commented
Apr 25, 2024
via email
Hey all!
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03
benchmark
gist.github.com
Should help, we are fixing the last bits 😉
… On 27 Mar 2024, at 06:48, Billy Cao ***@***.***> wrote:
@aliencaocao commented on this gist.
Hi is the static kv cache supposed to require a lot of vram for bs1 and ctx 2048? Trying to use it on llavanext 7b but it oom on a 24gb card. I am loading in 4bit already and there are plenty of vram left (some 10gb)
—
Reply to this email directly, view it on GitHub <https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb#gistcomment-5002538> or unsubscribe <https://github.com/notifications/unsubscribe-auth/ALSYHV6G2644FPQVFKM63SDY2JMULBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBSHAYTANBUU52HE2LHM5SXFJTDOJSWC5DF>.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
Hi @ArthurZucker the link is broken btw
…
On 25 Apr 2024, at 16:07, Billy Cao ***@***.***> wrote:
@aliencaocao commented on this gist.
Hi @ArthurZucker <https://github.com/ArthurZucker> the link is broken btw
—
Reply to this email directly, view it on GitHub <https://gist.github.com/ArthurZucker/af34221def212259b43d55a2811d2dbb#gistcomment-5035795> or unsubscribe <https://github.com/notifications/unsubscribe-auth/ALSYHVYKTWTBLYJLHEBRAH3Y7EE2BBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBSHAYTANBUU52HE2LHM5SXFJTDOJSWC5DF>.
You are receiving this email because you were mentioned.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
Yes, thanks
Thanks for providing this sample! I am building a customized MLP block with Triton2.2. Looks like my model cannot be processed by torch.compile
. It works well with CUDAGraph without cache though. Could you provide a sample code that builds a CUDA graph with static cache? The error I get is
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'layout'
target: triton_kernel_wrapper_functional
kwargs: {'kernel_idx': 0, 'grid': [(688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1)], 'kwargs': {'a_ptr': TensorBox(
View(
SliceView(
StorageBox(
ComputedBuffer(name='buf19', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 4096], stride=[4096, 4096, 1]), data=Pointwise(
'cuda',
torch.float16,
def inner_fn(index):
....
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Any suggestion on fixing this error is also helpful!!!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment