Skip to content

Instantly share code, notes, and snippets.

@daveschumaker
Last active March 18, 2024 13:13
Show Gist options
  • Save daveschumaker/0b4358012dc6a2b5c13e186a5ed84491 to your computer and use it in GitHub Desktop.
Save daveschumaker/0b4358012dc6a2b5c13e186a5ed84491 to your computer and use it in GitHub Desktop.
A simple http server for Stable Diffusion that can generate images using a simple RESTful API using the lstein / InvokeAI fork of Stable Diffusion.
# pylint: disable=broad-except
# pylint: disable=global-statement
# pylint: disable=invalid-name
# pylint: disable=no-member
# pylint: disable=unsupported-membership-test
"""Creates a proxy server to interface with the stable diffusion engine using http requests"""
import base64
from json import dumps
from io import BytesIO
import random
import threading
import time
from ldm.generate import Generate
from ldm.dream.pngwriter import PngWriter
from flask import Flask, request
import requests
app = Flask(__name__)
server_online = False
t2i = None
opt = {}
### STABLE DIFFUSION API
def init_t2i():
"""Load model"""
global t2i
global server_online
# Valid sampler choices
# 'ddim',
# 'k_dpm_2_a',
# 'k_dpm_2',
# 'k_euler_a',
# 'k_euler',
# 'k_heun',
# 'k_lms',
# 'plms',
# sampler_name='k_euler_a',
# Adding config parameters here kills
# my linux server for some reason,
# but work fine on Windows 10 and MacOS.
# Leaving blank for now, as Paramus can
# be passed in later.
t2i = Generate()
print('')
print('Image server online!')
server_online = True
def image_writer(image, seed):
"""Generates image"""
# NOTE: Seed is unused here, but required since SD's prompt2image function returns 2 parameters.
file_writer = PngWriter('./outputs/artbot')
current_task_id = opt['current_task_id']
filename = f'{current_task_id}.png'
# Via: https://stackoverflow.com/a/31826470
buffered = BytesIO()
image.save(buffered, format='PNG')
img_str = base64.b64encode(buffered.getvalue())
base64_string = img_str.decode('utf-8')
image_info_to_send = {
'filename': opt['current_task_id'] + '-' + str(opt['seed']) + '.png',
'seed': opt['seed'],
'sampler': opt['sampler'],
'steps': opt['steps'],
'cfg_scale': opt['cfg_scale'],
'prompt': opt['prompt'],
'height': opt['height'],
'width': opt['width'],
'task_id': opt['current_task_id'],
'success': True,
'encoded_image': base64_string
}
print("image_writer: Image generation done, imammmm to upload proxy!")
file_writer.save_image_and_prompt_to_png(image, 'an_image', filename)
try:
requests.post(opt['upload_api_endpoint'], json=image_info_to_send, timeout=5)
except Exception as e:
print('image_writer: Uh oh!')
print(e)
opt['is_busy'] = False
opt['current_task_id'] = False
print("image_writer: Done!")
return dumps({
'success': True
})
def create_img():
"""Create image"""
t2i.prompt2image(
prompt=opt['prompt'],
outdir="./outputs",
image_callback=image_writer,
width = int(opt['width']),
height = int(opt['height']),
seed = int(opt['seed']),
sampler_name=opt['sampler'],
full_precision=True,
steps=int(opt['steps']),
cfg_scale=opt['cfg_scale']
)
return dumps({
'success': True,
'status': 'Image created...'
})
### ROUTES
@app.route('/ping', methods=["GET"])
def ping_route():
try:
if 'is_busy' not in opt:
opt['is_busy'] = False
if 'current_task_id' not in opt:
opt['current_task_id'] = False
return dumps({
"success": True,
"current_task_id": opt['current_task_id'],
"is_busy": opt['is_busy']
})
except Exception as e:
print('launch_image_proxy: uh oh!')
print(e)
@app.route('/create', methods=["POST"])
def proxy_create_image_route():
global opt
request_data = request.get_json()
if not request_data or 'prompt' not in request_data:
return dumps({
"success": False,
"status": "Missing prompt"
})
if not request_data or 'upload_api_endpoint' not in request_data:
return dumps({
"success": False,
"status": "Missing upload_api_endpoint"
})
if 'is_busy' in opt and opt['is_busy'] is True:
return dumps({
"success": False,
"current_task_id": opt['current_task_id'],
"status": "Server is currently busy"
})
if server_online is False:
return dumps({
"success": False,
"status": "Image server is still booting up"
})
# Various defaults
opt = {}
if 'task_id' not in request_data:
opt['current_task_id'] = False
else:
opt['current_task_id'] = request_data['task_id']
if 'height' not in request_data:
opt['height'] = 128 # min 128 on MacOS otherwise semaphore error
else:
opt['height'] = int(request_data['height'])
if 'width' not in request_data:
opt['width'] = 128 # min 128 on MacOS otherwise semaphore error
else:
opt['width'] = int(request_data['width'])
if 'sampler' not in request_data:
opt['sampler'] = 'k_euler'
else:
opt['sampler'] = request_data['sampler']
if 'seed' not in request_data:
opt['seed'] = random.randint(0, 2**32 - 1)
else:
opt['seed'] = int(request_data['seed'])
if 'steps' not in request_data:
opt['steps'] = 32
else:
opt['steps'] = int(request_data['steps'])
if 'cfg_scale' not in request_data:
opt['cfg_scale'] = 12.0
else:
opt['cfg_scale'] = float(request_data['cfg_scale'])
if 'is_busy' not in opt:
opt['is_busy'] = False
# default values
opt['prompt'] = request_data.get('prompt')
opt['upload_api_endpoint'] = request_data['upload_api_endpoint']
# Run image generation stuff...
opt['is_busy'] = True
try:
print('Generating image...')
create_img()
except Exception as e:
print('proxy_create_image_route: Uh oh!')
print(e)
return dumps({
"success": True
})
if __name__ == '__main__':
init_t2i()
app.run(host='0.0.0.0', port=5003, debug=True)
@Teie
Copy link

Teie commented Oct 26, 2022

Hi.

pylint: disable=no-member

It's some kind of "Login" for SD?
I'm a volunteer giving computer classes and I'm making my personal machine available for students to use the SD.
For now, only DDNS Port Forwarding is working well, but it's a matter of time before some rude person ruins everything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment