Skip to content

Instantly share code, notes, and snippets.

@python273
Created July 27, 2023 17:20
Show Gist options
  • Save python273/ae9d085ce9f2968b50c6ab90f2017076 to your computer and use it in GitHub Desktop.
Save python273/ae9d085ce9f2968b50c6ab90f2017076 to your computer and use it in GitHub Desktop.
not really optimal, generation blocks the eventloop
import asyncio
import torch
from torch import autocast
# from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
from diffusers import LMSDiscreteScheduler
from my_gen_pipeline import StableDiffusionPipeline
from datetime import datetime
from conf import UNIX_SOCKET_PATH
# torch.cuda.empty_cache()
pipe = None
lock = None
async def handle_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
global lock
prompt = await reader.readline()
if not prompt.endswith(b'\n'):
return
print()
print(f'prompt: {repr(prompt)}')
prompt = prompt.decode('utf-8').strip()
print(f'decoded {prompt!r}')
async with lock:
filename = f"images/img-{datetime.utcnow().isoformat().replace('/', '_')}.png"
prompt_filename = filename + '.txt'
with autocast("cuda"):
image = pipe.gen(
prompt,
height=640,
width=640,
guidance_scale=9,
num_inference_steps=65
)["sample"][0]
with open(prompt_filename, 'w') as f:
f.write(prompt)
image.save(filename)
print('sending', filename)
writer.write(f'{filename}\n'.encode('ascii'))
await writer.drain()
writer.close()
await writer.wait_closed()
background_tasks = set()
def create_connection_task(*args, **kwargs):
task = asyncio.create_task(handle_connection(*args, **kwargs))
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
async def main():
global pipe, lock
lock = asyncio.Lock()
# lms = LMSDiscreteScheduler(
# beta_start=0.00085,
# beta_end=0.012,
# beta_schedule="scaled_linear"
# )
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
# scheduler=lms,
revision="fp16", torch_dtype=torch.float16,
use_auth_token=" TOKEN "
)
pipe = pipe.to('cuda')
server = await asyncio.start_unix_server(
create_connection_task, UNIX_SOCKET_PATH
)
print(server)
async with server:
await server.serve_forever()
if __name__ == '__main__':
asyncio.run(main(), debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment