Skip to content

Instantly share code, notes, and snippets.

@danslinky
Last active June 9, 2024 22:47
Show Gist options
  • Save danslinky/3082e14cebb4aa5072abcaa2121e69f5 to your computer and use it in GitHub Desktop.
Save danslinky/3082e14cebb4aa5072abcaa2121e69f5 to your computer and use it in GitHub Desktop.
A text-to-image toy that creates images on a serverless # [RunPod](https://www.runpod.io/) endpoint.
# %% [markdown]
# # RunPod Automatic1111 Stable Diffusion
#
# A text-to-image toy that creates images on a serverless
# [RunPod](https://www.runpod.io/) endpoint. Enter a prompt,
# using the form or function.
#
# ```py
# new("a cow in space")
# ```
#
# Requires `RUNPOD_API_KEY` and `RUNPOD_ENDPOINT_ID` in a .env file.
myprompt = """
Bald topped man, side hair though, a light grey facial hair.
Wearing a suit.
A humble, honest man stands as a witness at the Post Office Horizon IT Inquiry.
He is a decent, respectable individual and a true hero.
The man speaks with sincerity and integrity.
The setting is a formal inquiry room, with an official atmosphere.
He was the Subpostmaster of the Craig-y-Don Post Office, Craig-y-Don, Llandudno, North Wales.
Between 31 March 1998 to 5 November 2003.
The scene conveys a sense of solemn duty and commitment to truth and justice.
It has been a scandal, a disgrace.
Looks like Toby Jones.
"""
import base64
import time
import dotenv
import os
import runpod
from PIL import Image
from IPython.display import display
from IPython.display import Image as DisplayImage
import itertools
from ipywidgets import Label, HTML, Button, Textarea, VBox
import logging
logging.basicConfig(level=logging.INFO)
dotenv.load_dotenv('.env')
if not os.getenv("RUNPOD_API_KEY") or not os.getenv("RUNPOD_ENDPOINT_ID"):
raise Exception("Please set the RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID environment variables.")
# runpod.api_url_base = "http://localhost:8000" # local testing
runpod.api_key = os.getenv("RUNPOD_API_KEY")
class RunPodRequest:
def __init__(self, prompt):
self.endpoint = runpod.Endpoint(os.getenv("RUNPOD_ENDPOINT_ID"))
self.prompt = prompt
self.run_request = None
self.status = None
self.output = None
def run(self):
try:
self.run_request = self.endpoint.run({"input":{"prompt": self.prompt}})
except Exception as e:
logging.error(e)
raise
self.poll_status()
def get_output(self):
self.output = self.run_request.output()
def poll_status(self):
spinner = itertools.cycle(['*', '**', '***'])
spinner_label = Label()
display(spinner_label)
self.status = self.run_request.status()
while self.status not in ["COMPLETED", "FAILED"]:
spinner_label.value = next(spinner)
time.sleep(1)
self.status = self.run_request.status()
if self.status == "COMPLETED":
self.get_output()
spinner_label.close()
elif self.status == "FAILED":
raise Exception("RunPod request failed")
def display_result(self):
result_label = HTML()
display(result_label)
if self.status == "COMPLETED":
images = self.decode_images()
display(DisplayImage(images[0]))
elif self.status == "FAILED":
result_label.value = f"<p style='white-space: pre-wrap;'>The request failed. {self.output}</p>"
def get_images(self):
return self.output['images']
def decode_images(self):
images = self.get_images()
return [base64.b64decode(image) for image in images]
text = Textarea(
value=myprompt.strip(),
placeholder='Enter your prompt here',
layout={'width': 'auto', 'height': 'auto'},
rows=10,
)
button = Button(description="Submit")
def new(myprompt=None):
if myprompt is None:
myprompt = text.value
request = RunPodRequest(myprompt)
request.run()
request.display_result()
def on_button_clicked(b):
new()
button.on_click(on_button_clicked)
display(VBox([text, button]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment