Last active
March 18, 2024 13:13
-
-
Save daveschumaker/0b4358012dc6a2b5c13e186a5ed84491 to your computer and use it in GitHub Desktop.
A simple http server for Stable Diffusion that can generate images using a simple RESTful API using the lstein / InvokeAI fork of Stable Diffusion.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pylint: disable=broad-except | |
# pylint: disable=global-statement | |
# pylint: disable=invalid-name | |
# pylint: disable=no-member | |
# pylint: disable=unsupported-membership-test | |
"""Creates a proxy server to interface with the stable diffusion engine using http requests""" | |
import base64 | |
from json import dumps | |
from io import BytesIO | |
import random | |
import threading | |
import time | |
from ldm.generate import Generate | |
from ldm.dream.pngwriter import PngWriter | |
from flask import Flask, request | |
import requests | |
app = Flask(__name__) | |
server_online = False | |
t2i = None | |
opt = {} | |
### STABLE DIFFUSION API | |
def init_t2i(): | |
"""Load model""" | |
global t2i | |
global server_online | |
# Valid sampler choices | |
# 'ddim', | |
# 'k_dpm_2_a', | |
# 'k_dpm_2', | |
# 'k_euler_a', | |
# 'k_euler', | |
# 'k_heun', | |
# 'k_lms', | |
# 'plms', | |
# sampler_name='k_euler_a', | |
# Adding config parameters here kills | |
# my linux server for some reason, | |
# but work fine on Windows 10 and MacOS. | |
# Leaving blank for now, as Paramus can | |
# be passed in later. | |
t2i = Generate() | |
print('') | |
print('Image server online!') | |
server_online = True | |
def image_writer(image, seed): | |
"""Generates image""" | |
# NOTE: Seed is unused here, but required since SD's prompt2image function returns 2 parameters. | |
file_writer = PngWriter('./outputs/artbot') | |
current_task_id = opt['current_task_id'] | |
filename = f'{current_task_id}.png' | |
# Via: https://stackoverflow.com/a/31826470 | |
buffered = BytesIO() | |
image.save(buffered, format='PNG') | |
img_str = base64.b64encode(buffered.getvalue()) | |
base64_string = img_str.decode('utf-8') | |
image_info_to_send = { | |
'filename': opt['current_task_id'] + '-' + str(opt['seed']) + '.png', | |
'seed': opt['seed'], | |
'sampler': opt['sampler'], | |
'steps': opt['steps'], | |
'cfg_scale': opt['cfg_scale'], | |
'prompt': opt['prompt'], | |
'height': opt['height'], | |
'width': opt['width'], | |
'task_id': opt['current_task_id'], | |
'success': True, | |
'encoded_image': base64_string | |
} | |
print("image_writer: Image generation done, imammmm to upload proxy!") | |
file_writer.save_image_and_prompt_to_png(image, 'an_image', filename) | |
try: | |
requests.post(opt['upload_api_endpoint'], json=image_info_to_send, timeout=5) | |
except Exception as e: | |
print('image_writer: Uh oh!') | |
print(e) | |
opt['is_busy'] = False | |
opt['current_task_id'] = False | |
print("image_writer: Done!") | |
return dumps({ | |
'success': True | |
}) | |
def create_img(): | |
"""Create image""" | |
t2i.prompt2image( | |
prompt=opt['prompt'], | |
outdir="./outputs", | |
image_callback=image_writer, | |
width = int(opt['width']), | |
height = int(opt['height']), | |
seed = int(opt['seed']), | |
sampler_name=opt['sampler'], | |
full_precision=True, | |
steps=int(opt['steps']), | |
cfg_scale=opt['cfg_scale'] | |
) | |
return dumps({ | |
'success': True, | |
'status': 'Image created...' | |
}) | |
### ROUTES | |
@app.route('/ping', methods=["GET"]) | |
def ping_route(): | |
try: | |
if 'is_busy' not in opt: | |
opt['is_busy'] = False | |
if 'current_task_id' not in opt: | |
opt['current_task_id'] = False | |
return dumps({ | |
"success": True, | |
"current_task_id": opt['current_task_id'], | |
"is_busy": opt['is_busy'] | |
}) | |
except Exception as e: | |
print('launch_image_proxy: uh oh!') | |
print(e) | |
@app.route('/create', methods=["POST"]) | |
def proxy_create_image_route(): | |
global opt | |
request_data = request.get_json() | |
if not request_data or 'prompt' not in request_data: | |
return dumps({ | |
"success": False, | |
"status": "Missing prompt" | |
}) | |
if not request_data or 'upload_api_endpoint' not in request_data: | |
return dumps({ | |
"success": False, | |
"status": "Missing upload_api_endpoint" | |
}) | |
if 'is_busy' in opt and opt['is_busy'] is True: | |
return dumps({ | |
"success": False, | |
"current_task_id": opt['current_task_id'], | |
"status": "Server is currently busy" | |
}) | |
if server_online is False: | |
return dumps({ | |
"success": False, | |
"status": "Image server is still booting up" | |
}) | |
# Various defaults | |
opt = {} | |
if 'task_id' not in request_data: | |
opt['current_task_id'] = False | |
else: | |
opt['current_task_id'] = request_data['task_id'] | |
if 'height' not in request_data: | |
opt['height'] = 128 # min 128 on MacOS otherwise semaphore error | |
else: | |
opt['height'] = int(request_data['height']) | |
if 'width' not in request_data: | |
opt['width'] = 128 # min 128 on MacOS otherwise semaphore error | |
else: | |
opt['width'] = int(request_data['width']) | |
if 'sampler' not in request_data: | |
opt['sampler'] = 'k_euler' | |
else: | |
opt['sampler'] = request_data['sampler'] | |
if 'seed' not in request_data: | |
opt['seed'] = random.randint(0, 2**32 - 1) | |
else: | |
opt['seed'] = int(request_data['seed']) | |
if 'steps' not in request_data: | |
opt['steps'] = 32 | |
else: | |
opt['steps'] = int(request_data['steps']) | |
if 'cfg_scale' not in request_data: | |
opt['cfg_scale'] = 12.0 | |
else: | |
opt['cfg_scale'] = float(request_data['cfg_scale']) | |
if 'is_busy' not in opt: | |
opt['is_busy'] = False | |
# default values | |
opt['prompt'] = request_data.get('prompt') | |
opt['upload_api_endpoint'] = request_data['upload_api_endpoint'] | |
# Run image generation stuff... | |
opt['is_busy'] = True | |
try: | |
print('Generating image...') | |
create_img() | |
except Exception as e: | |
print('proxy_create_image_route: Uh oh!') | |
print(e) | |
return dumps({ | |
"success": True | |
}) | |
if __name__ == '__main__': | |
init_t2i() | |
app.run(host='0.0.0.0', port=5003, debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi.
pylint: disable=no-member
It's some kind of "Login" for SD?
I'm a volunteer giving computer classes and I'm making my personal machine available for students to use the SD.
For now, only DDNS Port Forwarding is working well, but it's a matter of time before some rude person ruins everything.