Skip to content

Instantly share code, notes, and snippets.

@jiacheo
Last active June 26, 2023 12:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jiacheo/04e3ac7837b68b89673ebca8279d0be7 to your computer and use it in GitHub Desktop.
Save jiacheo/04e3ac7837b68b89673ebca8279d0be7 to your computer and use it in GitHub Desktop.
ChatGLM2 to openai api
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
import uuid
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
print(f"request: {json_post_list}")
messages = json_post_list.get('messages')
mlen = len(messages)
history = None
if mlen > 1:
history = []
for i in range(0, mlen-1, 2):
history.append((messages[i].get('content'), messages[i+1].get('content')))
prompt = messages[mlen-1].get('content')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
temperature = temperature if temperature else 0.95
response, histori = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 10240,
top_p=top_p if top_p else 0.7,
temperature=temperature,
do_sample=False if temperature == 0 else True)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
choices = []
choices.append({
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finishReason": "stop"
})
resp_id = str(uuid.uuid4())
answer = {
"choices": choices,
"id": resp_id
}
log = f"[{time}] response >> id:{resp_id}, prompt:\"{prompt}\", history:\"{history}\", response:\"{repr(response)}\""
print(log)
torch_gc()
return answer
if __name__ == '__main__':
while True:
try:
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, device='cuda')
model.eval()
break
except Exception as e:
print(f"network error, try again. Error:{e}")
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment