Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Created April 10, 2024 18:37
Show Gist options
  • Save yaoyaoding/1691d66e5a9f9fea40d9128fcd54eff2 to your computer and use it in GitHub Desktop.
Save yaoyaoding/1691d66e5a9f9fea40d9128fcd54eff2 to your computer and use it in GitHub Desktop.
from typing import List
import os
import torch
import hidet
from hidet.apps.llm import create_llm
from hidet.apps.llm.sampler import SamplingParams
from hidet.apps.llm.nn.attention import DefaultAttnState
from hidet.apps.llm.tokenizer import Tokenizer
from hidet.apps.llm.modeling.llama import LlamaForCausalLM
hidet.option.cache_dir('./outs/cache')
hidet.option.auth_tokens.for_huggingface(os.environ['HUGGINGFACE_TOKEN'])
def demo_page_attention():
from hidet.apps.llm import ops
hidet.option.cache_dir('./outs/pa-cache')
hidet.option.save_lower_ir()
hidet.option.parallel_build(False)
num_heads = 32
num_kv_heads = 32
head_size = 128
num_blocks = 1024
block_size = 32
query = hidet.symbol(['bs', num_heads, 1, head_size], dtype='f16', device='cuda')
seq_lengths = hidet.symbol(['bs'], dtype='f16', device='cuda')
cache_blocks = hidet.symbol(['bs', 'max_cache_blocks'], dtype='f16', device='cuda')
key_cache = hidet.symbol([num_blocks, num_kv_heads, head_size, block_size], dtype='f16', device='cuda')
value_cache = hidet.symbol([num_blocks, num_kv_heads, head_size, block_size], dtype='f16', device='cuda')
output = ops.page_attention(query, seq_lengths, cache_blocks, key_cache, value_cache)
graph = hidet.trace_from(output, inputs=[query, seq_lengths, cache_blocks, key_cache, value_cache])
compiled_graph = graph.build()
# model_name = 'stas/tiny-random-llama-2'
model_name = 'meta-llama/Llama-2-7b-chat-hf'
def demo_llm():
llm = create_llm(name=model_name, block_size=32)
prompt = 'hello, how are you?'
greedy_sampling = SamplingParams(temperature=0.0)
llm.add_sequence(sequence_id=0, prompt=prompt, sampling_params=greedy_sampling)
while len(llm.scheduler.new) + len(llm.scheduler.running) + len(llm.scheduler.waiting) > 0:
outputs = llm.step()
for output in outputs:
print('[{}] {} {}'.format(output.sequence_id, output.prompt_tokens, output.output_tokens))
if output.is_finished():
print('{}'.format(output.output_text))
def demo_tokenizer():
tokenizer = Tokenizer(model_name)
print(repr(tokenizer.decode([13])))
def demo_default_attn_state():
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
prompt = 'hello, world!'
tokenizer = Tokenizer(model_name)
model = LlamaForCausalLM.from_pretrained(name=model_name)
attn_states = [DefaultAttnState(is_prefill=True) for _ in range(model.num_attention_layers())]
max_tokens = 11
# prefill
input_ids_list = [tokenizer.encode(prompt)]
position_ids_list = [list(range(len(input_ids_list[0])))]
output = model.forward(
input_ids=hidet.asarray(input_ids_list, dtype='int32', device='cuda'),
position_ids=hidet.asarray(position_ids_list, dtype='int32', device='cuda'),
attn_states=attn_states
) # output: [bs, seq_length, hidden_size]
embedding = model.embedding() # [hidden_size, vocab_size]
last_token_output = output[:, -1, :] # [bs, hidden_size]
logits = last_token_output @ embedding # [bs, vocab_size]
choices: List[List[int]] = torch.argmax(logits, dim=-1).tolist() # int32 [bs]
next_token = choices[0]
print(next_token)
output_tokens = [next_token]
# decode
for state in attn_states:
state.is_prefill = False
for _ in range(max_tokens):
input_ids_list = [[next_token]]
position_ids_list = [[position_ids_list[0][0] + 1]]
output = model.forward(
input_ids=hidet.asarray(input_ids_list, dtype='int32', device='cuda'),
position_ids=hidet.asarray(position_ids_list, dtype='int32', device='cuda'),
attn_states=attn_states
)
embedding = model.embedding() # [hidden_size, vocab_size]
last_token_output = output[:, -1, :] # [bs, hidden_size]
logits = last_token_output @ embedding # [bs, vocab_size]
choices: List[List[int]] = torch.argmax(logits, dim=-1).tolist() # int32 [bs]
next_token = choices[0]
print(next_token)
output_tokens.append(next_token)
print(tokenizer.decode(output_tokens))
def main():
# demo_page_attention()
# demo_default_attn_state()
# demo_tokenizer()
demo_llm()
if __name__ == '__main__':
main()
@yaoyaoding
Copy link
Author

Before run the script, set the enviroment vairable HUGGINGFACE_TOKEN to your huggingface token (you need to apply for the access to llama-2).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment