Skip to content

Instantly share code, notes, and snippets.

@chavinlo

chavinlo/base.py Secret

Last active February 16, 2023 05:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chavinlo/d8005ebda6499853891c9edae8765b4b to your computer and use it in GitHub Desktop.
Save chavinlo/d8005ebda6499853891c9edae8765b4b to your computer and use it in GitHub Desktop.
import oneflow as torch
import torch as og_torch
from queue import Queue
def is_accelerate_available():
return False
from transformers import CLIPTextModel, CLIPTokenizer
import os
import oneflow as flow
from .lpw import LongPromptWeightingPipeline
from diffusers import (
OneFlowDDIMScheduler,
OneFlowDPMSolverMultistepScheduler,
OneFlowEulerAncestralDiscreteScheduler,
OneFlowEulerDiscreteScheduler,
OneFlowPNDMScheduler,
OneFlowStableDiffusionPipeline
)
import time
from PIL import Image
import random
from io import BytesIO
import base64
import traceback
def prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, init_noise_sigma, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * init_noise_sigma
return latents
@torch.no_grad()
def image_generator(
config,
request_queue: Queue,
image_queue: Queue,
):
os.environ["ONEFLOW_MLIR_CSE"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1"
os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1"
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1"
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
os.environ["ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL"] = "1"
os.environ["ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL"] = "0"
os.environ["ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL"] = "1"
os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1"
os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1"
print("Initiating...")
def imgq(status: str, content):
response = {
"status": status,
"content": content
}
image_queue.put(response)
#BOOTUP START
#config
prompt_multiplier = 20
device = torch.device("cuda")
default_model = "runwayml/stable-diffusion-v1-5"
#things below make vram go boom
text_model = CLIPTextModel.from_pretrained(default_model, subfolder="text_encoder")
tokenizer_model = CLIPTokenizer.from_pretrained(default_model, subfolder="tokenizer")
#do not use device here, device thats initiated here is from oneflow, transformers likely needs a og_torch one
text_model = text_model.to("cuda")
lpw_pipe = LongPromptWeightingPipeline(text_model, tokenizer_model, prompt_multiplier)
resolutions = [256, 512, 768, 1024]
resultant_resolutions = []
for width in resolutions:
for height in resolutions:
resultant_resolutions.append([width, height])
resultant_resolutions = sorted(resultant_resolutions, key=lambda res: res[0]*res[1], reverse=True)
#Unet compilation
pipe_map = dict()
for model in config['models']:
print("Loading model:", model['model_path'])
tmp_pipe = OneFlowStableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model['model_path'],
use_auth_token=True,
torch_dtype=torch.float16
)
tmp_pipe.to("cuda")
tmp_pipe._encode_prompt = lpw_pipe._encode_prompt
tmp_pipe.enable_graph_share_mem()
tmp_prompt = "Anime girl, beautiful"
tmp_neg_prompt = "Disgusting, Horrible"
for resolution in resultant_resolutions:
print("Doing resolution:", resolution)
with torch.autocast("cuda"):
tmp_pipe(
prompt=tmp_prompt,
negative_prompt=tmp_neg_prompt,
height=resolution[1],
width=resolution[0]
)
pipe_map[model['alias']] = tmp_pipe
#SCHEDULERS INIT
sch_source = default_model
schedulers = {
"DDIM": OneFlowDDIMScheduler.from_pretrained(sch_source, subfolder="scheduler"),
"EULER-A": OneFlowEulerAncestralDiscreteScheduler.from_pretrained(sch_source, subfolder="scheduler"),
"EULER": OneFlowEulerDiscreteScheduler.from_pretrained(sch_source, subfolder="scheduler"),
"HEUN": OneFlowDPMSolverMultistepScheduler.from_pretrained(sch_source, subfolder="scheduler", solver_type="heun"),
"DPM++": OneFlowDPMSolverMultistepScheduler.from_pretrained(sch_source, subfolder="scheduler"),
"DPM": OneFlowDPMSolverMultistepScheduler.from_pretrained(sch_source, subfolder="scheduler", algorithm_type="dpmsolver"),
"PNDM": OneFlowPNDMScheduler.from_pretrained(sch_source, subfolder="scheduler"),
}
#BOOTUP END
#Request gathering loop
while True:
request = request_queue.get()
data = request
try:
start_time = time.time()
print("Got request:", request)
prompt = data['prompt']
negative_prompt = data['negative_prompt']
image_height = data['height']
image_width = data['width']
scheduler = data['scheduler']
steps = data['steps']
cfg = data['cfg']
seed = data['seed']
model = data['model']
vae = data['vae']
if model in pipe_map:
curr_pipe = pipe_map[model]
else:
imgq('fail', f'Pipeline with name {model} was not found.')
continue
if scheduler in schedulers:
curr_scheduler = schedulers[scheduler]
if seed is None or seed == -1:
seed = random.randint(1, 100000000000000)
generator = torch.Generator().manual_seed(seed)
batch_size = 1
preparation_time = time.time()
latents = prepare_latents(
batch_size=batch_size,
num_channels_latents=4,
height=image_height,
width=image_width,
dtype=torch.float16,
device=device,
generator=generator,
init_noise_sigma=curr_scheduler.init_noise_sigma
)
late_time = time.time()
curr_pipe.scheduler = curr_scheduler
with torch.autocast("cuda"):
image = curr_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
height=image_height,
width=image_width,
latents=latents
)[0][0]
pipe_time = time.time()
buffered = BytesIO()
image.save(buffered, format="PNG")
imgq("done", {
"image": base64.b64encode(buffered.getvalue()).decode('utf-8'),
"metadata": {
"prompt": prompt,
"negative_prompt": negative_prompt,
"model": model,
"vae": vae,
"steps": steps,
"width": image_width,
"height": image_height,
"cfg": cfg,
"seed": seed,
"scheduler": scheduler,
"compute_time": pipe_time - preparation_time
}
})
benchmark_time = {
"PREP": preparation_time - start_time,
"LATENTS": late_time - preparation_time,
"PIPE": pipe_time - late_time,
"TOTAL": pipe_time - start_time,
}
print("Benchs: (Check notes)")
for i in benchmark_time:
print('| {:^14} | {:>9.2f} ms |'.format(i, int(benchmark_time[i]*1000)))
print(f'w{image_width} x h{image_height}')
print('scheduler: {}'.format(scheduler))
except Exception as e:
traceback.print_exc()
imgq('fail', f'general exception, got {str(e)}')
continue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment