Skip to content

Instantly share code, notes, and snippets.

@wong2
Created November 23, 2023 08:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wong2/b630fea734190bcf11360d2f6cb56496 to your computer and use it in GitHub Desktop.
Save wong2/b630fea734190bcf11360d2f6cb56496 to your computer and use it in GitHub Desktop.
from io import BytesIO
import torch
from fastapi.responses import StreamingResponse
from leptonai.photon import Photon
from loguru import logger
class JPEGResponse(StreamingResponse):
media_type = "image/jpeg"
class LCM(Photon):
requirement_dependency = ["torch", "diffusers", "Pillow", "peft"]
# In default, we will use gpu.a10 as the computation resource shape. This should be fast enough.
deployment_template = {
"resource_shape": "gpu.a10",
"env": {},
}
# A10 should be able to support a maximum concurrency of 8 requests to interleave
# IO and compute. This is not tuned by the way.
handler_max_concurrency = 8
def init(self):
from diffusers import DiffusionPipeline, LCMScheduler
cuda_available = torch.cuda.is_available()
self.pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16 if cuda_available else torch.float32,
)
self.pipeline.to("cuda" if cuda_available else "cpu")
# set scheduler
self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config)
# load LCM-LoRA
self.pipeline.load_lora_weights("latent-consistency/lcm-lora-sdxl")
logger.info(f"Initialized model. cuda: {cuda_available}.")
@Photon.handler(
"run",
example={
"prompt": "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"seed": 2159232,
"steps": 4,
"guidance_scale": 8.0,
},
)
def run(
self,
prompt: str,
seed: int = 0,
steps: int = 4,
guidance_scale: float = 1.0,
) -> JPEGResponse:
generator = torch.manual_seed(seed)
output = self.pipeline(
prompt=prompt,
num_inference_steps=steps,
generator=generator,
guidance_scale=guidance_scale,
)
img_io = BytesIO()
output.images[0].save(img_io, format="JPEG")
img_io.seek(0)
return JPEGResponse(img_io)
if __name__ == "__main__":
p = LCM()
p.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment