Created
April 25, 2024 16:13
-
-
Save priamai/8d98401cb6eab9126fcfd91af06fed47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ## Basic setup | |
import io | |
from pathlib import Path | |
from fastapi import FastAPI, Request | |
import modal | |
from modal import ( | |
App, | |
Image, | |
Mount, | |
asgi_app, | |
build, | |
enter, | |
gpu, | |
method, | |
web_endpoint, | |
) | |
# ## Define a container image | |
# | |
# To take advantage of Modal's blazing fast cold-start times, we'll need to download our model weights | |
# inside our container image with a download function. We ignore binaries, ONNX weights and 32-bit weights. | |
# | |
# Tip: avoid using global variables in this function to ensure the download step detects model changes and | |
# triggers a rebuild. | |
sdxl_image = ( | |
Image.debian_slim(python_version="3.10") | |
.apt_install( | |
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1" | |
) | |
.pip_install( | |
"diffusers==0.26.3", | |
"invisible_watermark==0.2.0", | |
"transformers~=4.38.2", | |
"accelerate==0.27.2", | |
"safetensors==0.4.2", | |
) | |
) | |
app = App( | |
"sd-xl-api" | |
) | |
with sdxl_image.imports(): | |
import torch | |
from diffusers import DiffusionPipeline | |
from fastapi import Response | |
# ## Load model and run inference | |
# | |
# The container lifecycle [`@enter` decorator](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta) | |
# loads the model at startup. Then, we evaluate it in the `run_inference` function. | |
# | |
# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay | |
# online for 4 minutes before spinning down. This can be adjusted for cost/experience trade-offs. | |
@app.cls(gpu=gpu.A10G(), container_idle_timeout=240, image=sdxl_image) | |
class Model: | |
@build() | |
def build(self): | |
from huggingface_hub import snapshot_download | |
ignore = [ | |
"*.bin", | |
"*.onnx_data", | |
"*/diffusion_pytorch_model.safetensors", | |
] | |
snapshot_download( | |
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore | |
) | |
snapshot_download( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
ignore_patterns=ignore, | |
) | |
@enter() | |
def enter(self): | |
load_options = dict( | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
device_map="auto", | |
) | |
# Load base model | |
self.base = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", **load_options | |
) | |
# Load refiner model | |
self.refiner = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
text_encoder_2=self.base.text_encoder_2, | |
vae=self.base.vae, | |
**load_options, | |
) | |
# Compiling the model graph is JIT so this will increase inference time for the first run | |
# but speed up subsequent runs. Uncomment to enable. | |
# self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True) | |
# self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True) | |
def _inference(self, prompt, n_steps=24, high_noise_frac=0.8): | |
negative_prompt = "disfigured, ugly, deformed" | |
image = self.base( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=n_steps, | |
denoising_end=high_noise_frac, | |
output_type="latent", | |
).images | |
image = self.refiner( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=n_steps, | |
denoising_start=high_noise_frac, | |
image=image, | |
).images[0] | |
byte_stream = io.BytesIO() | |
image.save(byte_stream, format="JPEG") | |
return byte_stream | |
@method() | |
def inference(self, prompt, n_steps=24, high_noise_frac=0.8): | |
return self._inference( | |
prompt, n_steps=n_steps, high_noise_frac=high_noise_frac | |
).getvalue() | |
web_app = FastAPI() | |
@web_app.get("/txt2img") | |
async def authme(request: Request): | |
# also add token later | |
params = request.query_params | |
print(params) | |
model = Model() | |
result = model.inference(prompt=params['prompt'],n_steps=1,high_noise_frac=0.8) | |
return Response( | |
content=result, | |
media_type="image/jpeg", | |
) | |
@app.function(image=sdxl_image,secrets=[modal.Secret.from_name("game")]) | |
@asgi_app() | |
def fastapi_app(): | |
return web_app | |
@app.local_entrypoint() | |
def main(): | |
print("Server started!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment