-
-
Save chavinlo/79776f50006698e477796c4c58083623 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import oneflow as torch | |
import torch as og_torch | |
def is_accelerate_available(): | |
return False | |
from transformers import CLIPTextModel, CLIPTokenizer | |
import os | |
import oneflow as flow | |
from lpw import LongPromptWeightingPipeline | |
from diffusers import OneFlowEulerAncestralDiscreteScheduler, OneFlowAutoencoderKL, OneFlowUNet2DConditionModel | |
from PIL import Image | |
import time | |
from queue import Queue | |
def set_envs(): | |
os.environ["ONEFLOW_MLIR_CSE"] = "1" | |
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" | |
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" | |
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" | |
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" | |
os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" | |
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" | |
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" | |
os.environ["ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL"] = "1" | |
os.environ["ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL"] = "0" | |
os.environ["ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL"] = "1" | |
os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" | |
os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" | |
print("Initiating...") | |
#self.vae_scale_factor = 8 | |
def prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, init_noise_sigma, latents=None): | |
shape = (batch_size, num_channels_latents, height // 8, width // 8) | |
if latents is None: | |
if device.type == "mps": | |
# randn does not work reproducibly on mps | |
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | |
else: | |
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
if latents.shape != shape: | |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | |
latents = latents.to(device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * init_noise_sigma | |
return latents | |
def numpy_to_pil(images): | |
""" | |
Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
# custom oneflow classes | |
class UNetGraph(flow.nn.Graph): | |
def __init__(self, unet): | |
super().__init__() | |
self.unet = unet | |
self.config.enable_cudnn_conv_heuristic_search_algo(False) | |
self.config.allow_fuse_add_to_output(True) | |
def build(self, latent_model_input, t, text_embeddings): | |
text_embeddings = torch._C.amp_white_identity(text_embeddings) | |
return self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | |
class VaePostProcess(flow.nn.Module): | |
def __init__(self, vae) -> None: | |
super().__init__() | |
self.vae = vae | |
def forward(self, latents): | |
latents = 1 / 0.18215 * latents | |
image = self.vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
return image | |
class VaeGraph(flow.nn.Graph): | |
def __init__(self, vae_post_process) -> None: | |
super().__init__() | |
self.vae_post_process = vae_post_process | |
def build(self, latents): | |
return self.vae_post_process(latents) | |
class TextEncoderGraph(flow.nn.Graph): | |
def __init__(self, text_encoder) -> None: | |
super().__init__() | |
self.text_encoder = text_encoder | |
def build(self, text_input, attention_mask): | |
return self.text_encoder(text_input, attention_mask)[0] | |
@torch.no_grad() | |
def engine( | |
config | |
): | |
default_model = "runwayml/stable-diffusion-v1-5" | |
text_model = CLIPTextModel.from_pretrained(default_model, subfolder="text_encoder") | |
tokenizer_model = CLIPTokenizer.from_pretrained(default_model, subfolder="tokenizer") | |
text_model = text_model.to("cuda") | |
lpw_pipe = LongPromptWeightingPipeline(text_model, tokenizer_model, config['prompt_multiplier']) | |
def sample_unet_inputs(): | |
scheduler = OneFlowEulerAncestralDiscreteScheduler.from_pretrained(default_model, subfolder="scheduler") | |
text_embeddings = lpw_pipe(prompt="anime girl", negative_prompt="horrible") | |
text_embeddings = text_embeddings.to("cpu") | |
text_embeddings = text_embeddings.numpy() | |
text_embeddings = torch.from_numpy(text_embeddings) | |
text_embeddings = text_embeddings.to("cuda") | |
text_embeddings = text_embeddings.to(torch.float16) | |
latents = prepare_latents( | |
1, | |
4, | |
512, | |
512, | |
text_embeddings.dtype, | |
torch.device("cuda"), | |
None, | |
init_noise_sigma=scheduler.init_noise_sigma | |
) | |
_, t = list(enumerate(scheduler.timesteps))[0] | |
latent_model_input = torch.cat([latents] * 2) | |
return latent_model_input, t, text_embeddings | |
#tmp cache key | |
cache_key = (512, 512, 1) | |
from diffusers.oneflow_graph_compile_cache import OneFlowGraphCompileCache | |
#defaults are cache_size = 1 and enable_graph_share_mem=False | |
unet_map = {} | |
for model in config['unets']: | |
graphs_path = os.path.join(model['model_path'], 'graphs') | |
pytorch_path = os.path.join(model['model_path'], 'pytorch') | |
tmp_unet = OneFlowUNet2DConditionModel.from_pretrained(pytorch_path) | |
tmp_unet.to("cuda") | |
tmp_unet.to(torch.float16) | |
graph_compile_cache = OneFlowGraphCompileCache(10, True) | |
graph_compile_cache.enable_share_mem(True) | |
#load saved graph | |
unet_graph = None | |
graph_class2init_args = dict() | |
unet_graph_args = (UNetGraph, tmp_unet) | |
graph_class2init_args[UNetGraph.__name__] = unet_graph_args | |
print("load init") | |
print(graphs_path) | |
graph_compile_cache.load_graph(graphs_path, graph_class2init_args) | |
print("load finish") | |
unet_graph = graph_compile_cache.get_graph(UNetGraph, cache_key, tmp_unet) | |
graph_compile_cache.enable_share_mem(True) | |
# latent_model_input, t, text_embeddings = sample_unet_inputs() | |
# print("compiling") | |
# unet_graph.compile(latent_model_input, t, text_embeddings) | |
# print("compile finish") | |
unet_map[model['alias']] = unet_graph | |
vae_map = {} | |
for model in config['vaes']: | |
graphs_path = os.path.join(model['model_path'], 'graphs') | |
pytorch_path = os.path.join(model['model_path'], 'pytorch') | |
tmp_vae = OneFlowAutoencoderKL.from_pretrained(pytorch_path) | |
tmp_vae.to("cuda") | |
tmp_vae.to(torch.float16) | |
vae_post_process = VaePostProcess(tmp_vae) | |
vae_post_process.eval() | |
graph_class2init_args = dict() | |
vae_post_process_graph = graph_compile_cache.get_graph(VaeGraph, cache_key, vae_post_process) | |
vae_graph_args = (VaeGraph, vae_post_process_graph) | |
graph_class2init_args[VaeGraph.__name__] = vae_graph_args | |
graph_compile_cache = OneFlowGraphCompileCache(10, True) | |
graph_compile_cache.enable_share_mem() | |
graph_compile_cache.load_graph(graphs_path, graph_class2init_args) | |
vae_map[model['alias']] = vae_post_process_graph | |
device = torch.device("cuda") | |
num_inference_steps = 50 | |
width = 512 | |
height = 512 | |
cfg = 14 | |
prompt = "Anime girl" | |
negative_prompt = "Disgusting, Low resolution, Low Quality" | |
do_classifier_free_guidance = True | |
text_embeddings = lpw_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=cfg | |
) | |
print(text_embeddings) | |
print(text_embeddings.shape) | |
generator = torch.Generator() | |
scheduler = OneFlowEulerAncestralDiscreteScheduler.from_pretrained(default_model, subfolder="scheduler") | |
scheduler.set_timesteps(num_inference_steps) | |
timesteps = scheduler.timesteps | |
print("Preps done") | |
num_channel_latents = 4 | |
print(text_embeddings.dtype) | |
text_embeddings = text_embeddings.to("cpu") | |
text_embeddings = text_embeddings.numpy() | |
text_embeddings = torch.from_numpy(text_embeddings) | |
text_embeddings = text_embeddings.to("cuda") | |
text_embeddings = text_embeddings.to(torch.float16) | |
cache_key = (height, width, 1) | |
vae_post_process_graph = vae_map['ANYTHING-V3'] | |
unet_graph = unet_map['ANYTHING-V3'] | |
print(unet_graph) | |
latents = prepare_latents( | |
batch_size=1, | |
num_channels_latents=num_channel_latents, | |
height=height, | |
width=width, | |
dtype=text_embeddings.dtype, | |
device=device, | |
generator=generator, | |
init_noise_sigma=scheduler.init_noise_sigma | |
) | |
print("latents dev:", latents.device) | |
print("latents type:", latents.dtype) | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
print("i:", i) | |
print("t:", t) | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}-unet-graph") | |
noise_pred = unet_graph(latent_model_input, t, text_embeddings) | |
torch._oneflow_internal.profiler.RangePop() | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + cfg * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
image = vae_post_process_graph(latents) | |
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
image = numpy_to_pil(image) | |
image[0].save(f"test_out.png") | |
if __name__ == "__main__": | |
import json | |
from test import engine | |
engine(json.load(open('/root/node/cfg/basic.json'))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment