Skip to content

Instantly share code, notes, and snippets.

@albertpb
Created October 22, 2023 00:26
Show Gist options
  • Save albertpb/e393cffe17e42e37bc007a99f9175ab5 to your computer and use it in GitHub Desktop.
Save albertpb/e393cffe17e42e37bc007a99f9175ab5 to your computer and use it in GitHub Desktop.
oobabooga text-generation-webui api GET characters
import json
import ssl
import base64
import yaml
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from pathlib import Path
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from urllib.parse import urlparse, parse_qs
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import (
encode,
generate_reply,
stop_everything_event
)
from modules.utils import (get_available_models,
get_available_characters)
from modules.logging_colors import logger
def get_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
# dump
'shared.settings': shared.settings,
'shared.args': vars(shared.args),
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
elif self.path.startswith('/api/v1/characters'):
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
args = parse_qs(urlparse(self.path).query)
name = args.get('name', None)
# picture_response_format = args.get('picture_response_format', 'b64_json')
if name:
name = name[0]
picture_data = None
data = None
for extension in ['png', 'jpg', 'jpeg']:
filepath = Path(f"characters/{name}.{extension}")
if filepath.exists():
with open(filepath, 'rb') as f:
file_contents = f.read()
encoded_bytes = base64.b64encode(file_contents)
# Turn raw base64 encoded bytes into ASCII
# TODO: support 'url' and 'data': url ? data_url?
img_data = encoded_bytes.decode('ascii')
img_filename = f"{name}.{extension}"
img_encoding = 'b64_json' # like SD, maybe also accept 'url'.. 'data_url'?
picture_data = {
'filename': img_filename,
'encoding': img_encoding,
# or 'data': url, and/or #'url': f"data:image/png;base64,{img_data}",
'data': img_data,
}
break
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'characters/{name}.{extension}')
if filepath.exists():
with open(filepath, 'r', encoding='utf-8') as f:
file_contents = f.read()
data = json.loads(
file_contents) if extension == "json" else yaml.safe_load(file_contents)
break
if not (picture_data or data):
self.send_error(404, message="Character not found")
return
resp = {
'data': data,
'picture': picture_data,
}
response = json.dumps({
'results': resp
})
else:
response = json.dumps({
'results': get_available_characters()
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
response = json.dumps({
'results': [{
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = generate_params['history']
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/stop-stream':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
stop_everything_event()
response = json.dumps({
'results': 'success'
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/model':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# by default return the same as the GET interface
result = shared.model_name
# Actions: info, load, list, unload
action = body.get('action', '')
if action == 'load':
model_name = body['model_name']
args = body.get('args', {})
print('args', args)
for k in args:
setattr(shared.args, k, args[k])
shared.model_name = model_name
unload_model()
model_settings = get_model_metadata(shared.model_name)
shared.settings.update(
{k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
try:
shared.model, shared.tokenizer = load_model(
shared.model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora) # list
except Exception as e:
response = json.dumps({'error': {'message': repr(e)}})
self.wfile.write(response.encode('utf-8'))
raise e
shared.args.model = shared.model_name
result = get_model_info()
elif action == 'unload':
unload_model()
shared.model_name = None
shared.args.model = None
result = get_model_info()
elif action == 'list':
result = get_available_models()
elif action == 'info':
result = get_model_info()
response = json.dumps({
'result': result,
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_OPTIONS(self):
self.send_response(200)
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
self.send_header(
'Cache-Control', 'no-store, no-cache, must-revalidate')
super().end_headers()
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
def on_start(public_url: str):
logger.info(
f'Starting non-streaming server at public url {public_url}/api')
if share:
try:
try_start_cloudflared(
port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
if ssl_verify:
logger.info(f'Starting API at https://{address}:{port}/api')
else:
logger.info(f'Starting API at http://{address}:{port}/api')
server.serve_forever()
def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[
port, share, tunnel_id], daemon=True).start()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment