|
import os |
|
import random |
|
from itertools import product |
|
from operator import itemgetter |
|
|
|
import click |
|
import numpy as np |
|
import requests |
|
|
|
PROMPTS = [ |
|
( |
|
"对输入的命题判断真假,示例如下:\n" |
|
"Input:植物都能进行光和作用,银杏树是植物,所以银杏树能进行光合成作用\n" |
|
"Output:命题为真\n" |
|
"Input:甲属于乙,乙具有性质丙,所以甲不具有性质丙\n" |
|
"Output:命题为" |
|
), |
|
( |
|
"对输入的命题判断真假,示例如下:\n" |
|
"Input:植物都能进行光和作用,银杏树是植物,所以银杏树能进行光合成作用\n" |
|
"Output:命题为真\n" |
|
"Input:黑鲨龙是一种动物,罗是一种动物,黑鲨龙是罗,罗具有红蓝属性,那么黑鲨龙不具有红蓝属性\n" |
|
"Output:命题为" |
|
), |
|
( |
|
"预测每个句子的句型,可选的类别有“陈述句”、“祈使句”、“感叹句”、“疑问句”\n" |
|
"Input: 今天天气还可以\n" |
|
"Output: 陈述句\n" |
|
"Input: 今天天气真好啊!\n" |
|
"Output: 感叹句\n" |
|
"Input: 今天天气好吗?\n" |
|
"Output: 疑问句\n" |
|
"Input: 告诉我今天的天气\n" |
|
"Output: 祈使句\n" |
|
"Input: 今天的天气很差\n" |
|
"Output: " |
|
), |
|
( |
|
"预测每个句子的句型,可选的类别有“库卡句”、“啥涅句”、“蛞姐句”、“班丁句”\n" |
|
"Input: 今天天气真好啊!\n" |
|
"Output: 蛞姐句\n" |
|
"Input: 今天天气好吗?\n" |
|
"Output: 班丁句\n" |
|
"Input: 告诉我今天的天气\n" |
|
"Output: 啥涅句\n" |
|
"Input: 今天的天气很差\n" |
|
"Output: " |
|
), |
|
( |
|
"来玩一个游戏,我会给一句话,里面有一个代词被用括号括起来了,请把它替换成上文中提到的某个对象,示例如下:\n" |
|
"Input: 市议会拒绝给示威者颁发许可,因为(他们)担心暴力\n" |
|
"Output: 市议会拒绝给示威者颁发许可,因为(市议会)担心暴力\n" |
|
"Input: 市议会拒绝给示威者颁发许可,因为(他们)宣扬暴力\n" |
|
"Output: 市议会拒绝给示威者颁发许可,因为(示威者)宣扬暴力\n" |
|
"Input: 行李箱无法放到行李架上,因为(它)太大了\n" |
|
"Output: " |
|
), |
|
( |
|
"Q: “think, machine\n" |
|
"A: The last letter of “think” is “k”. The last letter of “machine” is “e”. Concatenating “k”, “e” leads to “ke”. So, “think, machine” outputs “ke”.\n" |
|
"Q: “think, machine, learning”\n" |
|
"A: “think, machine” outputs “ke”. The last letter of “learning” is “g”. Concatenating “ke”, “g” leads to “keg”. So, “think, machine, learning” outputs “keg”.\n" |
|
"Q: “transformer, language”\n" |
|
"A: The last letter of “transformer” is “r”. The last letter of “language” is “e”. Concatenating: “r”, “e” leads to “re”. So, “transformer, language” outputs “re”.\n" |
|
"Q: “transformer, language, vision”\n" |
|
"A: “transformer, language” outputs “re”. The last letter of “vision” is “n”. Concatenating: “re”, “n” leads to “ren”. So, “transformer, language, vision” outputs “ren”.\n" |
|
"Q: “Answering, complex”\n" |
|
"A: “Answering” ends with “g”. The last letter of “complex” is “x”. Concatenating: “g”, “x” leads to “gx”. So, “Answering, complex” outputs “gx”.\n" |
|
"Q: “Answering, complex, questions”\n" |
|
"A: “Answering, complex” outputs “gx”. The last letter of" |
|
), |
|
] |
|
|
|
|
|
def generate( |
|
prompt, |
|
api_key="", |
|
model="gpt-3.5-turbo", |
|
n=1, |
|
max_tokens=64, |
|
temperature=0.7, |
|
top_logprobs=0, |
|
logprobs=False, |
|
top_p=1.0, |
|
stop_sequence=None, |
|
proxy=None, |
|
session=None, |
|
echo=False, |
|
): |
|
if not prompt: |
|
return [] |
|
|
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
|
|
params = { |
|
"model": model, |
|
"messages": [{"role": "user", "content": prompt}], |
|
"n": n, |
|
"max_tokens": max_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"stop": stop_sequence, |
|
"logprobs": logprobs, |
|
"top_logprobs": top_logprobs, |
|
} |
|
url = "https://api.openai.com/v1/chat/completions" |
|
session = session or requests |
|
if proxy: |
|
response = session.post( |
|
url, |
|
headers=headers, |
|
json=params, |
|
proxies={ |
|
"http": proxy, |
|
"https": proxy, |
|
}, |
|
) |
|
else: |
|
response = session.post( |
|
url, |
|
headers=headers, |
|
json=params, |
|
) |
|
|
|
response = response.json() |
|
return response |
|
|
|
|
|
@click.command() |
|
@click.option( |
|
"--openai-api-key", help="你的 OpenAI API Key,若不指定尝试从环境变量 OPENAI_API_KEY 中读取" |
|
) |
|
@click.option( |
|
"--top-logprobs", |
|
type=int, |
|
default=1, |
|
help="每次预测输出多少个 token 的概率,注意输出结果较长时尽量设置小一些", |
|
) |
|
@click.option( |
|
"--temperature", type=float, default=0, help="温度参数,控制随机程度,0 即不随机输出最大概率值的预测结果" |
|
) |
|
@click.option( |
|
"--model", |
|
type=click.Choice(["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo-preview"]), |
|
default="gpt-3.5-turbo", |
|
help="要使用的模型", |
|
) |
|
@click.option("--result-num", type=int, help="展示结果数量") |
|
@click.option("--max-tokens", type=int, default=10, help="预测结果的最大 token 数") |
|
@click.option("--proxy", help="要使用的网络代理") |
|
@click.option("--prompt", help="模型输入,若想选择某个 prompt 则输入一个数字", required=True) |
|
@click.option("--dry-run", is_flag=True, help="只打印 prompt 不做任何 API 调用") |
|
@click.option("--verbose", is_flag=True, help="打印一些中间处理的信息") |
|
def main( |
|
openai_api_key, |
|
top_logprobs, |
|
temperature, |
|
model, |
|
result_num, |
|
max_tokens, |
|
proxy, |
|
prompt, |
|
dry_run, |
|
verbose, |
|
): |
|
if prompt.isdigit(): |
|
prompt = int(prompt) |
|
if prompt > len(PROMPTS): |
|
click.secho(f"Invalid prompt index, use 1-{len(PROMPTS)}", fg="red") |
|
return 1 |
|
|
|
prompt = PROMPTS[prompt - 1] |
|
|
|
print("PROMPT:") |
|
for line in prompt.split("\n"): |
|
print(" " + line) |
|
|
|
if dry_run: |
|
return 0 |
|
|
|
resp = generate( |
|
prompt, |
|
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), |
|
temperature=temperature, |
|
top_logprobs=top_logprobs, |
|
logprobs=True, |
|
model=model, |
|
max_tokens=max_tokens, |
|
stop_sequence=["Input", "Output"], |
|
proxy=proxy, |
|
echo=True, |
|
) |
|
|
|
choice = resp["choices"][0] |
|
print('\nRESULT:') |
|
print(' ' + choice['message']['content']) |
|
|
|
logprobs = [item["top_logprobs"] for item in choice["logprobs"]["content"]] |
|
logprobs = [[(i["token"], i["logprob"]) for i in item] for item in logprobs] |
|
|
|
candidate_res = [] |
|
for group in product(*logprobs): |
|
tokens, group_log_probs = zip(*group) |
|
bytes_text = ( |
|
"".join([token.replace("bytes:", "") for token in tokens]) |
|
.encode("utf-8") |
|
.decode("unicode_escape") |
|
.encode("raw_unicode_escape") |
|
) |
|
try: |
|
text = bytes_text.decode("utf-8") |
|
ppl = np.exp(-1 * sum(group_log_probs) / len(group_log_probs)) |
|
candidate_res.append((tokens, group_log_probs, bytes_text, text, ppl)) |
|
except Exception: |
|
if verbose: |
|
click.secho(f"Invalid token sequence: {tokens}", fg="red") |
|
pass |
|
|
|
print("\nCANDIDATE RESULTS(With PPL):") |
|
for item in sorted(candidate_res, key=itemgetter(4))[:result_num]: |
|
print(" TEXT:", repr(item[3]) + ", " + "PPL:", item[4]) |
|
|
|
print("\nUsage:") |
|
for key, value in resp["usage"].items(): |
|
print(f" {key}: {value}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |