Skip to content

Instantly share code, notes, and snippets.

@chavinlo

chavinlo/test.py Secret

Created February 16, 2023 05:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chavinlo/79776f50006698e477796c4c58083623 to your computer and use it in GitHub Desktop.
Save chavinlo/79776f50006698e477796c4c58083623 to your computer and use it in GitHub Desktop.
import oneflow as torch
import torch as og_torch
def is_accelerate_available():
return False
from transformers import CLIPTextModel, CLIPTokenizer
import os
import oneflow as flow
from lpw import LongPromptWeightingPipeline
from diffusers import OneFlowEulerAncestralDiscreteScheduler, OneFlowAutoencoderKL, OneFlowUNet2DConditionModel
from PIL import Image
import time
from queue import Queue
def set_envs():
os.environ["ONEFLOW_MLIR_CSE"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1"
os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1"
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1"
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
os.environ["ONEFLOW_KERENL_CONV_ENABLE_CUTLASS_IMPL"] = "1"
os.environ["ONEFLOW_KERENL_FMHA_ENABLE_TRT_FLASH_ATTN_IMPL"] = "0"
os.environ["ONEFLOW_KERNEL_GLU_ENABLE_DUAL_GEMM_IMPL"] = "1"
os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1"
os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1"
print("Initiating...")
#self.vae_scale_factor = 8
def prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator, init_noise_sigma, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * init_noise_sigma
return latents
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
# custom oneflow classes
class UNetGraph(flow.nn.Graph):
def __init__(self, unet):
super().__init__()
self.unet = unet
self.config.enable_cudnn_conv_heuristic_search_algo(False)
self.config.allow_fuse_add_to_output(True)
def build(self, latent_model_input, t, text_embeddings):
text_embeddings = torch._C.amp_white_identity(text_embeddings)
return self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
class VaePostProcess(flow.nn.Module):
def __init__(self, vae) -> None:
super().__init__()
self.vae = vae
def forward(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
return image
class VaeGraph(flow.nn.Graph):
def __init__(self, vae_post_process) -> None:
super().__init__()
self.vae_post_process = vae_post_process
def build(self, latents):
return self.vae_post_process(latents)
class TextEncoderGraph(flow.nn.Graph):
def __init__(self, text_encoder) -> None:
super().__init__()
self.text_encoder = text_encoder
def build(self, text_input, attention_mask):
return self.text_encoder(text_input, attention_mask)[0]
@torch.no_grad()
def engine(
config
):
default_model = "runwayml/stable-diffusion-v1-5"
text_model = CLIPTextModel.from_pretrained(default_model, subfolder="text_encoder")
tokenizer_model = CLIPTokenizer.from_pretrained(default_model, subfolder="tokenizer")
text_model = text_model.to("cuda")
lpw_pipe = LongPromptWeightingPipeline(text_model, tokenizer_model, config['prompt_multiplier'])
def sample_unet_inputs():
scheduler = OneFlowEulerAncestralDiscreteScheduler.from_pretrained(default_model, subfolder="scheduler")
text_embeddings = lpw_pipe(prompt="anime girl", negative_prompt="horrible")
text_embeddings = text_embeddings.to("cpu")
text_embeddings = text_embeddings.numpy()
text_embeddings = torch.from_numpy(text_embeddings)
text_embeddings = text_embeddings.to("cuda")
text_embeddings = text_embeddings.to(torch.float16)
latents = prepare_latents(
1,
4,
512,
512,
text_embeddings.dtype,
torch.device("cuda"),
None,
init_noise_sigma=scheduler.init_noise_sigma
)
_, t = list(enumerate(scheduler.timesteps))[0]
latent_model_input = torch.cat([latents] * 2)
return latent_model_input, t, text_embeddings
#tmp cache key
cache_key = (512, 512, 1)
from diffusers.oneflow_graph_compile_cache import OneFlowGraphCompileCache
#defaults are cache_size = 1 and enable_graph_share_mem=False
unet_map = {}
for model in config['unets']:
graphs_path = os.path.join(model['model_path'], 'graphs')
pytorch_path = os.path.join(model['model_path'], 'pytorch')
tmp_unet = OneFlowUNet2DConditionModel.from_pretrained(pytorch_path)
tmp_unet.to("cuda")
tmp_unet.to(torch.float16)
graph_compile_cache = OneFlowGraphCompileCache(10, True)
graph_compile_cache.enable_share_mem(True)
#load saved graph
unet_graph = None
graph_class2init_args = dict()
unet_graph_args = (UNetGraph, tmp_unet)
graph_class2init_args[UNetGraph.__name__] = unet_graph_args
print("load init")
print(graphs_path)
graph_compile_cache.load_graph(graphs_path, graph_class2init_args)
print("load finish")
unet_graph = graph_compile_cache.get_graph(UNetGraph, cache_key, tmp_unet)
graph_compile_cache.enable_share_mem(True)
# latent_model_input, t, text_embeddings = sample_unet_inputs()
# print("compiling")
# unet_graph.compile(latent_model_input, t, text_embeddings)
# print("compile finish")
unet_map[model['alias']] = unet_graph
vae_map = {}
for model in config['vaes']:
graphs_path = os.path.join(model['model_path'], 'graphs')
pytorch_path = os.path.join(model['model_path'], 'pytorch')
tmp_vae = OneFlowAutoencoderKL.from_pretrained(pytorch_path)
tmp_vae.to("cuda")
tmp_vae.to(torch.float16)
vae_post_process = VaePostProcess(tmp_vae)
vae_post_process.eval()
graph_class2init_args = dict()
vae_post_process_graph = graph_compile_cache.get_graph(VaeGraph, cache_key, vae_post_process)
vae_graph_args = (VaeGraph, vae_post_process_graph)
graph_class2init_args[VaeGraph.__name__] = vae_graph_args
graph_compile_cache = OneFlowGraphCompileCache(10, True)
graph_compile_cache.enable_share_mem()
graph_compile_cache.load_graph(graphs_path, graph_class2init_args)
vae_map[model['alias']] = vae_post_process_graph
device = torch.device("cuda")
num_inference_steps = 50
width = 512
height = 512
cfg = 14
prompt = "Anime girl"
negative_prompt = "Disgusting, Low resolution, Low Quality"
do_classifier_free_guidance = True
text_embeddings = lpw_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=cfg
)
print(text_embeddings)
print(text_embeddings.shape)
generator = torch.Generator()
scheduler = OneFlowEulerAncestralDiscreteScheduler.from_pretrained(default_model, subfolder="scheduler")
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
print("Preps done")
num_channel_latents = 4
print(text_embeddings.dtype)
text_embeddings = text_embeddings.to("cpu")
text_embeddings = text_embeddings.numpy()
text_embeddings = torch.from_numpy(text_embeddings)
text_embeddings = text_embeddings.to("cuda")
text_embeddings = text_embeddings.to(torch.float16)
cache_key = (height, width, 1)
vae_post_process_graph = vae_map['ANYTHING-V3']
unet_graph = unet_map['ANYTHING-V3']
print(unet_graph)
latents = prepare_latents(
batch_size=1,
num_channels_latents=num_channel_latents,
height=height,
width=width,
dtype=text_embeddings.dtype,
device=device,
generator=generator,
init_noise_sigma=scheduler.init_noise_sigma
)
print("latents dev:", latents.device)
print("latents type:", latents.dtype)
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
print("i:", i)
print("t:", t)
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}-unet-graph")
noise_pred = unet_graph(latent_model_input, t, text_embeddings)
torch._oneflow_internal.profiler.RangePop()
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
image = vae_post_process_graph(latents)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)
image[0].save(f"test_out.png")
if __name__ == "__main__":
import json
from test import engine
engine(json.load(open('/root/node/cfg/basic.json')))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment