Skip to content

Instantly share code, notes, and snippets.

@xinsblog
Last active August 21, 2023 09:37
Show Gist options
  • Save xinsblog/e0926e2bbc3696a9145f83a88358dcf8 to your computer and use it in GitHub Desktop.
Save xinsblog/e0926e2bbc3696a9145f83a88358dcf8 to your computer and use it in GitHub Desktop.
from typing import List
import json
import sys
import openai
class SimChatGPT:
def __init__(self, api_key: str, messages: List = None):
openai.api_key = api_key
if messages:
self.messages = messages
else:
self.messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "接下来我会提供给你两句短文本,如果这两个文本的语义匹配,则回复'匹配',反之则回复'不匹配',"
"不要回复额外的文字或者标点符号"},
]
def ask_chat_gpt(self) -> str:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=self.messages
)
response_content = response['choices'][0]['message']['content']
return response_content
def train(self, x1: str, x2: str, y: str):
self.messages.append({"role": "user", "content": f"'{x1}'和'{x2}'匹配还是不匹配"})
response_content = self.ask_chat_gpt()
self.messages.append({"role": "assistant", "content": response_content})
if response_content not in {'匹配', '不匹配'}:
feedback = "你回答的格式不对,你只能回复'匹配'或者'不匹配',不能回复额外的文字或者标点符号"
elif response_content == y:
feedback = "你回答的很对,棒棒哒"
else:
feedback = f"你回答的不对,你回答的是'{response_content}',正确答案是'{y}'"
self.messages.append({"role": "user", "content": feedback})
print(f"\n当前训练样本x1={x1}, x2={x2}, y={y}")
print(f"self.messages=")
for msg in self.messages:
print(msg)
def predict(self, x1: str, x2: str) -> str:
self.messages.append({"role": "user", "content": f"'{x1}'和'{x2}'匹配还是不匹配"})
response_content = self.ask_chat_gpt()
self.messages.pop()
return response_content
def save(self, model_path: str):
model_dict = {
'messages': self.messages
}
with open(model_path, "w", encoding='utf-8') as f:
json.dump(model_dict, f, ensure_ascii=False, indent=2)
@classmethod
def load(self, model_path: str, api_key: str) -> 'SimChatGPT':
with open(model_path, "r", encoding='utf-8') as f:
model_dict = json.load(f)
model = SimChatGPT(api_key=api_key, messages=model_dict['messages'])
return model
if __name__ == '__main__':
train_data = [
("小张比小王更高吗", "小王比小张更矮吗", "匹配"),
("小张比小王更高吗", "小王比小张更高吗", "不匹配"),
("上海比北京更远吗", "北京比上海更远吗", "不匹配"),
("鱼和鸡蛋能一起吃吗", "鸡蛋和鱼能同时吃吗", "匹配"),
]
test_data = [
("苹果8比苹果9更贵吗", "苹果9比苹果8更贵吗", "不匹配"),
("iphone8比iphone9更贵吗", "iphone9比8更便宜吗", "匹配"),
("上海和北京一样远吗", "北京和上海同样远吗", "匹配"),
("杭州比深圳更热吗", "深圳比杭州更热吗", "不匹配"),
]
if len(sys.argv) < 2:
raise RuntimeError("命令行参数缺少api_key")
api_key = sys.argv[1]
sim_chatgpt = SimChatGPT(api_key=api_key)
print(sim_chatgpt.ask_chat_gpt())
for x1, x2, y in train_data:
sim_chatgpt.train(x1, x2, y)
sim_chatgpt.save('sim_chatgpt.json')
sim_chatgpt2 = SimChatGPT.load('sim_chatgpt.json', api_key=api_key)
for x1, x2, y in test_data:
print(x1, x2, y, sim_chatgpt2.predict(x1, x2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment