Skip to content

Instantly share code, notes, and snippets.

@advanceboy
Last active February 1, 2024 17:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save advanceboy/717fde162a6f9ccb592f04898f0aacc1 to your computer and use it in GitHub Desktop.
Save advanceboy/717fde162a6f9ccb592f04898f0aacc1 to your computer and use it in GitHub Desktop.
rinna/japanese-gpt-neox-3.6b-instruction-sft もとい rinna/japanese-gpt-neox-3.6b-instruction-ppo と gradio を使ったチャット UI のサンプル実装です。 transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ出力し、gradio の UI でユーザーに表示させています。
# coding=utf-8
# License: CC0
"""
rinna/japanese-gpt-neox-3.6b-instruction-ppo と gradio を使ったチャット UI のサンプル実装です。
-> https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo
transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ表示し、ユーザー体験を向上させています。
-> https://huggingface.co/docs/transformers/v4.29.1/en/internal/generation_utils#transformers.TextIteratorStreamer
streamer クラスの API は開発中のため、近い将来互換性がなくなる可能性があります。
transformers==4.37.2 gradio==4.16.0 での動作を確認しています。
環境作成手順
1. CUDA Toolkit のインストール https://developer.nvidia.com/cuda-toolkit-archive
2. CUDA の環境にあわせた PyTorch のパッケージを pip で追加
* `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118`
3. transformers 関連のパッケージを pip で追加
* `pip3 install ipython sentencepiece transformers accelerate gradio`
4. Python でスクリプトを実行
* `python rinna_gradio_chat.py`
* 初回実行時、 huggingface.co からモデルを DL にするのに時間がかかったり、失敗したりする場合があります。
5. コンソールに表示された URL <http://127.0.0.1:7860> にブラウザでアクセスする
pip パッケージを入れる際は、 venv などで仮想環境を作成しておくことを強くおすすめします。
"""
import re
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# 定数宣言
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
torch_dtype = torch.bfloat16
max_new_tokens = 512
temperature = 0.7
speaker_name_user = 'ユーザー'
speaker_name_system = 'システム'
start_message = '<NL>'.join([
f'{speaker_name_user}: こんにちは。',
f'{speaker_name_system}: こんにちは、私は{speaker_name_system}です。あなたの質問に適切な回答をします。どのようなご用件ですか?'
])
# モデルの初期化
if not torch.cuda.is_available():
raise 'CUDA is not available'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch_dtype)
def user(message, history):
# history にユーザーメッセージを追加
return "", history + [[message, ""]]
def chat(curr_system_message, history):
# プロンプトの作成
prompt = [
f"{speaker_name_user}: {item[0]}"
'<NL>'
f"{speaker_name_system}: {item[1]}"
for item in history
]
prompt = '<NL>'.join(prompt)
prompt = (curr_system_message
+ '<NL>'
+ prompt
+ '<NL>'
+ f'{speaker_name_system}: '
).replace("\n", '<NL>')
# テキスト生成の開始
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_args = [token_ids.to(model.device)]
generation_kwargs = dict(
streamer=streamer,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, args=generation_args, kwargs=generation_kwargs)
thread.start()
# TextIteratorStreamer を使った生成結果の受け取り
print(f'{speaker_name_system}: ', end='')
generated_text = ''
pending_buffer = ''
for next_text in streamer:
if not next_text:
continue
print(next_text.replace('<NL>', "\n"), end='', flush=True)
last_pending_buffer = pending_buffer
generated_token = re.sub('</s>$', '', next_text)
pending_buffer = '' if next_text == generated_token else '</s>'
generated_text += last_pending_buffer + generated_token.replace('<NL>', "\n")
history[-1][1] = generated_text
yield history
print('')
# 生成結果
return generated_text
with gr.Blocks() as app:
gr.Markdown(f"## {model_name} Chat")
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", show_label=False, container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
system_msg = gr.Textbox(start_message, label="System Message", interactive=False, visible=False)
submit_kwargs = dict(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False)
submit_then_kwargs = dict(fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True)
submit_event = msg.submit(**submit_kwargs).then(**submit_then_kwargs)
submit_click_event = submit.click(**submit_kwargs).then(**submit_then_kwargs)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=False)
app.queue(max_size=32)
app.launch(max_threads=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment