Created
October 29, 2023 11:24
-
-
Save jerrylususu/3ebcf6262d110da89ce58d1e8d55bc22 to your computer and use it in GitHub Desktop.
glm to openai adapter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from mitmproxy import http, ctx | |
from collections.abc import Iterable | |
import time | |
import jwt # pip install PyJWT / pipx inject mitmproxy PyJWT | |
import re | |
GLM_TOKEN = "[INSERT_YOUR_TOKEN]" | |
GLM_HOST = "open.bigmodel.cn" | |
GLM_PATH = "/api/paas/v3/model-api/chatglm_turbo/sse-invoke" | |
OPENAI_HOST = "api.openai.com" | |
# 生成 glm auth 头 | |
def generate_token(apikey, exp_seconds): | |
try: | |
id, secret = apikey.split(".") | |
except Exception as e: | |
raise Exception("invalid apikey", e) | |
payload = { | |
"api_key": id, | |
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000, | |
"timestamp": int(round(time.time() * 1000)), | |
} | |
return jwt.encode( | |
payload, | |
secret, | |
algorithm="HS256", | |
headers={"alg": "HS256", "sign_type": "SIGN"}, | |
) | |
# glm模型不支持发送system消息 | |
def convert_role(arr): | |
for item in arr: | |
if item['role'] not in ['user', 'assistant']: | |
item['role'] = 'user' | |
return arr | |
# glm模型不支持把两个user消息放在一起发送 | |
def combine_same_role(arr): | |
new_arr = [] | |
current_role, current_msg = "", "" | |
for msg in arr: | |
if msg['role'] != current_role: | |
if current_role != "" and current_msg != "": | |
new_arr.append({ | |
"role": current_role, | |
"content": current_msg | |
}) | |
current_role = msg['role'] | |
current_msg = msg['content'] | |
else: | |
current_msg += "\n---\n" | |
current_msg += msg['content'] | |
# 最后一条 | |
new_arr.append({ | |
"role": current_role, | |
"content": current_msg | |
}) | |
return new_arr | |
def request(flow: http.HTTPFlow) -> None: | |
if flow.request.host == OPENAI_HOST and flow.request.method == "POST" and flow.request.headers.get("content-type", "").startswith("application/json"): | |
try: | |
request_data = json.loads(flow.request.get_text()) | |
if request_data.get("model") == "chatglm_turbo": | |
flow.request.host = GLM_HOST | |
flow.request.path = GLM_PATH | |
token = generate_token(GLM_TOKEN, 300) # token 有效期 300s | |
flow.request.headers["authorization"] = token | |
flow.request.headers["accept"] = "text/event-stream" | |
request_data['messages'] = combine_same_role(convert_role(request_data['messages'])) | |
# glm 的 messages 叫做 prompt | |
request_data["prompt"] = request_data["messages"] | |
# 每次 sse 相应只返回增量部分,和 openai 保持一致 | |
request_data['incremental'] = True | |
del request_data["messages"] | |
flow.request.set_text(json.dumps(request_data)) | |
except json.JSONDecodeError: | |
pass | |
# 处理 sse 事件不完整(粘包)问题(有时一个消息会被拆分成两个 sse 事件) | |
DATA_BUFFER = b"" | |
RESP_DECODE_FAILED = False | |
GLM_RESP_PATTERN = re.compile(r"(event:(\w+)\nid:(\d+)\ndata:(.*)\n((\w+):.*\n)*)", re.MULTILINE) | |
# 修改 sse 事件包体,把 glm 的相应格式转换成 openai 的响应格式 | |
def modify(data: bytes) -> bytes | Iterable[bytes]: | |
if not data.startswith(b'event:'): | |
# 非 glm 的请求,这里应该还能优化下 | |
return data | |
global DATA_BUFFER | |
global RESP_DECODE_FAILED | |
if RESP_DECODE_FAILED: | |
DATA_BUFFER = DATA_BUFFER + data | |
else: | |
DATA_BUFFER = data | |
is_finish_event = data.startswith(b'event:finish') | |
try: | |
decoded = DATA_BUFFER.decode('utf8') # 默认应该是utf8 | |
match_spans = list(GLM_RESP_PATTERN.finditer(decoded)) | |
if len(match_spans) == 0: | |
raise Exception("not finished") | |
# 正常最后会多一个 '/n' | |
if match_spans[-1].span()[1] + 1 != len(decoded): | |
raise Exception("still not finished") | |
DATA_BUFFER = b"" | |
RESP_DECODE_FAILED = False | |
# ctx.log.info(decoded) | |
matches = GLM_RESP_PATTERN.findall(decoded) | |
new_resps = [] | |
for m in matches: | |
# glm 的请求 | |
lines = m[0].split("\n") | |
glm_data_raws = lines[2:] | |
glm_event = m[1] | |
glm_id = m[2] | |
glm_data = "" | |
for i in glm_data_raws: | |
if i.startswith('data:'): | |
true_data = i[len('data:'):] | |
if len(true_data) == 0: | |
glm_data += '\n' | |
else: | |
glm_data += true_data | |
# hack: 生成代码的时候似乎有两个空格?删掉一个 | |
if glm_data.startswith(" "): | |
glm_data = glm_data[1:] | |
new_obj = { | |
"id": glm_id, | |
"object": "chat.completion.chunk", | |
"created": int(time.time()), | |
"model": "chatglm_turbo", | |
"choices": [ | |
{ | |
"index": 0, | |
"delta": { | |
"role": "assistant", | |
"content": glm_data | |
}, | |
"finish_reason": None | |
} | |
] | |
} | |
if glm_event != "add": | |
ctx.log.info(decoded) | |
if glm_event == "finish": | |
new_obj["choices"][0]["delta"] = {} | |
new_obj["choices"][0]["finish_reason"] = "stop" | |
if glm_event in ["error", "interrupted"]: | |
new_obj["choices"][0]["delta"] = {} | |
new_obj["choices"][0]["finish_reason"] = glm_event | |
new_data_str = f"data: {json.dumps(new_obj, ensure_ascii=False)}\n\n" | |
# 如果是生成结束事件,模拟 openai 返回 DONE | |
if glm_event == "finish": | |
new_data_str += "data: [DONE]\n\n" | |
new_data_bin = new_data_str.encode('utf8') | |
new_resps.append(new_data_bin) | |
return new_resps | |
except Exception as e: | |
RESP_DECODE_FAILED = True | |
ctx.log.info(e) | |
ctx.log.info(data) | |
# 避免下一个请求卡住 | |
if is_finish_event: | |
DATA_BUFFER = b"" | |
# 解码不出来的话暂时不返回 | |
return [] | |
def responseheaders(flow): | |
if flow.request.host != GLM_HOST: | |
flow.response.stream = True | |
return | |
flow.response.stream = modify | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment