Skip to content

Instantly share code, notes, and snippets.

@advanceboy
Created May 19, 2023 16:19
Show Gist options
  • Save advanceboy/b9143aa9de23a6f9a60a07a862e0b4a8 to your computer and use it in GitHub Desktop.
Save advanceboy/b9143aa9de23a6f9a60a07a862e0b4a8 to your computer and use it in GitHub Desktop.
rinna/japanese-gpt-neox-3.6b-instruction-sft を使ったチャット UI のサンプル実装です。 transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ表示し、ユーザー体験を向上させています。
# coding=utf-8
# License: CC0
"""
rinna/japanese-gpt-neox-3.6b-instruction-sft を使ったチャット UI のサンプル実装です。
-> https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ表示し、ユーザー体験を向上させています。
-> https://huggingface.co/docs/transformers/v4.29.1/en/internal/generation_utils#transformers.TextIteratorStreamer
ユーザー入力には、以下のコマンドが使えます。
clear : すべての入力履歴をクリアし、初期プロンプト状態にリセットします。
exit : プログラムを終了します。
retry : 前回と同じプロンプトでテキストを再生成します。
streamer クラスの API は開発中のため、近い将来互換性がなくなる可能性があります。
transformers==4.29.2 での動作を確認しています。
環境作成手順
1. CUDA Toolkit のインストール https://developer.nvidia.com/cuda-toolkit-archive
2. CUDA の環境にあわせた PyTorch のパッケージを pip で追加
* `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118`
3. transformers 関連のパッケージを pip で追加
* `pip3 install ipython sentencepiece transformers accelerate`
4. Python でスクリプトを実行
* `python rinna_chat_streaming.py`
* 初回実行時、 huggingface.co からモデルを DL にするのに時間がかかったり、失敗したりする場合があります。
pip パッケージを入れる際は、 venv などで仮想環境を作成しておくことを強くおすすめします。
"""
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# 定数宣言
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-sft"
torch_dtype = torch.bfloat16
max_new_tokens = 256
temperature = 0.7
max_history = 16
speaker_name_user = 'ユーザー'
speaker_name_system = 'システム'
initial_messages = [
{
'speaker': speaker_name_user,
'text': 'こんにちは。'
},
{
'speaker': speaker_name_system,
'text': f'こんにちは、私は{speaker_name_system}です。あなたの質問に適切な回答をします。どのようなご用件ですか?'
},
]
messages = []
messages.extend(initial_messages)
# モデルの初期化
if not torch.cuda.is_available():
raise 'CUDA is not available'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch_dtype)
output = None
while True:
input_text = input(f'{speaker_name_user}: ')
# コマンドの処理
if input_text == 'clear':
messages.clear()
messages.extend(initial_messages)
print('履歴が初期化されました')
continue
elif input_text == 'exit':
break
elif input_text == 'retry':
pass
else:
if output:
messages.append({
'speaker': speaker_name_system,
'text': output
})
messages.append({
'speaker': speaker_name_user,
'text': input_text
})
if len(messages) > max_history:
del messages[0:(len(messages)-max_history)]
# プロンプトの作成
prompt = [
f"{uttr['speaker']}: {uttr['text']}"
for uttr in messages
]
prompt = "<NL>".join(prompt)
prompt = (
prompt
+ "<NL>"
+ f'{speaker_name_system}: '
)
# テキスト生成の開始
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_args = [token_ids.to(model.device)]
generation_kwargs = dict(
streamer=streamer,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, args=generation_args, kwargs=generation_kwargs)
thread.start()
# TextIteratorStreamer を使った生成結果の受け取り
print(f'{speaker_name_system}: ', end='')
generated_text = ''
for next_text in streamer:
if not next_text:
continue
print(next_text.replace('<NL>', "\n"), end='', flush=True)
generated_text += next_text
print('')
# 生成結果
output = re.sub('</s>$', '', generated_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment