-
-
Save chavinlo/d8005ebda6499853891c9edae8765b4b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import oneflow as torch | |
import torch as og_torch | |
from queue import Queue | |
def is_accelerate_available(): | |
return False | |
from transformers import CLIPTextModel, CLIPTokenizer | |
import os | |
import oneflow as flow | |
from .lpw import LongPromptWeightingPipeline | |
from diffusers import ( | |
OneFlowDDIMScheduler, | |
OneFlowDPMSolverMultistepScheduler, | |
OneFlowEulerAncestralDiscreteScheduler, | |
OneFlowEulerDiscreteScheduler, | |
OneFlowPNDMScheduler, | |
OneFlowStableDiffusionPipeline | |
) | |
import time | |
from PIL import Image | |
import random | |
from io import BytesIO | |
import base64 | |
import traceback | |
def prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, init_noise_sigma, latents=None): | |
shape = (batch_size, num_channels_latents, height // 8, width // 8) | |
if latents is None: | |
if device.type == "mps": | |
# randn does not work reproducibly on mps | |
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | |
else: | |
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
if latents.shape != shape: | |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | |
latents = latents.to(device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * init_noise_sigma | |
return latents | |
@torch.no_grad() | |
def image_generator( | |
config, | |
request_queue: Queue, | |
image_queue: Queue, | |
): | |
os.environ["ONEFLOW_MLIR_CSE"] = "1" | |
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" | |
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" | |
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" | |
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" | |
os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" | |
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" | |
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" | |
os.environ["ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL"] = "1" | |
os.environ["ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL"] = "0" | |
os.environ["ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL"] = "1" | |
os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" | |
os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" | |
print("Initiating...") | |
def imgq(status: str, content): | |
response = { | |
"status": status, | |
"content": content | |
} | |
image_queue.put(response) | |
#BOOTUP START | |
#config | |
prompt_multiplier = 20 | |
device = torch.device("cuda") | |
default_model = "runwayml/stable-diffusion-v1-5" | |
#things below make vram go boom | |
text_model = CLIPTextModel.from_pretrained(default_model, subfolder="text_encoder") | |
tokenizer_model = CLIPTokenizer.from_pretrained(default_model, subfolder="tokenizer") | |
#do not use device here, device thats initiated here is from oneflow, transformers likely needs a og_torch one | |
text_model = text_model.to("cuda") | |
lpw_pipe = LongPromptWeightingPipeline(text_model, tokenizer_model, prompt_multiplier) | |
resolutions = [256, 512, 768, 1024] | |
resultant_resolutions = [] | |
for width in resolutions: | |
for height in resolutions: | |
resultant_resolutions.append([width, height]) | |
resultant_resolutions = sorted(resultant_resolutions, key=lambda res: res[0]*res[1], reverse=True) | |
#Unet compilation | |
pipe_map = dict() | |
for model in config['models']: | |
print("Loading model:", model['model_path']) | |
tmp_pipe = OneFlowStableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path=model['model_path'], | |
use_auth_token=True, | |
torch_dtype=torch.float16 | |
) | |
tmp_pipe.to("cuda") | |
tmp_pipe._encode_prompt = lpw_pipe._encode_prompt | |
tmp_pipe.enable_graph_share_mem() | |
tmp_prompt = "Anime girl, beautiful" | |
tmp_neg_prompt = "Disgusting, Horrible" | |
for resolution in resultant_resolutions: | |
print("Doing resolution:", resolution) | |
with torch.autocast("cuda"): | |
tmp_pipe( | |
prompt=tmp_prompt, | |
negative_prompt=tmp_neg_prompt, | |
height=resolution[1], | |
width=resolution[0] | |
) | |
pipe_map[model['alias']] = tmp_pipe | |
#SCHEDULERS INIT | |
sch_source = default_model | |
schedulers = { | |
"DDIM": OneFlowDDIMScheduler.from_pretrained(sch_source, subfolder="scheduler"), | |
"EULER-A": OneFlowEulerAncestralDiscreteScheduler.from_pretrained(sch_source, subfolder="scheduler"), | |
"EULER": OneFlowEulerDiscreteScheduler.from_pretrained(sch_source, subfolder="scheduler"), | |
"HEUN": OneFlowDPMSolverMultistepScheduler.from_pretrained(sch_source, subfolder="scheduler", solver_type="heun"), | |
"DPM++": OneFlowDPMSolverMultistepScheduler.from_pretrained(sch_source, subfolder="scheduler"), | |
"DPM": OneFlowDPMSolverMultistepScheduler.from_pretrained(sch_source, subfolder="scheduler", algorithm_type="dpmsolver"), | |
"PNDM": OneFlowPNDMScheduler.from_pretrained(sch_source, subfolder="scheduler"), | |
} | |
#BOOTUP END | |
#Request gathering loop | |
while True: | |
request = request_queue.get() | |
data = request | |
try: | |
start_time = time.time() | |
print("Got request:", request) | |
prompt = data['prompt'] | |
negative_prompt = data['negative_prompt'] | |
image_height = data['height'] | |
image_width = data['width'] | |
scheduler = data['scheduler'] | |
steps = data['steps'] | |
cfg = data['cfg'] | |
seed = data['seed'] | |
model = data['model'] | |
vae = data['vae'] | |
if model in pipe_map: | |
curr_pipe = pipe_map[model] | |
else: | |
imgq('fail', f'Pipeline with name {model} was not found.') | |
continue | |
if scheduler in schedulers: | |
curr_scheduler = schedulers[scheduler] | |
if seed is None or seed == -1: | |
seed = random.randint(1, 100000000000000) | |
generator = torch.Generator().manual_seed(seed) | |
batch_size = 1 | |
preparation_time = time.time() | |
latents = prepare_latents( | |
batch_size=batch_size, | |
num_channels_latents=4, | |
height=image_height, | |
width=image_width, | |
dtype=torch.float16, | |
device=device, | |
generator=generator, | |
init_noise_sigma=curr_scheduler.init_noise_sigma | |
) | |
late_time = time.time() | |
curr_pipe.scheduler = curr_scheduler | |
with torch.autocast("cuda"): | |
image = curr_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
generator=generator, | |
height=image_height, | |
width=image_width, | |
latents=latents | |
)[0][0] | |
pipe_time = time.time() | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
imgq("done", { | |
"image": base64.b64encode(buffered.getvalue()).decode('utf-8'), | |
"metadata": { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"model": model, | |
"vae": vae, | |
"steps": steps, | |
"width": image_width, | |
"height": image_height, | |
"cfg": cfg, | |
"seed": seed, | |
"scheduler": scheduler, | |
"compute_time": pipe_time - preparation_time | |
} | |
}) | |
benchmark_time = { | |
"PREP": preparation_time - start_time, | |
"LATENTS": late_time - preparation_time, | |
"PIPE": pipe_time - late_time, | |
"TOTAL": pipe_time - start_time, | |
} | |
print("Benchs: (Check notes)") | |
for i in benchmark_time: | |
print('| {:^14} | {:>9.2f} ms |'.format(i, int(benchmark_time[i]*1000))) | |
print(f'w{image_width} x h{image_height}') | |
print('scheduler: {}'.format(scheduler)) | |
except Exception as e: | |
traceback.print_exc() | |
imgq('fail', f'general exception, got {str(e)}') | |
continue |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment