Skip to content

Instantly share code, notes, and snippets.

@priamai
Created April 25, 2024 16:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save priamai/8d98401cb6eab9126fcfd91af06fed47 to your computer and use it in GitHub Desktop.
Save priamai/8d98401cb6eab9126fcfd91af06fed47 to your computer and use it in GitHub Desktop.
# ## Basic setup
import io
from pathlib import Path
from fastapi import FastAPI, Request
import modal
from modal import (
App,
Image,
Mount,
asgi_app,
build,
enter,
gpu,
method,
web_endpoint,
)
# ## Define a container image
#
# To take advantage of Modal's blazing fast cold-start times, we'll need to download our model weights
# inside our container image with a download function. We ignore binaries, ONNX weights and 32-bit weights.
#
# Tip: avoid using global variables in this function to ensure the download step detects model changes and
# triggers a rebuild.
sdxl_image = (
Image.debian_slim(python_version="3.10")
.apt_install(
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1"
)
.pip_install(
"diffusers==0.26.3",
"invisible_watermark==0.2.0",
"transformers~=4.38.2",
"accelerate==0.27.2",
"safetensors==0.4.2",
)
)
app = App(
"sd-xl-api"
)
with sdxl_image.imports():
import torch
from diffusers import DiffusionPipeline
from fastapi import Response
# ## Load model and run inference
#
# The container lifecycle [`@enter` decorator](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# loads the model at startup. Then, we evaluate it in the `run_inference` function.
#
# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay
# online for 4 minutes before spinning down. This can be adjusted for cost/experience trade-offs.
@app.cls(gpu=gpu.A10G(), container_idle_timeout=240, image=sdxl_image)
class Model:
@build()
def build(self):
from huggingface_hub import snapshot_download
ignore = [
"*.bin",
"*.onnx_data",
"*/diffusion_pytorch_model.safetensors",
]
snapshot_download(
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore
)
snapshot_download(
"stabilityai/stable-diffusion-xl-refiner-1.0",
ignore_patterns=ignore,
)
@enter()
def enter(self):
load_options = dict(
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
device_map="auto",
)
# Load base model
self.base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", **load_options
)
# Load refiner model
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.base.text_encoder_2,
vae=self.base.vae,
**load_options,
)
# Compiling the model graph is JIT so this will increase inference time for the first run
# but speed up subsequent runs. Uncomment to enable.
# self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True)
# self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True)
def _inference(self, prompt, n_steps=24, high_noise_frac=0.8):
negative_prompt = "disfigured, ugly, deformed"
image = self.base(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]
byte_stream = io.BytesIO()
image.save(byte_stream, format="JPEG")
return byte_stream
@method()
def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
return self._inference(
prompt, n_steps=n_steps, high_noise_frac=high_noise_frac
).getvalue()
web_app = FastAPI()
@web_app.get("/txt2img")
async def authme(request: Request):
# also add token later
params = request.query_params
print(params)
model = Model()
result = model.inference(prompt=params['prompt'],n_steps=1,high_noise_frac=0.8)
return Response(
content=result,
media_type="image/jpeg",
)
@app.function(image=sdxl_image,secrets=[modal.Secret.from_name("game")])
@asgi_app()
def fastapi_app():
return web_app
@app.local_entrypoint()
def main():
print("Server started!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment