Last active
June 26, 2023 12:42
-
-
Save jiacheo/04e3ac7837b68b89673ebca8279d0be7 to your computer and use it in GitHub Desktop.
ChatGLM2 to openai api
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Request | |
from transformers import AutoTokenizer, AutoModel | |
import uvicorn, json, datetime | |
import torch | |
import uuid | |
DEVICE = "cuda" | |
DEVICE_ID = "0" | |
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE | |
def torch_gc(): | |
if torch.cuda.is_available(): | |
with torch.cuda.device(CUDA_DEVICE): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
app = FastAPI() | |
@app.post("/v1/chat/completions") | |
async def chat_completions(request: Request): | |
global model, tokenizer | |
json_post_raw = await request.json() | |
json_post = json.dumps(json_post_raw) | |
json_post_list = json.loads(json_post) | |
print(f"request: {json_post_list}") | |
messages = json_post_list.get('messages') | |
mlen = len(messages) | |
history = None | |
if mlen > 1: | |
history = [] | |
for i in range(0, mlen-1, 2): | |
history.append((messages[i].get('content'), messages[i+1].get('content'))) | |
prompt = messages[mlen-1].get('content') | |
max_length = json_post_list.get('max_length') | |
top_p = json_post_list.get('top_p') | |
temperature = json_post_list.get('temperature') | |
temperature = temperature if temperature else 0.95 | |
response, histori = model.chat(tokenizer, | |
prompt, | |
history=history, | |
max_length=max_length if max_length else 10240, | |
top_p=top_p if top_p else 0.7, | |
temperature=temperature, | |
do_sample=False if temperature == 0 else True) | |
now = datetime.datetime.now() | |
time = now.strftime("%Y-%m-%d %H:%M:%S") | |
choices = [] | |
choices.append({ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": response | |
}, | |
"finishReason": "stop" | |
}) | |
resp_id = str(uuid.uuid4()) | |
answer = { | |
"choices": choices, | |
"id": resp_id | |
} | |
log = f"[{time}] response >> id:{resp_id}, prompt:\"{prompt}\", history:\"{history}\", response:\"{repr(response)}\"" | |
print(log) | |
torch_gc() | |
return answer | |
if __name__ == '__main__': | |
while True: | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) | |
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda') | |
model.eval() | |
break | |
except Exception as e: | |
print(f"network error, try again. Error:{e}") | |
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment