Skip to content

Instantly share code, notes, and snippets.

@cgcardona
Last active September 25, 2023 02:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgcardona/2709914b11a7376fcc134000c0e92505 to your computer and use it in GitHub Desktop.
Save cgcardona/2709914b11a7376fcc134000c0e92505 to your computer and use it in GitHub Desktop.
Flask app to serve a GET request at "/" which accepts a `?prompt=my creative and expressive stable diffusion prompt`
# all the imports
import io, torch, time, math, os
# import specified modules
from flask import Flask, request, send_file
from torch import autocast
from diffusers import StableDiffusionPipeline
# create a new flask app
app = Flask(__name__)
# confirm GPU supports the NVIDIA machine learning toolkit
assert torch.cuda.is_available()
# Stable Diffusion v1.4: CompVis/stable-diffusion-v1-4
# Stable Diffusion v1.5: runwayml/stable-diffusion-v1-5
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=True).to("cuda")
def run_inference(prompt):
# get first 20 chars of the prompt to use as the file name
first_20_chars = prompt[0:20]
# sanitize the chars by removing any `,` characters
sanitized_chars = first_20_chars.replace(",", "")
# sanitize the chars by replacing ` ` with `-`
sanitized_chars = sanitized_chars.replace(" ", "-")
with autocast("cuda"):
image = pipe(prompt).images[0]
img_data = io.BytesIO()
image.save(img_data, "PNG")
img_data.seek(0)
timestamp = math.ceil(time.time())
title = f"{sanitized_chars}-{timestamp}"
parent_dir = "/path/to/generated-assets"
# each time inference.py is run a new directory should be created which is named
# the current timestamp. This new directory is where the newly generated
# image should be saved
directory = title
file_path = os.path.join(parent_dir, directory)
# set permissions
mode = 0o744
os.mkdir(file_path, mode)
file_name = f"{title}.png"
file_path_and_name = f"{file_path}/{file_name}"
image.save(file_path_and_name)
print(f"{file_path_and_name} created!")
# success
# ✨ 😎 ✨
sparkle = "\U00002728"
sunglasses = "\U0001F60E"
print (f"{sparkle} {sunglasses} {sparkle}")
print ("Winning")
return img_data
@app.route('/')
def myapp():
if "prompt" not in request.args:
return "Please specify a prompt parameter", 400
# prompt gets passed in as a query string parameter
prompt = request.args["prompt"]
img_data = run_inference(prompt)
return send_file(img_data, mimetype='image/png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment