Skip to content

Instantly share code, notes, and snippets.

@mutaguchi
Created August 20, 2023 00:26
Show Gist options
  • Save mutaguchi/2a4fd9a9d90b0103be6cdd4e9b628cd5 to your computer and use it in GitHub Desktop.
Save mutaguchi/2a4fd9a9d90b0103be6cdd4e9b628cd5 to your computer and use it in GitHub Desktop.
from transformers import GPTJForCausalLM, AlbertTokenizer
import torch
model = 'AIBunCho/japanese-novel-gpt-j-6b'
tokenizer = AlbertTokenizer.from_pretrained(model, keep_accents=True, remove_space=False)
model = GPTJForCausalLM.from_pretrained(
model,
load_in_4bit = True,
torch_dtype = torch.bfloat16,
device_map = 'auto')
model.eval()
def completion(prompt):
input_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors="pt"
).cuda()
tokens = model.generate(
input_ids.to(device=model.device),
max_new_tokens=256,
temperature=0.6,
top_p=1.0,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
output_without_input = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True)
return (output, output_without_input)
def get_line(prompt):
result = ""
for i in range(20):
prompt, current_output = completion(prompt)
if current_output:
markers = ["」", f"{speaker2}「", f"{speaker1}「"]
found = False
for marker in markers:
if marker in current_output:
current_output = current_output.split(marker)[0]
found = True
break
print(current_output, end="", flush=True)
result += current_output
if found:
break
return result.strip()
speaker1 = "私"
speaker2 = "メイド"
max_conversation_length = 20
system=f'''
[{speaker1}と{speaker2}の会話]
[{speaker2}の職業は、メイド。]
[{speaker2}は、丁寧な話し方をする。]
[{speaker2}は、毒舌家である。]
[{speaker2}の好きな食べ物は、チョコレート。]
[{speaker2}は、{speaker1}に仕えている。]
'''
initial_conversation=[
{
"speaker":speaker1,
"content":"おはよう、メイドちゃん。"
},
{
"speaker":speaker2,
"content":"おはようございます、ご主人様。"
},
{
"speaker":speaker1,
"content":"今日もよろしくね。"
},
{
"speaker":speaker2,
"content":"はい、精一杯、お仕事を務めさせていただきます。"
}
]
conversation = []
turn = 0
while True:
speaker1_line = input(f"{speaker1}: ")
prompt = f"{system}\n"
current_conversation = conversation[-max_conversation_length:] if len(conversation) > max_conversation_length else conversation.copy()
if turn < 5:
current_conversation[:0] = initial_conversation
else:
speaker_lines = [line["content"] for line in initial_conversation if line["speaker"] == speaker2]
prompt += f'[{speaker2}の台詞例:「{"」「".join(speaker_lines)}」]\n'
for line in current_conversation:
prompt += f'{line["speaker"]}「{line["content"]}」\n'
prompt += f'{speaker1}「{speaker1_line}」\n{speaker2}「'
#print(prompt)
print(f'{speaker2}: ',end="",flush=True)
speaker2_line = get_line(prompt)
print("")
if speaker2_line:
turn += 1
conversation.append(
{
"speaker":speaker1,
"content":speaker1_line
}
)
conversation.append(
{
"speaker":speaker2,
"content":speaker2_line
}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment