Last active
March 23, 2024 18:50
-
-
Save oneamitj/c66baae30d723503bbbeeae61be55ec3 to your computer and use it in GitHub Desktop.
WIP: A mock API to be used in place of OpenAI endpoint in Open Web UI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
from flask import Flask, render_template, request, jsonify, Response, stream_with_context | |
import os | |
import boto3 | |
import json | |
import datetime | |
import logging | |
import sys | |
import secrets | |
import string | |
import random | |
from functools import wraps | |
logging.basicConfig(level=logging.WARN, | |
stream=sys.stdout, | |
format='%(name)s - %(levelname)s - %(message)s') | |
app = Flask(__name__) | |
session = boto3.Session() | |
bedrock_runtime = session.client('bedrock-runtime') | |
r_hash = secrets.token_hex(10) | |
SECRET_TOKEN = os.getenv("AUTH_TOKEN") | |
def token_required(f): | |
@wraps(f) | |
def decorated_function(*args, **kwargs): | |
auth_header = request.headers.get('Authorization') | |
if not auth_header: | |
return jsonify({'message': 'Missing token'}), 403 | |
try: | |
token_type, token = auth_header.split(" ") | |
if token_type.lower() != 'bearer' or token != SECRET_TOKEN: | |
raise ValueError('Invalid token') | |
except Exception as e: | |
return jsonify({'message': 'Invalid token', 'error': str(e)}), 403 | |
return f(*args, **kwargs) | |
return decorated_function | |
@app.route('/') | |
def main(): | |
return {'response': 'OK'} | |
@app.route('/ask', methods=['POST']) | |
@token_required | |
def ask_bedrock(): | |
data = request.get_json() | |
try: | |
return { | |
'r': query_llm(data['p']) | |
}, 200 | |
except Exception as e: | |
logging.error(e, exc_info=True) | |
return { | |
'r': "ERROR" | |
}, 503 | |
@app.route('/models', methods=['GET']) | |
@token_required | |
def models(): | |
all_params = request.args.to_dict() | |
logging.info(f"Models Query: {all_params}") | |
return { | |
"object": "list", | |
"data": [ | |
{ | |
"id": "amazon.titan-text-lite-v1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "amazon.titan-text-express-v1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "anthropic.claude-instant-v1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "anthropic.claude-v2", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "anthropic.claude-v2:1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "meta.llama2-13b-chat-v1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "meta.llama2-70b-chat-v1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
# { | |
# "id": "meta.llama2-13b-v1", | |
# "object": "model", | |
# "created": 1711171627, | |
# "owned_by": "system" | |
# }, | |
# { | |
# "id": "meta.llama2-70b-v1", | |
# "object": "model", | |
# "created": 1711171627, | |
# "owned_by": "system" | |
# }, | |
{ | |
"id": "mistral.mistral-7b-instruct-v0:2", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
}, | |
{ | |
"id": "mistral.mixtral-8x7b-instruct-v0:1", | |
"object": "model", | |
"created": 1711171627, | |
"owned_by": "system" | |
} | |
] | |
} | |
@app.route('/api/tags', methods=['GET']) | |
def tags(): | |
all_params = request.args.to_dict() | |
logging.info(f"Tags Query: {all_params}") | |
return { | |
"models": [ | |
{ | |
"name": "meta.llama2-70b-v1", | |
"model": "meta.llama2-70b-v1", | |
"modified_at": "2024-03-22T12:24:00.806097832Z", | |
"size": 4998664424, | |
"digest": "69b94d2f208641931b7a10fbf9cb9c749d7725e2980226939283a03b76e84b7f", | |
"details": { | |
"parent_model": "", | |
"format": "gguf", | |
"family": "llama", | |
"families": [ | |
"llama" | |
], | |
"parameter_size": "70B", | |
"quantization_level": "Q5_0" | |
}, | |
"urls": [ | |
0 | |
] | |
} | |
] | |
} | |
@app.route('/api/version', methods=['GET']) | |
def version(): | |
all_params = request.args.to_dict() | |
logging.info(f"Version Query: {all_params}") | |
return {'version': '0.1.29'} | |
@app.route('/api/generate', methods=['POST']) | |
@token_required | |
def generate(): | |
data = request.get_json() | |
logging.info(f"/api/generate: {data}") | |
response_data = { | |
"model": "llama2", | |
"created_at": datetime.datetime.now().isoformat(), | |
"response": query_llm(data['prompt']), | |
"done": True | |
} | |
return jsonify(response_data), 200 | |
def generate_string(length=29): | |
chars = string.ascii_letters + string.digits | |
return ''.join(random.choice(chars) for _ in range(length)) | |
@app.route('/chat/completions', methods=['POST']) | |
@token_required | |
def chat(): | |
data = request.get_json() | |
model = data.get('model') | |
msgs = data.get('messages') | |
logging.info({ | |
'model_id': model, | |
'type': 'response', | |
'body': data | |
}) | |
prompt = gen_prompt(msgs, model) | |
try: | |
response = query_llm(prompt, model) | |
if not response: | |
response = 'ERROR' | |
except Exception as e: | |
response = 'ERROR' | |
logging.error(e, exc_info=True) | |
# ip_token = int(len(prompt)/6)+1 | |
# op_token = int(len(response)/6)+1 | |
r_str = generate_string() | |
def next_word(): | |
yield "event: start\n\n" | |
for result in response.strip().split(' '): | |
r = { | |
"id": f"chatcmpl-{r_str}", | |
"object": "chat.completion.chunk", | |
"created": int(datetime.datetime.now().timestamp()), | |
"model": model, | |
"system_fingerprint": f"fp_{r_hash}", | |
"choices": [ | |
{ | |
"index": 0, | |
"delta": { | |
"content": f"{result} " | |
}, | |
"logprobs": None, | |
"finish_reason": None | |
} | |
] | |
} | |
yield f"data: {json.dumps(r)}\n\n" | |
r = { | |
"id": f"chatcmpl-{r_str}", | |
"object": "chat.completion.chunk", | |
"created": int(datetime.datetime.now().timestamp()), | |
"model": model, | |
"system_fingerprint": f"fp_{r_hash}", | |
"choices": [ | |
{ | |
"index": 0, | |
"delta": {}, | |
"logprobs": None, | |
"finish_reason": "stop" | |
} | |
] | |
} | |
yield f"data: {json.dumps(r)}\n\n" | |
yield "data: [DONE]\n\n" | |
yield "event: end\n\n" | |
return Response(stream_with_context(next_word()), mimetype='text/event-stream') | |
@app.route('/<path:unmatched>', methods=['GET']) | |
def catch_all(unmatched): | |
all_params = request.args.to_dict() | |
logging.warning(f"GET {unmatched}: {all_params}") | |
return { | |
'endpoint': unmatched | |
} | |
@app.route('/<path:unmatched>', methods=['POST']) | |
def catch_all_post(unmatched): | |
data = request.get_json() | |
logging.warning(f"POST {unmatched}: {data}") | |
return { | |
'endpoint': unmatched | |
} | |
def gen_prompt(messages, model_id): | |
prompt = '' | |
q = messages[-1].get('content') | |
for msg in messages[:-1]: | |
prompt = prompt + '\n' + \ | |
msg.get('role') + ': ' + \ | |
msg.get('content') | |
if prompt.strip() != '': | |
# prompt = f"[context]{prompt}\n\n[prompt]\n{msgs[-1].get('content')}" | |
prompt = f"""Use the following conversation as your intraction history, inside <conversation></conversation> XML tags. | |
<conversation> | |
{prompt} | |
</conversation> | |
When answering to prompt: | |
- If you don't know when you are not sure, ask for clarification. | |
- If you don't know, just say that you don't know. | |
And answer according to the language of the user's question. | |
Given the conversation history yet, answer this query. | |
""" | |
else: | |
prompt = "" | |
if 'amazon' in model_id: | |
prompt = f"{prompt}\nUser: {q}\n\nBot: " | |
elif 'anthropic' in model_id: | |
prompt = f"{prompt}\nHuman: {q}\n\nAssistant: " | |
elif 'meta' in model_id: | |
prompt = f"{prompt}\n{q}" | |
elif 'mistral' in model_id: | |
prompt = f"{prompt}\n<s>[INST] {q} [/INST]" | |
logging.info(prompt) | |
return prompt | |
def generate_bedrock_request_body(prompt, model_id='amazon.titan-text-express-v1'): | |
body = {} | |
if 'amazon' in model_id: | |
body = { | |
"inputText": f"{prompt}", | |
"textGenerationConfig": { | |
# "temperature": float, | |
# "topP": float, | |
"maxTokenCount": 4096 | |
# "stopSequences": [string] | |
} | |
} | |
elif 'anthropic' in model_id: | |
body = { | |
"prompt": prompt, | |
"max_tokens_to_sample": 4096, | |
# "temperature": 0.5, | |
"stop_sequences": ["\n\nHuman:"] | |
} | |
elif 'meta' in model_id: | |
body = { | |
"prompt": prompt, | |
# "temperature": 0.5, | |
# "top_p": 0.9, | |
"max_gen_len": 2048 | |
} | |
elif 'mistral' in model_id: | |
body = { | |
"prompt": prompt, | |
"max_tokens": 4096 | |
# "temperature": 0.5, | |
} | |
return json.dumps(body) | |
def parse_bedrock_response_body(body, model_id='amazon.titan-text-express-v1'): | |
logging.info({ | |
'model_id': model_id, | |
'type': 'response', | |
'body': body | |
}) | |
output = '' | |
if 'amazon' in model_id: | |
for result in body['results']: | |
r = result['outputText'] | |
# if 'Bot:' in r: | |
r = r.replace('Bot:', '') | |
output = output + r | |
elif 'anthropic' in model_id: | |
output = body["completion"] | |
elif 'meta' in model_id: | |
output = body["generation"] | |
# if 'Expected answer:' in output: | |
output = output.replace('Expected answer:', '') | |
elif 'mistral' in model_id: | |
for o in body['outputs']: | |
output = output + o["text"] | |
if output == '': | |
output = 'ERROR' | |
return output | |
def query_llm(prompt, model_id='amazon.titan-text-express-v1') -> Optional[str]: | |
if not prompt: | |
return '' | |
bedrock_response = bedrock_runtime.invoke_model( | |
accept='application/json', | |
body=generate_bedrock_request_body(prompt, model_id), | |
contentType='application/json', | |
modelId=model_id) | |
response_stream_body = bedrock_response['body'].read() | |
json_body = json.loads(response_stream_body) | |
output = parse_bedrock_response_body(json_body, model_id) | |
return output | |
if __name__ == '__main__': | |
app.run(debug=False, host='0.0.0.0', port=9090) | |
# app.run(debug=True, host='0.0.0.0', port=443, ssl_context=('cert.pem', 'key.pem')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM --platform=linux/amd64 python:3-slim as build | |
RUN pip install boto3 | |
RUN pip install flask | |
RUN mkdir /app | |
WORKDIR /app | |
COPY bedrock_api.py bedrock_api.py | |
EXPOSE 9090 | |
CMD [ "python", "bedrock_api.py" ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment