Skip to content

Instantly share code, notes, and snippets.

@Linusp
Last active February 7, 2024 09:37
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save Linusp/9b4e72c6b4d41d22917ff97bb35c1f09 to your computer and use it in GitHub Desktop.
Play GPT

使用示例:

python playgpt.py --model gpt-3.5-turbo \
       --proxy 'http://localhost:8888' \
       --max-tokens 10 \
       --top-logprobs 3 \
       --result-num 10 \
       --temperature 0 \
       --prompt '起一个女性名字,姓刘,名字要和月亮有关,但不要直接用月字,尝试根据一些古诗词里的典故,使用较常见而不是冷僻的字,只输出名字无需其他。结果是:刘'

结果:

PROMPT:
    起一个女性名字,姓刘,名字要和月亮有关,但不要直接用月字,尝试根据一些古诗词里的典故,使用较常见而不是冷僻的字,只输出名字无需其他。结果是:刘

RESULT:
    婵娟

CANDIDATE RESULTS(With PPL):
    TEXT: '婵娟', PPL: 1.3893958134060524
    TEXT: '嫵娟', PPL: 1.4518073221689531
    TEXT: '娵娟', PPL: 1.753700750011277
    TEXT: '婉娟', PPL: 2.1273314013866393
    TEXT: '嫉娟', PPL: 2.2228908964694374
    TEXT: '娉娟', PPL: 2.6851258929509876
    TEXT: '婷娟', PPL: 3.300551957606001
    TEXT: '嫷娟', PPL: 3.4488123924201473
    TEXT: '婵婟', PPL: 3.5588571357287666
    TEXT: '嫵婟', PPL: 3.718720611038904

Usage:
    prompt_tokens: 87
    completion_tokens: 6
    total_tokens: 93

一些使用上的要点:

  1. 安装依赖: pip install requests click numpy
  2. (可选)设置环境变量: export OPENAI_API_KEY='xxxxxxxxx' (把 xxxx 替换成你的 key),不配置也行,在执行的时候提供了参数来指定
  3. 代码里内置了 prompt,如果想使用内置 prompt,可以使用选项 --prompt 指定为一个数字(从 1 开始),会选择内置 prompt 对应序号的那个,如果是数字外的 prompt 则当作正常 prompt 输入给模型
  4. 使用: python playgpt.py --help 可以查看各命令行参数说明
  5. --temperature 设置为大于 0 的数值,且 --top-logprobs 大于 1 时,结果里输出的 RESULT (我们平时调用 API 时直接得到的输出)可能不会在 CANDIDATE RESULTS(我这个代码里通过组合每一步生成的 topn 个候选 token 得到的结果) 里出现,这个是正常的
  6. top_logprobs 参数不要太大,3 就差不多了,多了也没意义
  7. 预期输出结果会比较长的时候,不要设置 top_logprobs,不然的话,假设输出结果 token 数量为 20、logprobs 为 3,最后输出带 PPL 信息结果的时候要做 3 的 20 次方次组合
  8. 当设置 top_logprobs 时可能会组合出很多结果,这些结果都会按困惑度从低到高展示,如果不想显示那么多,可以通过设置 --result-num 为更小的数字来控制结果数量
  9. PPL 是困惑度(Perplexity)的意思,是语言模型中常用的一个评估指标,有兴趣自己去查一下资料
import os
import random
from itertools import product
from operator import itemgetter
import click
import numpy as np
import requests
PROMPTS = [
(
"对输入的命题判断真假,示例如下:\n"
"Input:植物都能进行光和作用,银杏树是植物,所以银杏树能进行光合成作用\n"
"Output:命题为真\n"
"Input:甲属于乙,乙具有性质丙,所以甲不具有性质丙\n"
"Output:命题为"
),
(
"对输入的命题判断真假,示例如下:\n"
"Input:植物都能进行光和作用,银杏树是植物,所以银杏树能进行光合成作用\n"
"Output:命题为真\n"
"Input:黑鲨龙是一种动物,罗是一种动物,黑鲨龙是罗,罗具有红蓝属性,那么黑鲨龙不具有红蓝属性\n"
"Output:命题为"
),
(
"预测每个句子的句型,可选的类别有“陈述句”、“祈使句”、“感叹句”、“疑问句”\n"
"Input: 今天天气还可以\n"
"Output: 陈述句\n"
"Input: 今天天气真好啊!\n"
"Output: 感叹句\n"
"Input: 今天天气好吗?\n"
"Output: 疑问句\n"
"Input: 告诉我今天的天气\n"
"Output: 祈使句\n"
"Input: 今天的天气很差\n"
"Output: "
),
(
"预测每个句子的句型,可选的类别有“库卡句”、“啥涅句”、“蛞姐句”、“班丁句”\n"
"Input: 今天天气真好啊!\n"
"Output: 蛞姐句\n"
"Input: 今天天气好吗?\n"
"Output: 班丁句\n"
"Input: 告诉我今天的天气\n"
"Output: 啥涅句\n"
"Input: 今天的天气很差\n"
"Output: "
),
(
"来玩一个游戏,我会给一句话,里面有一个代词被用括号括起来了,请把它替换成上文中提到的某个对象,示例如下:\n"
"Input: 市议会拒绝给示威者颁发许可,因为(他们)担心暴力\n"
"Output: 市议会拒绝给示威者颁发许可,因为(市议会)担心暴力\n"
"Input: 市议会拒绝给示威者颁发许可,因为(他们)宣扬暴力\n"
"Output: 市议会拒绝给示威者颁发许可,因为(示威者)宣扬暴力\n"
"Input: 行李箱无法放到行李架上,因为(它)太大了\n"
"Output: "
),
(
"Q: “think, machine\n"
"A: The last letter of “think” is “k”. The last letter of “machine” is “e”. Concatenating “k”, “e” leads to “ke”. So, “think, machine” outputs “ke”.\n"
"Q: “think, machine, learning”\n"
"A: “think, machine” outputs “ke”. The last letter of “learning” is “g”. Concatenating “ke”, “g” leads to “keg”. So, “think, machine, learning” outputs “keg”.\n"
"Q: “transformer, language”\n"
"A: The last letter of “transformer” is “r”. The last letter of “language” is “e”. Concatenating: “r”, “e” leads to “re”. So, “transformer, language” outputs “re”.\n"
"Q: “transformer, language, vision”\n"
"A: “transformer, language” outputs “re”. The last letter of “vision” is “n”. Concatenating: “re”, “n” leads to “ren”. So, “transformer, language, vision” outputs “ren”.\n"
"Q: “Answering, complex”\n"
"A: “Answering” ends with “g”. The last letter of “complex” is “x”. Concatenating: “g”, “x” leads to “gx”. So, “Answering, complex” outputs “gx”.\n"
"Q: “Answering, complex, questions”\n"
"A: “Answering, complex” outputs “gx”. The last letter of"
),
]
def generate(
prompt,
api_key="",
model="gpt-3.5-turbo",
n=1,
max_tokens=64,
temperature=0.7,
top_logprobs=0,
logprobs=False,
top_p=1.0,
stop_sequence=None,
proxy=None,
session=None,
echo=False,
):
if not prompt:
return []
headers = {"Authorization": f"Bearer {api_key}"}
params = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"n": n,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stop": stop_sequence,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
}
url = "https://api.openai.com/v1/chat/completions"
session = session or requests
if proxy:
response = session.post(
url,
headers=headers,
json=params,
proxies={
"http": proxy,
"https": proxy,
},
)
else:
response = session.post(
url,
headers=headers,
json=params,
)
response = response.json()
return response
@click.command()
@click.option(
"--openai-api-key", help="你的 OpenAI API Key,若不指定尝试从环境变量 OPENAI_API_KEY 中读取"
)
@click.option(
"--top-logprobs",
type=int,
default=1,
help="每次预测输出多少个 token 的概率,注意输出结果较长时尽量设置小一些",
)
@click.option(
"--temperature", type=float, default=0, help="温度参数,控制随机程度,0 即不随机输出最大概率值的预测结果"
)
@click.option(
"--model",
type=click.Choice(["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo-preview"]),
default="gpt-3.5-turbo",
help="要使用的模型",
)
@click.option("--result-num", type=int, help="展示结果数量")
@click.option("--max-tokens", type=int, default=10, help="预测结果的最大 token 数")
@click.option("--proxy", help="要使用的网络代理")
@click.option("--prompt", help="模型输入,若想选择某个 prompt 则输入一个数字", required=True)
@click.option("--dry-run", is_flag=True, help="只打印 prompt 不做任何 API 调用")
@click.option("--verbose", is_flag=True, help="打印一些中间处理的信息")
def main(
openai_api_key,
top_logprobs,
temperature,
model,
result_num,
max_tokens,
proxy,
prompt,
dry_run,
verbose,
):
if prompt.isdigit():
prompt = int(prompt)
if prompt > len(PROMPTS):
click.secho(f"Invalid prompt index, use 1-{len(PROMPTS)}", fg="red")
return 1
prompt = PROMPTS[prompt - 1]
print("PROMPT:")
for line in prompt.split("\n"):
print(" " + line)
if dry_run:
return 0
resp = generate(
prompt,
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
temperature=temperature,
top_logprobs=top_logprobs,
logprobs=True,
model=model,
max_tokens=max_tokens,
stop_sequence=["Input", "Output"],
proxy=proxy,
echo=True,
)
choice = resp["choices"][0]
print('\nRESULT:')
print(' ' + choice['message']['content'])
logprobs = [item["top_logprobs"] for item in choice["logprobs"]["content"]]
logprobs = [[(i["token"], i["logprob"]) for i in item] for item in logprobs]
candidate_res = []
for group in product(*logprobs):
tokens, group_log_probs = zip(*group)
bytes_text = (
"".join([token.replace("bytes:", "") for token in tokens])
.encode("utf-8")
.decode("unicode_escape")
.encode("raw_unicode_escape")
)
try:
text = bytes_text.decode("utf-8")
ppl = np.exp(-1 * sum(group_log_probs) / len(group_log_probs))
candidate_res.append((tokens, group_log_probs, bytes_text, text, ppl))
except Exception:
if verbose:
click.secho(f"Invalid token sequence: {tokens}", fg="red")
pass
print("\nCANDIDATE RESULTS(With PPL):")
for item in sorted(candidate_res, key=itemgetter(4))[:result_num]:
print(" TEXT:", repr(item[3]) + ", " + "PPL:", item[4])
print("\nUsage:")
for key, value in resp["usage"].items():
print(f" {key}: {value}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment