Created
November 23, 2023 08:21
-
-
Save wong2/b630fea734190bcf11360d2f6cb56496 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from io import BytesIO | |
import torch | |
from fastapi.responses import StreamingResponse | |
from leptonai.photon import Photon | |
from loguru import logger | |
class JPEGResponse(StreamingResponse): | |
media_type = "image/jpeg" | |
class LCM(Photon): | |
requirement_dependency = ["torch", "diffusers", "Pillow", "peft"] | |
# In default, we will use gpu.a10 as the computation resource shape. This should be fast enough. | |
deployment_template = { | |
"resource_shape": "gpu.a10", | |
"env": {}, | |
} | |
# A10 should be able to support a maximum concurrency of 8 requests to interleave | |
# IO and compute. This is not tuned by the way. | |
handler_max_concurrency = 8 | |
def init(self): | |
from diffusers import DiffusionPipeline, LCMScheduler | |
cuda_available = torch.cuda.is_available() | |
self.pipeline = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16 if cuda_available else torch.float32, | |
) | |
self.pipeline.to("cuda" if cuda_available else "cpu") | |
# set scheduler | |
self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config) | |
# load LCM-LoRA | |
self.pipeline.load_lora_weights("latent-consistency/lcm-lora-sdxl") | |
logger.info(f"Initialized model. cuda: {cuda_available}.") | |
@Photon.handler( | |
"run", | |
example={ | |
"prompt": "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
"seed": 2159232, | |
"steps": 4, | |
"guidance_scale": 8.0, | |
}, | |
) | |
def run( | |
self, | |
prompt: str, | |
seed: int = 0, | |
steps: int = 4, | |
guidance_scale: float = 1.0, | |
) -> JPEGResponse: | |
generator = torch.manual_seed(seed) | |
output = self.pipeline( | |
prompt=prompt, | |
num_inference_steps=steps, | |
generator=generator, | |
guidance_scale=guidance_scale, | |
) | |
img_io = BytesIO() | |
output.images[0].save(img_io, format="JPEG") | |
img_io.seek(0) | |
return JPEGResponse(img_io) | |
if __name__ == "__main__": | |
p = LCM() | |
p.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment