Last active
October 6, 2023 12:57
-
-
Save lewtun/51ea3ad9a86c73caa4e3847f5af512e6 to your computer and use it in GitHub Desktop.
Dialogue template
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
from dataclasses import asdict, dataclass | |
from pathlib import Path | |
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union | |
from h4.utils import IGNORE_INDEX | |
from huggingface_hub import ModelHubMixin, hf_hub_download | |
# Generic variable that is either ModelHubMixin or a subclass thereof | |
T = TypeVar("T", bound="ModelHubMixin") | |
TEMPLATE_FILENAME = "dialogue_template.json" | |
@dataclass | |
class DialogueTemplate(ModelHubMixin): | |
"""Converts all turns of a dialogue between a user and assistant to a standardized format. | |
Adapted from OpenAI's ChatML (https://github.com/openai/openai-python/blob/main/chatml.md) and FastChat (https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) | |
""" | |
system: str | |
name: str = "default" | |
messages: List[Dict[str, str]] = None | |
system_token: str = "<|system|>" | |
system_format: Literal["standard", "no_leading_space", "compressed", "empty"] = "standard" | |
user_token: str = "<|user|>" | |
assistant_token: str = "<|assistant|>" | |
end_token: Union[str, list] = "<|end|>" | |
mid_str: str = "\n" | |
end_str: Union[str, list] = "\n" | |
extra_end_text: str = "" | |
def get_end_token(self, idx): | |
if isinstance(self.end_token, str): | |
return self.end_token | |
else: | |
return self.end_token[idx % 2] | |
def get_end_str(self, idx): | |
if isinstance(self.end_str, str): | |
return self.end_str | |
else: | |
return self.end_str[idx % 2] | |
def get_system_prompt(self): | |
if self.system_format == "standard": # General case | |
# If the first message is a system message, then we need to add the system token | |
if self.messages is not None and self.messages[0]["role"] == "system": | |
prompt = ( | |
self.system_token | |
+ self.mid_str | |
+ self.messages[0]["content"] | |
+ self.get_end_token(0) | |
+ self.get_end_str(0) | |
) | |
else: | |
prompt = self.system_token + self.mid_str + self.system + self.get_end_token(0) + self.get_end_str(0) | |
elif self.system_format == "no_leading_space": # Koala / Vicuna | |
prompt = self.system_token + self.system + self.get_end_token(0) + self.get_end_str(0) | |
elif self.system_format == "compressed": # Dolly case | |
prompt = self.system + self.get_end_token(0) | |
elif self.system_format == "empty": # OAsst case | |
prompt = "" | |
return prompt | |
def get_training_prompt(self) -> str: | |
prompt = self.get_system_prompt() | |
if self.messages is None: | |
raise ValueError("Dialogue template must have at least one message.") | |
for i, message in enumerate(self.messages): | |
end_token = self.get_end_token(i) | |
end_str = self.get_end_str(i) | |
if message["role"] == "system": | |
continue | |
elif message["role"] == "user": | |
prompt += self.user_token + self.mid_str + message["content"] + end_token + end_str | |
else: | |
prompt += self.assistant_token + self.mid_str + message["content"] + end_token + end_str | |
return prompt | |
def get_inference_prompt(self) -> str: | |
prompt = self.get_system_prompt() | |
if self.messages is None: | |
raise ValueError("Dialogue template must have at least one message.") | |
for i, message in enumerate(self.messages): | |
end_token = self.get_end_token(i) | |
end_str = self.get_end_str(i) | |
if message["role"] == "system": | |
continue | |
elif message["role"] == "user": | |
prompt += self.user_token + self.mid_str + message["content"] + end_token + end_str | |
else: | |
prompt += self.assistant_token + self.mid_str + message["content"] + end_token + end_str | |
prompt += self.assistant_token + self.extra_end_text | |
return prompt | |
def get_special_tokens(self) -> List[str]: | |
""" | |
Helper function to get the special tokens of the dialogue template and remove text formatting (e.g. not tokens). | |
""" | |
if type(self.end_token) == list: | |
tokens = [self.system_token, self.user_token, self.assistant_token] + self.end_token | |
[tokens.remove(s) for s in ["", " ", "\n\n"] if s in tokens] | |
return tokens | |
else: | |
tokens = [self.system_token, self.user_token, self.assistant_token, self.end_token] | |
[tokens.remove(s) for s in ["", " ", "\n\n"] if s in tokens] | |
return tokens | |
def copy(self): | |
return DialogueTemplate( | |
system=self.system, | |
name=self.name, | |
messages=self.messages, | |
system_token=self.system_token, | |
system_format=self.system_format, | |
user_token=self.user_token, | |
assistant_token=self.assistant_token, | |
end_token=self.end_token, | |
mid_str=self.mid_str, | |
end_str=self.end_str, | |
extra_end_text=self.extra_end_text, | |
) | |
def to_dict(self) -> Dict[str, Any]: | |
return {k: v for k, v in asdict(self).items()} | |
@classmethod | |
def from_dict(cls, data): | |
return DialogueTemplate( | |
system=data.get("system", ""), | |
name=data.get("name", ""), | |
messages=data.get("messages", None), | |
system_token=data.get("system_token", "<|system|>"), | |
system_format=data.get("system_format", "standard"), | |
user_token=data.get("user_token", "<|user|>"), | |
assistant_token=data.get("assistant_token", "<|assistant|>"), | |
end_token=data.get("end_token", "<|end|>"), | |
mid_str=data.get("mid_str", "\n"), | |
end_str=data.get("end_str", "\n"), | |
extra_end_text=data.get("extra_end_text", ""), | |
) | |
def _save_pretrained(self, save_directory: Union[str, Path]) -> None: | |
save_directory = Path(save_directory) | |
save_directory.mkdir(exist_ok=True) | |
with open(save_directory / "dialogue_template.json", "w") as f: | |
json.dump(self.to_dict(), f, indent=2) | |
@classmethod | |
def _from_pretrained( | |
cls: Type[T], | |
*, | |
model_id: str, | |
revision: Optional[str], | |
cache_dir: Optional[Union[str, Path]], | |
force_download: bool, | |
proxies: Optional[Dict], | |
resume_download: bool, | |
local_files_only: bool, | |
token: Optional[Union[str, bool]], | |
**model_kwargs, | |
) -> T: | |
"""Loads the dialogue template from a local directory or the Huggingface Hub. | |
Args: | |
model_id (`str`): | |
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). | |
revision (`str`, *optional*): | |
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the | |
latest commit on `main` branch. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding | |
the existing cache. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`). | |
token (`str` or `bool`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. By default, it will use the token | |
cached when running `huggingface-cli login`. | |
cache_dir (`str`, `Path`, *optional*): | |
Path to the folder where cached files are stored. | |
local_files_only (`bool`, *optional*, defaults to `False`): | |
If `True`, avoid downloading the file and return the path to the local cached file if it exists. | |
model_kwargs: | |
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. | |
""" | |
if os.path.isdir(model_id): # Can either be a local directory | |
print("Loading dialogue template from local directory") | |
template_file = os.path.join(model_id, TEMPLATE_FILENAME) | |
else: # Or a template on the Hub | |
template_file = hf_hub_download( # Download from the hub, passing same input args | |
repo_id=model_id, | |
filename=TEMPLATE_FILENAME, | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
# Load template | |
with open(template_file, "r") as f: | |
data = json.load(f) | |
return cls.from_dict(data=data) | |
# Template for Falcon / Starcoder / Pythia models with EOS token used to indicate end of turn | |
default_v2_template = DialogueTemplate(system="", name="default_v2", end_token="<|endoftext|>", extra_end_text="\n") | |
alpaca_template = DialogueTemplate( | |
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", | |
name="alpaca", | |
user_token="### Instruction:", | |
assistant_token="### Response:", | |
) | |
# Sourced from FastChat https://github.com/lm-sys/FastChat/blob/75d8ab26ee308f9cf0990976508232f06dd421e4/fastchat/conversation.py#L234 | |
vicuna_template = DialogueTemplate( | |
system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", | |
name="vicuna", | |
user_token="### USER:", | |
assistant_token="### ASSISTANT:", | |
) | |
# for Vicuna v1.1 | |
vicuna_template_v2 = DialogueTemplate( | |
system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", | |
name="vicuna_v2", | |
system_token="", # no system token | |
system_format="no_leading_space", | |
user_token="USER:", | |
assistant_token="ASSISTANT:", | |
end_token=["", "</s>"], | |
mid_str=" ", | |
end_str=[" ", ""], | |
) | |
koala_template = DialogueTemplate( | |
system="BEGINNING OF CONVERSATION:", | |
name="koala", | |
system_token="", | |
system_format="no_leading_space", | |
user_token="USER:", | |
assistant_token="GPT:", | |
end_token=["", "</s>"], | |
mid_str=" ", | |
end_str=[" ", ""], | |
) | |
dolly_template = DialogueTemplate( | |
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", | |
name="dolly", | |
system_token="", | |
system_format="compressed", | |
user_token="### Instruction", | |
assistant_token="### Response", | |
end_token=["\n\n", "### End"], | |
end_str=["", "\n\n"], | |
mid_str=":\n", | |
extra_end_text=":\n", | |
) | |
oasst_template = DialogueTemplate( | |
system="", | |
name="oasst", | |
system_token="", | |
system_format="empty", | |
user_token="<|prompter|>", | |
assistant_token="<|assistant|>", | |
end_token="<|endoftext|>", | |
mid_str="", | |
end_str="", | |
) | |
SUPPORTED_DIALOGUE_TEMPLATES = { | |
"default_v2": default_v2_template, | |
"vicuna": vicuna_template, | |
"vicuna_v2": vicuna_template_v2, | |
"koala": koala_template, | |
"dolly": dolly_template, | |
"alpaca": alpaca_template, | |
"oasst": oasst_template, | |
} | |
def get_dialogue_template(template: str) -> DialogueTemplate: | |
if template not in SUPPORTED_DIALOGUE_TEMPLATES.keys(): | |
raise ValueError(f"Template {template} is not supported!") | |
return SUPPORTED_DIALOGUE_TEMPLATES[template].copy() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment