Skip to content

Instantly share code, notes, and snippets.

@nenkoru
Created July 16, 2023 12:02
Show Gist options
  • Save nenkoru/ad94d1ed28e9fbb0e0598ac459dc7750 to your computer and use it in GitHub Desktop.
Save nenkoru/ad94d1ed28e9fbb0e0598ac459dc7750 to your computer and use it in GitHub Desktop.
Starcoder for lmi
from typing import Union, Iterable, List
from collections.abc import Iterable as abc_Iterable
import transformers
import ctranslate2
class CT2Generator:
"""Implements LMIProtocol."""
def __init__(self,
*,
generator: "ctranslate2.Generator",
tokenizer: "transformers.AutoTokenizer"
):
self._generator = generator
self._tokenizer = tokenizer
def generate(
self,
*args,
inputs: Union[str, Iterable[str]],
parameters: "lmi.GenerationParameters",
**kwargs
) -> Union[str, List[str]]:
if not isinstance(inputs, abc_Iterable):
inputs = [inputs]
encoded_prompts = [self._tokenizer.encode(input) for input in inputs]
tokens = [self._tokenizer.convert_ids_to_tokens(encoded_prompt) for encoded_prompt in encoded_prompts]
results = self._generator.generate_batch(
tokens,
max_length=parameters.max_tokens,
include_prompt_in_result=parameters.return_prompt,
top_p=parameters.top_p,
top_k=parameters.top_k,
)
texts = [self._tokenizer.decode(result.sequences_ids[0]) for result in results]
return texts
_ct2_generator = ctranslate2.Generator(
"starcoderplus_ct2_int8",
device="cuda",
compute_type="int8"
)
_tokenizer = transformers.AutoTokenizer.from_pretrained(
"bigcode/starcoderplus"
)
generator = CT2Generator(
generator=_ct2_generator,
tokenizer=_tokenizer,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment