Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Last active May 17, 2024 02:26
Show Gist options
  • Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
simple static kv cache script
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
device = "cuda"
# Copied from the gpt-fast repo
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0]
new_token = sample(logits,temperature=0.6, top_k=5)[0]
return new_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16)
model = model.to(device).eval()
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompt = "My favourite condiment is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
batch_size, sequence_length = input_ids.shape
max_cache_length = 2048
max_new_tokens = 100
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device)
generated_ids[:,:sequence_length] = input_ids
cache_position = torch.tensor([sequence_length], device=device)
with torch.no_grad():
for i in range(100):
if i == 0: # prefill uses vanilla model
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0]
input_id = sample(logits, temperature=0.6, top_k=5)[0]
generated_ids[:,sequence_length] = input_id[:,0]
else:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
input_id = decode_one_tokens(model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, cache_position, input_id)
cache_position += 1
print(tokenizer.batch_decode(generated_ids.long()))
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"]
@attili-sanjeet
Copy link

Hi @ArthurZucker, can you provide me any kind of support on this issue?

@ArthurZucker
Copy link
Author

Hey sorry:

        NUM_TOKENS_TO_GENERATE = 40
        EXPECTED_TEXT_COMPLETION = [
            "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
            "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
        ]
        prompts = [
            "Simply put, the theory of relativity states that ",
            "My favorite all time favorite condiment is ketchup.",
        ]
        tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
        model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
        inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

        def decode_one_tokens(model, cur_token, input_pos, cache_position):
            logits = model(
                cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
            )[0]
            new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
            return new_token

        batch_size, seq_length = inputs["input_ids"].shape
        with torch.no_grad():
            model._setup_cache(StaticCache, 2, max_cache_len=4096)
            cache_position = torch.arange(seq_length, device=torch_device)
            generated_ids = torch.zeros(
                batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
            )
            generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

            logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
            next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
            generated_ids[:, seq_length] = next_token[:, 0]

            decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
            cache_position = torch.tensor([seq_length + 1], device=torch_device)
            for _ in range(1, NUM_TOKENS_TO_GENERATE):
                with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                    with CaptureLogger(logging.get_logger(__name__)) as cl:
                        next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
                        self.assertNotIn("skipping cudagraphs due to", cl.out)
                    generated_ids[:, cache_position] = next_token.int()
                cache_position += 1

        text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

this works because you have the same length of the prompt. You need to make sure the position ids are passed to the model and incremented

@ArthurZucker
Copy link
Author

ArthurZucker commented Apr 25, 2024 via email

@aliencaocao
Copy link

Hi @ArthurZucker the link is broken btw

@aliencaocao
Copy link

Yes, thanks

@Luke20000429
Copy link

Luke20000429 commented May 17, 2024

Thanks for providing this sample! I am building a customized MLP block with Triton2.2. Looks like my model cannot be processed by torch.compile. It works well with CUDAGraph without cache though. Could you provide a sample code that builds a CUDA graph with static cache? The error I get is

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'View' object has no attribute 'layout'
  target: triton_kernel_wrapper_functional
  kwargs: {'kernel_idx': 0, 'grid': [(688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (688, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (344, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1), (172, 1, 1)], 'kwargs': {'a_ptr': TensorBox(
    View(
      SliceView(
        StorageBox(
          ComputedBuffer(name='buf19', layout=FixedLayout('cuda', torch.float16, size=[1, 1, 4096], stride=[4096, 4096, 1]), data=Pointwise(
            'cuda',
            torch.float16,
            def inner_fn(index):
  ....

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Any suggestion on fixing this error is also helpful!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment