Skip to content

Instantly share code, notes, and snippets.

@oneamitj
Last active March 23, 2024 18:50
Show Gist options
  • Save oneamitj/c66baae30d723503bbbeeae61be55ec3 to your computer and use it in GitHub Desktop.
Save oneamitj/c66baae30d723503bbbeeae61be55ec3 to your computer and use it in GitHub Desktop.
WIP: A mock API to be used in place of OpenAI endpoint in Open Web UI
from typing import Optional
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
import os
import boto3
import json
import datetime
import logging
import sys
import secrets
import string
import random
from functools import wraps
logging.basicConfig(level=logging.WARN,
stream=sys.stdout,
format='%(name)s - %(levelname)s - %(message)s')
app = Flask(__name__)
session = boto3.Session()
bedrock_runtime = session.client('bedrock-runtime')
r_hash = secrets.token_hex(10)
SECRET_TOKEN = os.getenv("AUTH_TOKEN")
def token_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header:
return jsonify({'message': 'Missing token'}), 403
try:
token_type, token = auth_header.split(" ")
if token_type.lower() != 'bearer' or token != SECRET_TOKEN:
raise ValueError('Invalid token')
except Exception as e:
return jsonify({'message': 'Invalid token', 'error': str(e)}), 403
return f(*args, **kwargs)
return decorated_function
@app.route('/')
def main():
return {'response': 'OK'}
@app.route('/ask', methods=['POST'])
@token_required
def ask_bedrock():
data = request.get_json()
try:
return {
'r': query_llm(data['p'])
}, 200
except Exception as e:
logging.error(e, exc_info=True)
return {
'r': "ERROR"
}, 503
@app.route('/models', methods=['GET'])
@token_required
def models():
all_params = request.args.to_dict()
logging.info(f"Models Query: {all_params}")
return {
"object": "list",
"data": [
{
"id": "amazon.titan-text-lite-v1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "amazon.titan-text-express-v1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "anthropic.claude-instant-v1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "anthropic.claude-v2",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "anthropic.claude-v2:1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "meta.llama2-13b-chat-v1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "meta.llama2-70b-chat-v1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
# {
# "id": "meta.llama2-13b-v1",
# "object": "model",
# "created": 1711171627,
# "owned_by": "system"
# },
# {
# "id": "meta.llama2-70b-v1",
# "object": "model",
# "created": 1711171627,
# "owned_by": "system"
# },
{
"id": "mistral.mistral-7b-instruct-v0:2",
"object": "model",
"created": 1711171627,
"owned_by": "system"
},
{
"id": "mistral.mixtral-8x7b-instruct-v0:1",
"object": "model",
"created": 1711171627,
"owned_by": "system"
}
]
}
@app.route('/api/tags', methods=['GET'])
def tags():
all_params = request.args.to_dict()
logging.info(f"Tags Query: {all_params}")
return {
"models": [
{
"name": "meta.llama2-70b-v1",
"model": "meta.llama2-70b-v1",
"modified_at": "2024-03-22T12:24:00.806097832Z",
"size": 4998664424,
"digest": "69b94d2f208641931b7a10fbf9cb9c749d7725e2980226939283a03b76e84b7f",
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "70B",
"quantization_level": "Q5_0"
},
"urls": [
0
]
}
]
}
@app.route('/api/version', methods=['GET'])
def version():
all_params = request.args.to_dict()
logging.info(f"Version Query: {all_params}")
return {'version': '0.1.29'}
@app.route('/api/generate', methods=['POST'])
@token_required
def generate():
data = request.get_json()
logging.info(f"/api/generate: {data}")
response_data = {
"model": "llama2",
"created_at": datetime.datetime.now().isoformat(),
"response": query_llm(data['prompt']),
"done": True
}
return jsonify(response_data), 200
def generate_string(length=29):
chars = string.ascii_letters + string.digits
return ''.join(random.choice(chars) for _ in range(length))
@app.route('/chat/completions', methods=['POST'])
@token_required
def chat():
data = request.get_json()
model = data.get('model')
msgs = data.get('messages')
logging.info({
'model_id': model,
'type': 'response',
'body': data
})
prompt = gen_prompt(msgs, model)
try:
response = query_llm(prompt, model)
if not response:
response = 'ERROR'
except Exception as e:
response = 'ERROR'
logging.error(e, exc_info=True)
# ip_token = int(len(prompt)/6)+1
# op_token = int(len(response)/6)+1
r_str = generate_string()
def next_word():
yield "event: start\n\n"
for result in response.strip().split(' '):
r = {
"id": f"chatcmpl-{r_str}",
"object": "chat.completion.chunk",
"created": int(datetime.datetime.now().timestamp()),
"model": model,
"system_fingerprint": f"fp_{r_hash}",
"choices": [
{
"index": 0,
"delta": {
"content": f"{result} "
},
"logprobs": None,
"finish_reason": None
}
]
}
yield f"data: {json.dumps(r)}\n\n"
r = {
"id": f"chatcmpl-{r_str}",
"object": "chat.completion.chunk",
"created": int(datetime.datetime.now().timestamp()),
"model": model,
"system_fingerprint": f"fp_{r_hash}",
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(r)}\n\n"
yield "data: [DONE]\n\n"
yield "event: end\n\n"
return Response(stream_with_context(next_word()), mimetype='text/event-stream')
@app.route('/<path:unmatched>', methods=['GET'])
def catch_all(unmatched):
all_params = request.args.to_dict()
logging.warning(f"GET {unmatched}: {all_params}")
return {
'endpoint': unmatched
}
@app.route('/<path:unmatched>', methods=['POST'])
def catch_all_post(unmatched):
data = request.get_json()
logging.warning(f"POST {unmatched}: {data}")
return {
'endpoint': unmatched
}
def gen_prompt(messages, model_id):
prompt = ''
q = messages[-1].get('content')
for msg in messages[:-1]:
prompt = prompt + '\n' + \
msg.get('role') + ': ' + \
msg.get('content')
if prompt.strip() != '':
# prompt = f"[context]{prompt}\n\n[prompt]\n{msgs[-1].get('content')}"
prompt = f"""Use the following conversation as your intraction history, inside <conversation></conversation> XML tags.
<conversation>
{prompt}
</conversation>
When answering to prompt:
- If you don't know when you are not sure, ask for clarification.
- If you don't know, just say that you don't know.
And answer according to the language of the user's question.
Given the conversation history yet, answer this query.
"""
else:
prompt = ""
if 'amazon' in model_id:
prompt = f"{prompt}\nUser: {q}\n\nBot: "
elif 'anthropic' in model_id:
prompt = f"{prompt}\nHuman: {q}\n\nAssistant: "
elif 'meta' in model_id:
prompt = f"{prompt}\n{q}"
elif 'mistral' in model_id:
prompt = f"{prompt}\n<s>[INST] {q} [/INST]"
logging.info(prompt)
return prompt
def generate_bedrock_request_body(prompt, model_id='amazon.titan-text-express-v1'):
body = {}
if 'amazon' in model_id:
body = {
"inputText": f"{prompt}",
"textGenerationConfig": {
# "temperature": float,
# "topP": float,
"maxTokenCount": 4096
# "stopSequences": [string]
}
}
elif 'anthropic' in model_id:
body = {
"prompt": prompt,
"max_tokens_to_sample": 4096,
# "temperature": 0.5,
"stop_sequences": ["\n\nHuman:"]
}
elif 'meta' in model_id:
body = {
"prompt": prompt,
# "temperature": 0.5,
# "top_p": 0.9,
"max_gen_len": 2048
}
elif 'mistral' in model_id:
body = {
"prompt": prompt,
"max_tokens": 4096
# "temperature": 0.5,
}
return json.dumps(body)
def parse_bedrock_response_body(body, model_id='amazon.titan-text-express-v1'):
logging.info({
'model_id': model_id,
'type': 'response',
'body': body
})
output = ''
if 'amazon' in model_id:
for result in body['results']:
r = result['outputText']
# if 'Bot:' in r:
r = r.replace('Bot:', '')
output = output + r
elif 'anthropic' in model_id:
output = body["completion"]
elif 'meta' in model_id:
output = body["generation"]
# if 'Expected answer:' in output:
output = output.replace('Expected answer:', '')
elif 'mistral' in model_id:
for o in body['outputs']:
output = output + o["text"]
if output == '':
output = 'ERROR'
return output
def query_llm(prompt, model_id='amazon.titan-text-express-v1') -> Optional[str]:
if not prompt:
return ''
bedrock_response = bedrock_runtime.invoke_model(
accept='application/json',
body=generate_bedrock_request_body(prompt, model_id),
contentType='application/json',
modelId=model_id)
response_stream_body = bedrock_response['body'].read()
json_body = json.loads(response_stream_body)
output = parse_bedrock_response_body(json_body, model_id)
return output
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=9090)
# app.run(debug=True, host='0.0.0.0', port=443, ssl_context=('cert.pem', 'key.pem'))
FROM --platform=linux/amd64 python:3-slim as build
RUN pip install boto3
RUN pip install flask
RUN mkdir /app
WORKDIR /app
COPY bedrock_api.py bedrock_api.py
EXPOSE 9090
CMD [ "python", "bedrock_api.py" ]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment