Created
April 18, 2024 13:09
-
-
Save kohya-ss/37f4c5ef8171cbb2b6cc1f4fd7999b89 to your computer and use it in GitHub Desktop.
llama-cpp-python と gradio で command-r-plus を動かす
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Apache License 2.0 | |
# 使用法は gist のコメントを見てください | |
import argparse | |
from typing import List, Optional, Union, Iterator | |
from llama_cpp import Llama | |
from llama_cpp.llama_tokenizer import LlamaHFTokenizer | |
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler | |
import llama_cpp.llama_types as llama_types | |
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar | |
from transformers import AutoTokenizer | |
import gradio as gr | |
from llama_cpp import Llama | |
import gradio as gr | |
from transformers import AutoTokenizer | |
MODEL_ID = "CohereForAI/c4ai-command-r-plus" | |
MAX_TOKENS_IN_CHAT_MODE = 1024 | |
@register_chat_completion_handler("command-r") | |
def command_r_chat_handler( | |
llama: Llama, | |
messages: List[llama_types.ChatCompletionRequestMessage], | |
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, | |
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, | |
tools: Optional[List[llama_types.ChatCompletionTool]] = None, | |
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
stream: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, | |
max_tokens: Optional[int] = None, | |
presence_penalty: float = 0.0, | |
frequency_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
**kwargs, # type: ignore | |
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: | |
# bos_token = "<BOS_TOKEN>" # not sure if this is needed | |
start_turn_token = "<|START_OF_TURN_TOKEN|>" | |
end_turn_token = "<|END_OF_TURN_TOKEN|>" | |
user_token = "<|USER_TOKEN|>" | |
chatbot_token = "<|CHATBOT_TOKEN|>" | |
# prompt = bos_token + start_turn_token | |
prompt = start_turn_token | |
for message in messages: | |
if message["role"] == "user": | |
prompt += user_token + message["content"] + end_turn_token + start_turn_token | |
elif message["role"] == "assistant": | |
prompt += chatbot_token + message["content"] + end_turn_token + start_turn_token | |
prompt += chatbot_token | |
stop_tokens = [end_turn_token] # , bos_token] | |
return _convert_completion_to_chat( | |
llama.create_completion( | |
prompt=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
typical_p=typical_p, | |
stream=stream, | |
stop=stop_tokens, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
repeat_penalty=repeat_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
), | |
stream=stream, | |
) | |
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
global stop_generating | |
stop_generating = False | |
output = prompt | |
for chunk in llama( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repeat_penalty=repeat_penalty, | |
top_k=top_k, | |
stream=True, | |
): | |
# print(chunk) # uncomment to show each chunk | |
if stop_generating: | |
break | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "text" in chunk["choices"][0]: | |
text = chunk["choices"][0]["text"] | |
# check EOS_TOKEN | |
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN): | |
output += text[: -len("<EOS_TOKEN>")] | |
yield output[len(prompt) :] | |
break | |
output += text | |
yield output[len(prompt) :] # remove prompt | |
def launch_completion(llama, listen=False): | |
# css = """ | |
# .prompt textarea {font-size:1.0em !important} | |
# """ | |
# with gr.Blocks(css=css) as demo: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# change font size | |
io_textbox = gr.Textbox( | |
label="Input/Output", | |
placeholder="Enter your prompt here...", | |
interactive=True, | |
elem_classes=["prompt"], | |
) | |
with gr.Row(): | |
generate_button = gr.Button("Generate") | |
stop_button = gr.Button("Stop", visible=False) | |
with gr.Row(): | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty") | |
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K") | |
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k) | |
for output in output_generator: | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True) | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False) | |
def stop_generation(): | |
globals().update(stop_generating=True) | |
return gr.update(visible=True), gr.update(visible=False) | |
generate_button.click( | |
generate_and_display, | |
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k], | |
outputs=[io_textbox, generate_button, stop_button], | |
show_progress=True, | |
) | |
stop_button.click( | |
stop_generation, | |
outputs=[generate_button, stop_button], | |
) | |
# add event to textbox to add new line on enter | |
io_textbox.submit( | |
lambda x: x + "\n", | |
inputs=[io_textbox], | |
outputs=[io_textbox], | |
) | |
demo.launch(server_name="0.0.0.0" if listen else None) | |
def launch_chat(llama, listen=False): | |
# GUI for model parameters is not implemented yet in chat mode | |
model_kwargs = { | |
# "temperature": 0.3, | |
# "top_p": 0.95, | |
# "top_k": 40, | |
# "min_p": 0.05, | |
# "typical_p": 1.0, | |
# "stream": False, | |
# "stop": [], | |
"max_tokens": MAX_TOKENS_IN_CHAT_MODE # max tokens for generation | |
} | |
def chat(message, history): | |
user_input = message | |
messages = [] | |
for message in history: | |
messages.append({"role": "user", "content": message[0]}) | |
messages.append({"role": "assistant", "content": message[1]}) | |
messages.append({"role": "user", "content": user_input}) | |
# print("debug: messages", messages) | |
chat_completion_chunks = command_r_chat_handler(llama=llama, messages=messages, stream=True, **model_kwargs) | |
response = "" | |
for chunk in chat_completion_chunks: | |
# print(chunk) # uncomment to show each chunk | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "delta" in chunk["choices"][0]: | |
if "content" in chunk["choices"][0]["delta"]: | |
response += chunk["choices"][0]["delta"]["content"] | |
yield response | |
chatbot = gr.ChatInterface(chat) | |
chatbot.launch(server_name="0.0.0.0" if listen else None) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path") | |
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers") | |
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length") | |
parser.add_argument("--chat", action="store_true", help="Chat mode") | |
parser.add_argument("--listen", action="store_true", help="Listen mode") | |
args = parser.parse_args() | |
print("Initializing tokenizer") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}") | |
llama_tokenizer = LlamaHFTokenizer(tokenizer) | |
llama = Llama( | |
model_path=args.model, | |
n_gpu_layers=args.n_gpu_layers, | |
# tensor_split=tensor_split, | |
n_ctx=args.n_ctx, | |
tokenizer=llama_tokenizer, | |
# n_threads=n_threads, | |
) | |
print("Launching Gradio") | |
if args.chat: | |
launch_chat(llama, args.listen) | |
else: | |
launch_completion(llama, args.listen) |
Command-R v01に対応したバージョンも作りましたので、そちらもご利用ください。
https://gist.github.com/kohya-ss/e23fa9a321dba07fabd1ef61eab6863c
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
使い方
VRAM 24GB とメイン RAM 64GB で Command-R+ Q4 を動かす例です。
Command-R+ の Q3 と Q4 の間には少なくない性能差があるようなので Q4 を動かします。
extra-index-url
の末尾はインストールされている CUDA バージョンに合わせてください。cu121, cu122, cu123 が公式リポジトリには書かれています(ただ、このコメント執筆時点では cu123 は存在しませんでした)。CUDA のマイナーバージョン(末尾一桁)が違ってもだいたい動くことが多いです。llama-cpp-python==0.2.62
のように指定してください。その後、ブラウザで http://localhost:7860/ を開いてください。
-ngl 22 -c 2048
で、Q4_K_M がギリギリ動くと思います。IQ4_XS なら-ngl 24 -c 2048
でメイン RAM の使用量を減らすと良さそうです。。-c 2048
はコンテキスト長 2048 で、2048 トークンまで入力を受け付けます。このくらいあればまあまあの分量、応答できると思います。オプション
-ngl でいくつのレイヤーを GPU で処理するか指定します。多くすると VRAM を使いますが、それだけ速くなります。また少なくするとメイン RAM の使用量が増え、CPU で処理する量が増えるので遅くなります。
-c でコンテキスト長を指定します。長くするとそれだけメモリを使います。
--chat でチャットモードで起動します。デフォルトは文章の続きを生成する補完モードです。
--listen オプションをつけると LAN 内の他の PC からアクセス可能になります。
その他
Q3、Q4 などの数値は量子化ビット数で、K_M や K_S、IQ?_XS などは量子化手法の違いのようです。PPL Value が低いほど性能劣化が抑えられているようです。
RAM/VRAM とも余裕がなければ量子化ビット数の少ないモデル(サイズの小さいモデル)選んでください。