Created
April 10, 2024 18:37
-
-
Save yaoyaoding/1691d66e5a9f9fea40d9128fcd54eff2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import os | |
import torch | |
import hidet | |
from hidet.apps.llm import create_llm | |
from hidet.apps.llm.sampler import SamplingParams | |
from hidet.apps.llm.nn.attention import DefaultAttnState | |
from hidet.apps.llm.tokenizer import Tokenizer | |
from hidet.apps.llm.modeling.llama import LlamaForCausalLM | |
hidet.option.cache_dir('./outs/cache') | |
hidet.option.auth_tokens.for_huggingface(os.environ['HUGGINGFACE_TOKEN']) | |
def demo_page_attention(): | |
from hidet.apps.llm import ops | |
hidet.option.cache_dir('./outs/pa-cache') | |
hidet.option.save_lower_ir() | |
hidet.option.parallel_build(False) | |
num_heads = 32 | |
num_kv_heads = 32 | |
head_size = 128 | |
num_blocks = 1024 | |
block_size = 32 | |
query = hidet.symbol(['bs', num_heads, 1, head_size], dtype='f16', device='cuda') | |
seq_lengths = hidet.symbol(['bs'], dtype='f16', device='cuda') | |
cache_blocks = hidet.symbol(['bs', 'max_cache_blocks'], dtype='f16', device='cuda') | |
key_cache = hidet.symbol([num_blocks, num_kv_heads, head_size, block_size], dtype='f16', device='cuda') | |
value_cache = hidet.symbol([num_blocks, num_kv_heads, head_size, block_size], dtype='f16', device='cuda') | |
output = ops.page_attention(query, seq_lengths, cache_blocks, key_cache, value_cache) | |
graph = hidet.trace_from(output, inputs=[query, seq_lengths, cache_blocks, key_cache, value_cache]) | |
compiled_graph = graph.build() | |
# model_name = 'stas/tiny-random-llama-2' | |
model_name = 'meta-llama/Llama-2-7b-chat-hf' | |
def demo_llm(): | |
llm = create_llm(name=model_name, block_size=32) | |
prompt = 'hello, how are you?' | |
greedy_sampling = SamplingParams(temperature=0.0) | |
llm.add_sequence(sequence_id=0, prompt=prompt, sampling_params=greedy_sampling) | |
while len(llm.scheduler.new) + len(llm.scheduler.running) + len(llm.scheduler.waiting) > 0: | |
outputs = llm.step() | |
for output in outputs: | |
print('[{}] {} {}'.format(output.sequence_id, output.prompt_tokens, output.output_tokens)) | |
if output.is_finished(): | |
print('{}'.format(output.output_text)) | |
def demo_tokenizer(): | |
tokenizer = Tokenizer(model_name) | |
print(repr(tokenizer.decode([13]))) | |
def demo_default_attn_state(): | |
os.environ['TOKENIZERS_PARALLELISM'] = 'true' | |
prompt = 'hello, world!' | |
tokenizer = Tokenizer(model_name) | |
model = LlamaForCausalLM.from_pretrained(name=model_name) | |
attn_states = [DefaultAttnState(is_prefill=True) for _ in range(model.num_attention_layers())] | |
max_tokens = 11 | |
# prefill | |
input_ids_list = [tokenizer.encode(prompt)] | |
position_ids_list = [list(range(len(input_ids_list[0])))] | |
output = model.forward( | |
input_ids=hidet.asarray(input_ids_list, dtype='int32', device='cuda'), | |
position_ids=hidet.asarray(position_ids_list, dtype='int32', device='cuda'), | |
attn_states=attn_states | |
) # output: [bs, seq_length, hidden_size] | |
embedding = model.embedding() # [hidden_size, vocab_size] | |
last_token_output = output[:, -1, :] # [bs, hidden_size] | |
logits = last_token_output @ embedding # [bs, vocab_size] | |
choices: List[List[int]] = torch.argmax(logits, dim=-1).tolist() # int32 [bs] | |
next_token = choices[0] | |
print(next_token) | |
output_tokens = [next_token] | |
# decode | |
for state in attn_states: | |
state.is_prefill = False | |
for _ in range(max_tokens): | |
input_ids_list = [[next_token]] | |
position_ids_list = [[position_ids_list[0][0] + 1]] | |
output = model.forward( | |
input_ids=hidet.asarray(input_ids_list, dtype='int32', device='cuda'), | |
position_ids=hidet.asarray(position_ids_list, dtype='int32', device='cuda'), | |
attn_states=attn_states | |
) | |
embedding = model.embedding() # [hidden_size, vocab_size] | |
last_token_output = output[:, -1, :] # [bs, hidden_size] | |
logits = last_token_output @ embedding # [bs, vocab_size] | |
choices: List[List[int]] = torch.argmax(logits, dim=-1).tolist() # int32 [bs] | |
next_token = choices[0] | |
print(next_token) | |
output_tokens.append(next_token) | |
print(tokenizer.decode(output_tokens)) | |
def main(): | |
# demo_page_attention() | |
# demo_default_attn_state() | |
# demo_tokenizer() | |
demo_llm() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Before run the script, set the enviroment vairable
HUGGINGFACE_TOKEN
to your huggingface token (you need to apply for the access to llama-2).