Skip to content

Instantly share code, notes, and snippets.

@ssheng
Last active August 31, 2023 18:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ssheng/38e59e475f3ac5b0f9299c71f7dc3185 to your computer and use it in GitHub Desktop.
Save ssheng/38e59e475f3ac5b0f9299c71f7dc3185 to your computer and use it in GitHub Desktop.
import bentoml
import torch
import typing as t
class LLMRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self):
self.tokenizer = bentoml.transformers.load_model("llama-2-tokenizer")
self.model = bentoml.transformers.load_model("llama-2-model")
@bentoml.Runnable.method()
def generate_iterator(self, prompt: str) -> t.Generator[str, None, str]:
input_ids = self.tokenizer(prompt).input_ids
context_length = 4096
max_new_token = 256
max_src_len = context_length - max_new_token - 1
input_ids = input_ids[-max_src_len:]
output_ids = list(input_ids)
input_echo_len = len(input_ids)
for i in range(max_new_token):
if i == 0:
out = self.model(torch.as_tensor([input_ids]), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = self.model(input_ids=torch.as_tensor([[token]]), use_cache=True, past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
output_ids.append(int(torch.multinomial(torch.sofmax(logits[0, -1, :], dim=-1), num_samples=1)))
yield {'text': self.tokenizer.decode(output_ids[input_echo_len:], skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)}
llm_stream_runner = bentoml.Runner(LLMRunnable)
svc = bentoml.Service("llm-stream-service", runners=[llm_stream_runner])
@svc.api(input=bentoml.io.Text(), output=bentoml.io.Text())
async def generate(prompt:str) -> t.AsyncGenerator[str, None]:
return stream_runner.generate.async_stream(prompt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment