Skip to content

Instantly share code, notes, and snippets.

@jonobr1
Created January 6, 2023 03:58
Show Gist options
  • Save jonobr1/284dab4b8241a3ca22c5fd10a2143a56 to your computer and use it in GitHub Desktop.
Save jonobr1/284dab4b8241a3ca22c5fd10a2143a56 to your computer and use it in GitHub Desktop.
runtime: python37
service: prototype
env_variables:
STABILITY_HOST: "grpc.stability.ai:443"
STABILITY_KEY: "sk-FDeORcjpafivBCYtq98HYaCpBJ7hbuhxoi4KZ0P1tr3Fs7Aq"
CLOUD_STORAGE_BUCKET: "cdn.rifff.com"
handlers:
- url: /api/.*
script: auto
secure: always
login: admin
- url: /(.*\.(gif|png|jpg|jpeg|ico|css|map|json|js|eot|svg|ttf|woff|woff2|ogg|mp3|wav|ogg|mp4|webm|xml|html|fbx|gltf))$
static_files: public/\1
upload: .*\.(gif|png|jpg|jpeg|ico|css|map|json|js|eot|svg|ttf|woff|woff2|ogg|mp3|wav|ogg|mp4|webm|xml|html|fbx|gltf)$
- url: /.*
static_files: public/index.html
upload: public/index.html
secure: always
login: admin
import os
from flask import Flask
from flask_cors import CORS
from google.cloud import storage
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import uuid
IMAGE_EXTENSIONS = { 'png', 'jpg', 'jpeg', 'gif' }
cs = storage.Client()
bucket = cs.bucket(os.environ['CLOUD_STORAGE_BUCKET'])
def allowed(filename: str):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in IMAGE_EXTENSIONS
def upload(src: str, contents):
blob = bucket.blob(src)
blob.upload_from_string(contents)
return { 'src': 'https://storage.googleapis.com/{}/{}'.format(os.environ['CLOUD_STORAGE_BUCKET'], src) }
#
app = Flask(__name__)
CORS(app)
stability = client.StabilityInference(
key=os.environ['STABILITY_KEY'],
verbose=True,
)
#
@app.route('/api/prompt/<prompt>', methods=['GET'])
def on_prompt(prompt=None):
if not isinstance(prompt, str):
return { 'error': 'API Error: prompt does not exist.' }
answers = stability.generate(
prompt=prompt,
seed=34567, # if provided, specifying a random seed makes results deterministic
steps=30, # defaults to 50 if not specified
)
results = list()
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
return { 'error': 'API Error: Stable Diffusion safety filters. Please modify the prompt and try again.' }
if artifact.type == generation.ARTIFACT_IMAGE:
id = uuid.uuid4()
data = upload('prompts/{}.png'.format(id), artifact.binary)
results.append(data)
return { 'results': results }
if __name__ == '__main__':
# This is used when running locally only. When deploying to Google App
# Engine, a webserver process such as Gunicorn will serve the app. You
# can configure startup instructions by adding `entrypoint` to app.yaml.
app.run(host='127.0.0.1', port=3000, debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment