Skip to content

Instantly share code, notes, and snippets.

@the-crypt-keeper
Last active September 27, 2023 01:55
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save the-crypt-keeper/8d781a12ee515903edc89ef69383570f to your computer and use it in GitHub Desktop.
Save the-crypt-keeper/8d781a12ee515903edc89ef69383570f to your computer and use it in GitHub Desktop.
llama2 chat prompt format reverse engineering
#
# this is adapted from https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L213
# the tokenizer is replaced with ord() to make it easier to see whats actually happening
from typing_extensions import TypedDict, Literal
from typing import List, Optional
Role = Literal["system", "user", "assistant"]
class Message(TypedDict):
role: Role
content: str
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class ChatPrediction(TypedDict, total=False):
generation: Message
tokens: List[str] # not required
logprobs: List[float] # not required
Dialog = List[Message]
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def encode(str, bos, eos):
s = '<s>' if bos else ''
s += str
s += '</s>' if eos else ''
return [ord(x) for x in s]
def chat_completion(
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[ChatPrediction]:
#if max_gen_len is None:
# max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = []
for dialog in dialogs:
if dialog[0]["role"] != "system":
dialog = [
{
"role": "system",
"content": DEFAULT_SYSTEM_PROMPT,
}
] + dialog
dialog = [
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: List[int] = sum(
[
encode(
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
bos=True,
eos=True,
)
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
],
[],
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += encode(
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
bos=True,
eos=False,
)
prompt_tokens.append(dialog_tokens)
return prompt_tokens
d1 = [Message(role="user", content="<prompt>")]
p1 = chat_completion([d1])
print(''.join([chr(x) for x in p1[0]]))
d2 = [Message(role="user", content="<prompt>"), Message(role="assistant", content="<answer>"), Message(role="user", content="<prompt-second>")]
p2 = chat_completion([d2])
print(''.join([chr(x) for x in p2[0]]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment