Skip to content

Instantly share code, notes, and snippets.

@lewtun
Last active October 6, 2023 12:57
Show Gist options
  • Save lewtun/51ea3ad9a86c73caa4e3847f5af512e6 to your computer and use it in GitHub Desktop.
Save lewtun/51ea3ad9a86c73caa4e3847f5af512e6 to your computer and use it in GitHub Desktop.
Dialogue template
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union
from h4.utils import IGNORE_INDEX
from huggingface_hub import ModelHubMixin, hf_hub_download
# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
TEMPLATE_FILENAME = "dialogue_template.json"
@dataclass
class DialogueTemplate(ModelHubMixin):
"""Converts all turns of a dialogue between a user and assistant to a standardized format.
Adapted from OpenAI's ChatML (https://github.com/openai/openai-python/blob/main/chatml.md) and FastChat (https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py)
"""
system: str
name: str = "default"
messages: List[Dict[str, str]] = None
system_token: str = "<|system|>"
system_format: Literal["standard", "no_leading_space", "compressed", "empty"] = "standard"
user_token: str = "<|user|>"
assistant_token: str = "<|assistant|>"
end_token: Union[str, list] = "<|end|>"
mid_str: str = "\n"
end_str: Union[str, list] = "\n"
extra_end_text: str = ""
def get_end_token(self, idx):
if isinstance(self.end_token, str):
return self.end_token
else:
return self.end_token[idx % 2]
def get_end_str(self, idx):
if isinstance(self.end_str, str):
return self.end_str
else:
return self.end_str[idx % 2]
def get_system_prompt(self):
if self.system_format == "standard": # General case
# If the first message is a system message, then we need to add the system token
if self.messages is not None and self.messages[0]["role"] == "system":
prompt = (
self.system_token
+ self.mid_str
+ self.messages[0]["content"]
+ self.get_end_token(0)
+ self.get_end_str(0)
)
else:
prompt = self.system_token + self.mid_str + self.system + self.get_end_token(0) + self.get_end_str(0)
elif self.system_format == "no_leading_space": # Koala / Vicuna
prompt = self.system_token + self.system + self.get_end_token(0) + self.get_end_str(0)
elif self.system_format == "compressed": # Dolly case
prompt = self.system + self.get_end_token(0)
elif self.system_format == "empty": # OAsst case
prompt = ""
return prompt
def get_training_prompt(self) -> str:
prompt = self.get_system_prompt()
if self.messages is None:
raise ValueError("Dialogue template must have at least one message.")
for i, message in enumerate(self.messages):
end_token = self.get_end_token(i)
end_str = self.get_end_str(i)
if message["role"] == "system":
continue
elif message["role"] == "user":
prompt += self.user_token + self.mid_str + message["content"] + end_token + end_str
else:
prompt += self.assistant_token + self.mid_str + message["content"] + end_token + end_str
return prompt
def get_inference_prompt(self) -> str:
prompt = self.get_system_prompt()
if self.messages is None:
raise ValueError("Dialogue template must have at least one message.")
for i, message in enumerate(self.messages):
end_token = self.get_end_token(i)
end_str = self.get_end_str(i)
if message["role"] == "system":
continue
elif message["role"] == "user":
prompt += self.user_token + self.mid_str + message["content"] + end_token + end_str
else:
prompt += self.assistant_token + self.mid_str + message["content"] + end_token + end_str
prompt += self.assistant_token + self.extra_end_text
return prompt
def get_special_tokens(self) -> List[str]:
"""
Helper function to get the special tokens of the dialogue template and remove text formatting (e.g. not tokens).
"""
if type(self.end_token) == list:
tokens = [self.system_token, self.user_token, self.assistant_token] + self.end_token
[tokens.remove(s) for s in ["", " ", "\n\n"] if s in tokens]
return tokens
else:
tokens = [self.system_token, self.user_token, self.assistant_token, self.end_token]
[tokens.remove(s) for s in ["", " ", "\n\n"] if s in tokens]
return tokens
def copy(self):
return DialogueTemplate(
system=self.system,
name=self.name,
messages=self.messages,
system_token=self.system_token,
system_format=self.system_format,
user_token=self.user_token,
assistant_token=self.assistant_token,
end_token=self.end_token,
mid_str=self.mid_str,
end_str=self.end_str,
extra_end_text=self.extra_end_text,
)
def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in asdict(self).items()}
@classmethod
def from_dict(cls, data):
return DialogueTemplate(
system=data.get("system", ""),
name=data.get("name", ""),
messages=data.get("messages", None),
system_token=data.get("system_token", "<|system|>"),
system_format=data.get("system_format", "standard"),
user_token=data.get("user_token", "<|user|>"),
assistant_token=data.get("assistant_token", "<|assistant|>"),
end_token=data.get("end_token", "<|end|>"),
mid_str=data.get("mid_str", "\n"),
end_str=data.get("end_str", "\n"),
extra_end_text=data.get("extra_end_text", ""),
)
def _save_pretrained(self, save_directory: Union[str, Path]) -> None:
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True)
with open(save_directory / "dialogue_template.json", "w") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def _from_pretrained(
cls: Type[T],
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Optional[Union[str, bool]],
**model_kwargs,
) -> T:
"""Loads the dialogue template from a local directory or the Huggingface Hub.
Args:
model_id (`str`):
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
revision (`str`, *optional*):
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
the existing cache.
resume_download (`bool`, *optional*, defaults to `False`):
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`).
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
model_kwargs:
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
"""
if os.path.isdir(model_id): # Can either be a local directory
print("Loading dialogue template from local directory")
template_file = os.path.join(model_id, TEMPLATE_FILENAME)
else: # Or a template on the Hub
template_file = hf_hub_download( # Download from the hub, passing same input args
repo_id=model_id,
filename=TEMPLATE_FILENAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Load template
with open(template_file, "r") as f:
data = json.load(f)
return cls.from_dict(data=data)
# Template for Falcon / Starcoder / Pythia models with EOS token used to indicate end of turn
default_v2_template = DialogueTemplate(system="", name="default_v2", end_token="<|endoftext|>", extra_end_text="\n")
alpaca_template = DialogueTemplate(
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
name="alpaca",
user_token="### Instruction:",
assistant_token="### Response:",
)
# Sourced from FastChat https://github.com/lm-sys/FastChat/blob/75d8ab26ee308f9cf0990976508232f06dd421e4/fastchat/conversation.py#L234
vicuna_template = DialogueTemplate(
system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
name="vicuna",
user_token="### USER:",
assistant_token="### ASSISTANT:",
)
# for Vicuna v1.1
vicuna_template_v2 = DialogueTemplate(
system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.",
name="vicuna_v2",
system_token="", # no system token
system_format="no_leading_space",
user_token="USER:",
assistant_token="ASSISTANT:",
end_token=["", "</s>"],
mid_str=" ",
end_str=[" ", ""],
)
koala_template = DialogueTemplate(
system="BEGINNING OF CONVERSATION:",
name="koala",
system_token="",
system_format="no_leading_space",
user_token="USER:",
assistant_token="GPT:",
end_token=["", "</s>"],
mid_str=" ",
end_str=[" ", ""],
)
dolly_template = DialogueTemplate(
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
name="dolly",
system_token="",
system_format="compressed",
user_token="### Instruction",
assistant_token="### Response",
end_token=["\n\n", "### End"],
end_str=["", "\n\n"],
mid_str=":\n",
extra_end_text=":\n",
)
oasst_template = DialogueTemplate(
system="",
name="oasst",
system_token="",
system_format="empty",
user_token="<|prompter|>",
assistant_token="<|assistant|>",
end_token="<|endoftext|>",
mid_str="",
end_str="",
)
SUPPORTED_DIALOGUE_TEMPLATES = {
"default_v2": default_v2_template,
"vicuna": vicuna_template,
"vicuna_v2": vicuna_template_v2,
"koala": koala_template,
"dolly": dolly_template,
"alpaca": alpaca_template,
"oasst": oasst_template,
}
def get_dialogue_template(template: str) -> DialogueTemplate:
if template not in SUPPORTED_DIALOGUE_TEMPLATES.keys():
raise ValueError(f"Template {template} is not supported!")
return SUPPORTED_DIALOGUE_TEMPLATES[template].copy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment