Created
October 22, 2023 00:26
-
-
Save albertpb/e393cffe17e42e37bc007a99f9175ab5 to your computer and use it in GitHub Desktop.
oobabooga text-generation-webui api GET characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import ssl | |
import base64 | |
import yaml | |
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
from threading import Thread | |
from pathlib import Path | |
from extensions.api.util import build_parameters, try_start_cloudflared | |
from modules import shared | |
from modules.chat import generate_chat_reply | |
from modules.LoRA import add_lora_to_model | |
from modules.models import load_model, unload_model | |
from urllib.parse import urlparse, parse_qs | |
from modules.models_settings import get_model_metadata, update_model_parameters | |
from modules.text_generation import ( | |
encode, | |
generate_reply, | |
stop_everything_event | |
) | |
from modules.utils import (get_available_models, | |
get_available_characters) | |
from modules.logging_colors import logger | |
def get_model_info(): | |
return { | |
'model_name': shared.model_name, | |
'lora_names': shared.lora_names, | |
# dump | |
'shared.settings': shared.settings, | |
'shared.args': vars(shared.args), | |
} | |
class Handler(BaseHTTPRequestHandler): | |
def do_GET(self): | |
if self.path == '/api/v1/model': | |
self.send_response(200) | |
self.end_headers() | |
response = json.dumps({ | |
'result': shared.model_name | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path.startswith('/api/v1/characters'): | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
args = parse_qs(urlparse(self.path).query) | |
name = args.get('name', None) | |
# picture_response_format = args.get('picture_response_format', 'b64_json') | |
if name: | |
name = name[0] | |
picture_data = None | |
data = None | |
for extension in ['png', 'jpg', 'jpeg']: | |
filepath = Path(f"characters/{name}.{extension}") | |
if filepath.exists(): | |
with open(filepath, 'rb') as f: | |
file_contents = f.read() | |
encoded_bytes = base64.b64encode(file_contents) | |
# Turn raw base64 encoded bytes into ASCII | |
# TODO: support 'url' and 'data': url ? data_url? | |
img_data = encoded_bytes.decode('ascii') | |
img_filename = f"{name}.{extension}" | |
img_encoding = 'b64_json' # like SD, maybe also accept 'url'.. 'data_url'? | |
picture_data = { | |
'filename': img_filename, | |
'encoding': img_encoding, | |
# or 'data': url, and/or #'url': f"data:image/png;base64,{img_data}", | |
'data': img_data, | |
} | |
break | |
for extension in ["yml", "yaml", "json"]: | |
filepath = Path(f'characters/{name}.{extension}') | |
if filepath.exists(): | |
with open(filepath, 'r', encoding='utf-8') as f: | |
file_contents = f.read() | |
data = json.loads( | |
file_contents) if extension == "json" else yaml.safe_load(file_contents) | |
break | |
if not (picture_data or data): | |
self.send_error(404, message="Character not found") | |
return | |
resp = { | |
'data': data, | |
'picture': picture_data, | |
} | |
response = json.dumps({ | |
'results': resp | |
}) | |
else: | |
response = json.dumps({ | |
'results': get_available_characters() | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def do_POST(self): | |
content_length = int(self.headers['Content-Length']) | |
body = json.loads(self.rfile.read(content_length).decode('utf-8')) | |
if self.path == '/api/v1/generate': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
prompt = body['prompt'] | |
generate_params = build_parameters(body) | |
stopping_strings = generate_params.pop('stopping_strings') | |
generate_params['stream'] = False | |
generator = generate_reply( | |
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) | |
answer = '' | |
for a in generator: | |
answer = a | |
response = json.dumps({ | |
'results': [{ | |
'text': answer | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/chat': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
user_input = body['user_input'] | |
regenerate = body.get('regenerate', False) | |
_continue = body.get('_continue', False) | |
generate_params = build_parameters(body, chat=True) | |
generate_params['stream'] = False | |
generator = generate_chat_reply( | |
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) | |
answer = generate_params['history'] | |
for a in generator: | |
answer = a | |
response = json.dumps({ | |
'results': [{ | |
'history': answer | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/stop-stream': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
stop_everything_event() | |
response = json.dumps({ | |
'results': 'success' | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/model': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
# by default return the same as the GET interface | |
result = shared.model_name | |
# Actions: info, load, list, unload | |
action = body.get('action', '') | |
if action == 'load': | |
model_name = body['model_name'] | |
args = body.get('args', {}) | |
print('args', args) | |
for k in args: | |
setattr(shared.args, k, args[k]) | |
shared.model_name = model_name | |
unload_model() | |
model_settings = get_model_metadata(shared.model_name) | |
shared.settings.update( | |
{k: v for k, v in model_settings.items() if k in shared.settings}) | |
update_model_parameters(model_settings, initial=True) | |
if shared.settings['mode'] != 'instruct': | |
shared.settings['instruction_template'] = None | |
try: | |
shared.model, shared.tokenizer = load_model( | |
shared.model_name) | |
if shared.args.lora: | |
add_lora_to_model(shared.args.lora) # list | |
except Exception as e: | |
response = json.dumps({'error': {'message': repr(e)}}) | |
self.wfile.write(response.encode('utf-8')) | |
raise e | |
shared.args.model = shared.model_name | |
result = get_model_info() | |
elif action == 'unload': | |
unload_model() | |
shared.model_name = None | |
shared.args.model = None | |
result = get_model_info() | |
elif action == 'list': | |
result = get_available_models() | |
elif action == 'info': | |
result = get_model_info() | |
response = json.dumps({ | |
'result': result, | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/token-count': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
tokens = encode(body['prompt'])[0] | |
response = json.dumps({ | |
'results': [{ | |
'tokens': len(tokens) | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def do_OPTIONS(self): | |
self.send_response(200) | |
self.end_headers() | |
def end_headers(self): | |
self.send_header('Access-Control-Allow-Origin', '*') | |
self.send_header('Access-Control-Allow-Methods', '*') | |
self.send_header('Access-Control-Allow-Headers', '*') | |
self.send_header( | |
'Cache-Control', 'no-store, no-cache, must-revalidate') | |
super().end_headers() | |
def _run_server(port: int, share: bool = False, tunnel_id=str): | |
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' | |
server = ThreadingHTTPServer((address, port), Handler) | |
ssl_certfile = shared.args.ssl_certfile | |
ssl_keyfile = shared.args.ssl_keyfile | |
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False | |
if ssl_verify: | |
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
context.load_cert_chain(ssl_certfile, ssl_keyfile) | |
server.socket = context.wrap_socket(server.socket, server_side=True) | |
def on_start(public_url: str): | |
logger.info( | |
f'Starting non-streaming server at public url {public_url}/api') | |
if share: | |
try: | |
try_start_cloudflared( | |
port, tunnel_id, max_attempts=3, on_start=on_start) | |
except Exception: | |
pass | |
else: | |
if ssl_verify: | |
logger.info(f'Starting API at https://{address}:{port}/api') | |
else: | |
logger.info(f'Starting API at http://{address}:{port}/api') | |
server.serve_forever() | |
def start_server(port: int, share: bool = False, tunnel_id=str): | |
Thread(target=_run_server, args=[ | |
port, share, tunnel_id], daemon=True).start() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment