Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Last active April 25, 2024 14:11
Show Gist options
  • Star 29 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
Save ArthurZucker/af34221def212259b43d55a2811d2dbb to your computer and use it in GitHub Desktop.
simple static kv cache script
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
from typing import Optional
device = "cuda"
# Copied from the gpt-fast repo
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0]
new_token = sample(logits,temperature=0.6, top_k=5)[0]
return new_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16)
model = model.to(device).eval()
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompt = "My favourite condiment is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
batch_size, sequence_length = input_ids.shape
max_cache_length = 2048
max_new_tokens = 100
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device)
generated_ids[:,:sequence_length] = input_ids
cache_position = torch.tensor([sequence_length], device=device)
with torch.no_grad():
for i in range(100):
if i == 0: # prefill uses vanilla model
logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0]
input_id = sample(logits, temperature=0.6, top_k=5)[0]
generated_ids[:,sequence_length] = input_id[:,0]
else:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
input_id = decode_one_tokens(model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, cache_position, input_id)
cache_position += 1
print(tokenizer.batch_decode(generated_ids.long()))
["<s> My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"]
@conway-abacus
Copy link

@ArthurZucker nice script! This doesn't work out of box for mistral right? Any idea how much work is needed to support it?

@ArthurZucker
Copy link
Author

I don't think it is too much changes, you can create an issue on transformers it's a good difficult issue but should be straightforward from the changes done to Llama

@conway-abacus
Copy link

conway-abacus commented Feb 18, 2024

Got it, thanks for the response!

I also see some strange behavior after compiling/running for the first time. If I change the prompt and rerun the code from the prompt= "..." line, the generation looks like it is using the cache state from the previous prompt. I don't see this when using the non-compiled decode_one_tokens().

Specific example: after running with the original prompt, run with

prompt = "List all numbers in the Fibonacci sequence: 1, 1, 2, 3, "
...
["<s> List all numbers in the Fibonacci sequence: 1, 1, 2, 3, 5 can be used in many ways.\nI love to use it as a relish for my sandwiches, or as a topping for my toast.\nIt's great on cheese and crackers, and it's also delicious as a dip for your vegies.\nYou can use it as a condiment for meat dishes or as a marinade.\nIt's also a great accompaniment to your curry dishes, or even as"]

So the first token 5 is good (from prefill) but the rest is a continuation of the first prompt + first token. Is this a torch.compile() thing or something I'm doing wrong with the cache? model._reset_cache() also doesn't help here (although it should not matter cause we are doing model._setup_cache() each time)

@attili-sanjeet
Copy link

Hi @ArthurZucker, i have a doubt while trying out the code for sequences inference sequentially.
With batches the above code works, but how can i make it run over sequentially?
Because i was getting the same output for all both the text with the below code. Can you help me out?

#######################################
        # Model loading section #
#######################################

device = "cuda"
model_name = "meta-llama/Llama-2-7b-chat-hf"
# quant_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto",
        torch_dtype=torch.float16,
        # attn_implementation="flash_attention_2"
        )

model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, token="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")

# Copied from the gpt-fast repo
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs

def decode_one_tokens(model, cur_token, cache_position):
    logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = False)[0]
    new_token = sample(logits,temperature=0.6, top_k=5)[0]
    return new_token

#######################################
     # Model inference section #
#######################################

tot_time, tot_tokens = 0, 0
length = 2 #Defining 2 samples to test
results=[]
max_cache_length = 2048
max_new_tokens, batch_size= 100, 1
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True)
texts = []
for idx, text in enumerate(["Can you suggest me a joke?", "give me top three rap songs"]):
    start = time.time()
    device='cuda'
    input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
    batch_size, sequence_length = input_ids.shape
    generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device)
    generated_ids[:,:sequence_length] = input_ids
    cache_position = torch.tensor([sequence_length], device=device)
    with torch.no_grad():
        for i in range(max_new_tokens):
            if i == 0: # prefill uses vanilla model
                logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0]
                input_id = sample(logits, temperature=0.6, top_k=5)[0]
                generated_ids[:,sequence_length] = input_id[:,0]
            else:
                with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                    input_id = decode_one_tokens(model, input_id.clone(), cache_position)
                    generated_ids.index_copy_(1, cache_position, input_id)
                    # print(input_id[0][0].item(), tokenizer.decode(input_id[0][0]))
                    if input_id[0][0].item() == 2:
                        # We got </s> EOS token as output
                        break
            cache_position += 1
    # print("New cache value -> ", cache_position)
    print(tokenizer.batch_decode(generated_ids[:, :sequence_length]))
    results += tokenizer.batch_decode(generated_ids[:, sequence_length:cache_position].long())
    # model._reset_cache()
    # print(results)
    tot_time += round(time.time() - start, 3)
    tot_tokens+= cache_position.item() - sequence_length
print(f"Total time model took -> {tot_time/length} secs for {tot_tokens//length} tokens generated")

@ArthurZucker
Copy link
Author

ArthurZucker commented Feb 23, 2024

You need to call setup cache and reset cache for each sentence

@attili-sanjeet
Copy link

I tried the way you suggested, but was getting the same answer for both the texts.

Responses: ["19th century?\n\nI am looking for a joke that would be considered appropriate and funny in the 19th century. Do you have any suggestions?\n\nAnswer:\n\nSure! Here's a joke from the 19th century that might fit the bill:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the pond!\n\nThis joke plays on the common phrase",
 "of the day?\n\nI'd be happy to! Here is a joke for you:\n\nWhy couldn't the bicycle stand up by itself?\n\nBecause it was two-tired!\n\nI hope you found that joke amusing! Let me know if you would like another one."]

Where the inputs for these responses were:

["Can you suggest me a joke?", "give me top three rap songs"]

@attili-sanjeet
Copy link

Hi @ArthurZucker, can you provide me any kind of support on this issue?

@ArthurZucker
Copy link
Author

Hey sorry:

        NUM_TOKENS_TO_GENERATE = 40
        EXPECTED_TEXT_COMPLETION = [
            "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
            "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
        ]
        prompts = [
            "Simply put, the theory of relativity states that ",
            "My favorite all time favorite condiment is ketchup.",
        ]
        tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
        model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
        inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

        def decode_one_tokens(model, cur_token, input_pos, cache_position):
            logits = model(
                cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
            )[0]
            new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
            return new_token

        batch_size, seq_length = inputs["input_ids"].shape
        with torch.no_grad():
            model._setup_cache(StaticCache, 2, max_cache_len=4096)
            cache_position = torch.arange(seq_length, device=torch_device)
            generated_ids = torch.zeros(
                batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
            )
            generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)

            logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
            next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
            generated_ids[:, seq_length] = next_token[:, 0]

            decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
            cache_position = torch.tensor([seq_length + 1], device=torch_device)
            for _ in range(1, NUM_TOKENS_TO_GENERATE):
                with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                    with CaptureLogger(logging.get_logger(__name__)) as cl:
                        next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
                        self.assertNotIn("skipping cudagraphs due to", cl.out)
                    generated_ids[:, cache_position] = next_token.int()
                cache_position += 1

        text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

this works because you have the same length of the prompt. You need to make sure the position ids are passed to the model and incremented

@ArthurZucker
Copy link
Author

ArthurZucker commented Apr 25, 2024 via email

@aliencaocao
Copy link

Hi @ArthurZucker the link is broken btw

@aliencaocao
Copy link

Yes, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment