Skip to content

Instantly share code, notes, and snippets.

@mutaguchi
Last active September 30, 2023 13:38
Show Gist options
  • Save mutaguchi/9be172a3a127dc7173707f31eaa41c15 to your computer and use it in GitHub Desktop.
Save mutaguchi/9be172a3a127dc7173707f31eaa41c15 to your computer and use it in GitHub Desktop.
rinna_chat

japanese-gpt-neox-3.6b-instruction-sftを使ったチャットサンプル

https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft

python rinna_chat.py で実行できます。

rinna_chat.py -p test.json のように会話設定JSONを指定することもできます。

test.json

{
	"speaker1" : "ユーザー",
	"speaker2": "アシスタント",
	"system_message": "アシスタントは有能なAIアシスタントです。ユーザーの質問に対し、適切な回答を行います。",
	"initial_messages" : [
	    {
	        "speaker": "ユーザー",
	        "text": "こんにちは。"
	    },
	    {
	        "speaker": "アシスタント",
	        "text": "こんにちは、ユーザーさん。何かお手伝いできることはありますか?"
	    }
	]
}

こんな使い方も…

作品名.json

{
	"speaker1" : "作品名",
	"speaker2": "キャラ名",
	"system_message": "「作品名」に登場する「キャラ名」を挙げてください。",
	"initial_messages" : [
	    {
	        "speaker": "作品名",
	        "text": "ご注文はうさぎですか?"
	    },
	    {
	        "speaker": "キャラ名",
	        "text": "チノ, ココア, リゼ, シャロ, マヤ, メグ"
	    },
	    {
	        "speaker": "作品名",
	        "text": "ドラえもん"
	    },
	    {
	        "speaker": "キャラ名",
	        "text": "ドラえもん, のび太, スネ夫, ジャイアン, しずか"
	    }
	]
}

rinna_chat.py -p 作品名.json -i 魔法少女まどかマギカ

鹿目まどか, 暁美ほむら, 巴マミ, 美樹さやか
# japanese-gpt-neox-3.6b-instruction-sftを使ったチャットサンプル
# 以下のコマンドが使えます。
# retry: 回答を再生成します。
# clear, cls: 会話ログを消去します。
# exit, end: スクリプトを終了します。
# 以下のコマンドライン引数が使えます。
# --path, -p: 会話設定用JSONファイルのパス(省略可)
# --input, -i: 入力テキストを指定。実行したらすぐに終了する。(省略可)
# ------------設定------------
use_cuda = True # cudaを使用するかどうか
fp16 = True # モデルを半精度化する
max_history = 5 # 履歷の最大長
max_completion_count = 3 # 補完が途切れた場合、自動補完継続する回数
max_new_tokens= 256 # 補完時の最大トークン長
temperature = 0.7 # 補完の温度
# ------------会話設定(JSONファイルでも指定可)------------
speaker1 = "ユーザー" # デフォルトユーザー名
speaker2 = "アシスタント" # デフォルトAI名
# デフォルトシステムメッセージ
system_message = f"{speaker2}は有能なAIアシスタントです。{speaker1}の質問に対し、適切な回答を行います。"
# デフォルト初期プロンプト
initial_messages = [
{
"speaker": speaker1,
"text": "こんにちは。"
},
{
"speaker": speaker2,
"text": f"こんにちは、{speaker1}さん。何かお手伝いできることはありますか?"
},
]
# ------------コードここから------------
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import argparse
import os
import re
json_path = ""
parser = argparse.ArgumentParser()
parser.add_argument('--path','-p')
parser.add_argument('--input', '-i', dest='input_text')
args = parser.parse_args()
current_dir = os.path.dirname(os.path.abspath(__file__))
if args.path:
json_path = args.path if os.path.isabs(args.path) else os.path.join(
current_dir, args.path)
if json_path:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
system_message = data['system_message']
speaker1 = data['speaker1']
speaker2 = data['speaker2']
initial_messages = data['initial_messages']
messages = []
system_message = {
"speaker": "システム",
"text": system_message
}
messages.extend(initial_messages)
model_name="rinna/japanese-gpt-neox-3.6b-instruction-sft"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if use_cuda and torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16) if fp16 else AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
else:
model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu")
def complete(prompt):
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
return output
while True:
input_text = args.input_text if args.input_text else input(f"{speaker1}: ")
input_text = input_text.replace("\n", "<NL>")
if input_text == "exit" or input_text == "end":
break
elif input_text == "retry":
messages.pop()
elif input_text == "clear" or input_text == "cls":
messages.clear()
messages.extend(initial_messages)
continue
else:
messages.append(
{
"speaker":speaker1,
"text":input_text
}
)
while len(messages) > max_history:
messages.pop(0)
prompt = [
f"{uttr['speaker']}: {uttr['text']}"
for uttr in [system_message] + messages
]
prompt = "<NL>".join(prompt)
prompt = prompt + "<NL>" + f"{speaker2}: "
completion_count = 0
output = ""
if not args.input_text:
print(f"{speaker2}: ",end="")
while not output.endswith("</s>") and completion_count < max_completion_count:
current = complete(prompt + output)
output += current
completion_count += 1
output = output.replace("</s>", "")
pattern = re.compile(r"^(.+?)<NL>" + speaker1 + ": ", re.DOTALL)
match = re.search(pattern, output)
if match:
output = match.group(1)
output = output.replace(speaker2 + ": ", "")
messages.append(
{
"speaker": speaker2,
"text": output
}
)
output = output.replace("<NL>", "\n")
print(output)
if args.input_text:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment