Skip to content

Instantly share code, notes, and snippets.

@mrsteyk
Last active November 11, 2023 06:54
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mrsteyk/74ad3ec2f6f823111ae4c90e168505ac to your computer and use it in GitHub Desktop.
Save mrsteyk/74ad3ec2f6f823111ae4c90e168505ac to your computer and use it in GitHub Desktop.
dalle_runner_api.model_infra.modules.public_diff_vae

Memes

UPD8: clean-up and it runs now, cba'd to actually test it demo below. Fixes are from gh:madebyollin, check out his gist which actually works

weights now with ready to use checkpoint (will be up in few minutes)

Confirmed working with this code:

import time

import torch
from torchvision import transforms as tfms

from diffusers import AutoencoderKL
from diffusers.utils import load_image
from safetensors.torch import load_file

from cd_oai import ConsistencyDecoder, save_image

from consistency_decoder import ConsistencyDecoder as UNet

with torch.no_grad():
    to_tensor = tfms.ToTensor()

    # Slightly better VAE to test against, not used in original SD1.4+
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32).cuda().eval().requires_grad_(False)

    # How do you call a fish wearing a bowtie? sofishticated
    image = load_image("__hatsune_miku_vocaloid_drawn_by_qys3__719b1e28d4bb1431325c1c663d9d247e.jpg").convert("RGB").resize((512, 768)) # https://danbooru.donmai.us/posts/6853904
    image = to_tensor(image).to(device=vae.device, dtype=vae.dtype).unsqueeze(0)*2 - 1

    latent = vae.encode(image.cuda()).latent_dist.sample()
    del image

    # decode with improved VAE!
    # idk if I need to scale it manually
    sample_gan = vae.decode(latent).sample.detach()
    save_image(sample_gan, "gan.png")
    # del latent
    del sample_gan
    del vae
    torch.cuda.empty_cache()

    # funny ConvResNet, doesn't fit on 3060
    model = UNet()
    model.load_state_dict(load_file("stk_consistency_decoder_amalgamated.safetensors"))
    model = model.eval().requires_grad_(False)

    # construct cd with different model
    decoder_consistency = ConsistencyDecoder(ckpt=model, device="cuda:0")
    torch.cuda.empty_cache()
    torch.cuda.manual_seed_all(228) # reproc
    t = time.time()
    sample_consistency = decoder_consistency(latent)
    print(f"FP32 took {time.time()-t}s")
    del model
    del decoder_consistency
    torch.cuda.empty_cache()
    save_image(sample_consistency, "cdec.png")

    # speeds stuff up, no perceptual quality change (as expected)
    # birchlabs guy already said that 99.99% of the time fp16 is all you need for a UNet @ inference time

    # removing this amp autocast improves speed and VRAM usage lel, I am misusing something smh
    # with torch.cuda.amp.autocast(dtype=torch.float16):
    # funny ConvResNet, this time in float16
    # I'm pretty sure torch would yell at me for fp16 on CPU
    model = UNet()
    model.load_state_dict(load_file("stk_consistency_decoder_amalgamated.safetensors"))
    model = model.cuda().half().eval().requires_grad_(False)
    torch.cuda.empty_cache()

    # construct cd with different model
    decoder_consistency = ConsistencyDecoder(ckpt=model, device="cuda:0")
    torch.cuda.empty_cache()
    torch.cuda.manual_seed_all(228) # reproc
    t = time.time()
    sample_consistency = decoder_consistency(latent.half()).float()
    print(f"FP16 took {time.time()-t}s")
    del model
    del decoder_consistency
    torch.cuda.empty_cache()
    save_image(sample_consistency, "cdec_half.png")

    # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    # funny ConvResNet, this time in brain float 16
    # I'm pretty sure torch would yell at me for bf16 on CPU
    model = UNet()
    model.load_state_dict(load_file("stk_consistency_decoder_amalgamated.safetensors"))
    model = model.cuda().bfloat16().eval().requires_grad_(False)
    torch.cuda.empty_cache()

    # construct cd with different model
    decoder_consistency = ConsistencyDecoder(ckpt=model, device="cuda:0")
    torch.cuda.empty_cache()
    torch.cuda.manual_seed_all(228) # reproc
    t = time.time()
    sample_consistency = decoder_consistency(latent.bfloat16()).float()
    print(f"BF16 took {time.time()-t}s")
    del model
    del decoder_consistency
    torch.cuda.empty_cache()
    save_image(sample_consistency, "cdec_brain.png")

# Certified tensor core and cudnn moment
# cudnn used is 8.9.5.30
"""
FP32 took 30.042385578155518s
FP16 took 2.8654606342315674s
BF16 took 2.710512399673462s
"""

Ground truth is this

GT-resize VAE-MSE CD-FP32 CD-FP16 CD-BF16
gt gan cdec cdec_half cdec_brain

original 3 AM message

I might've messed up paddings and other non weigh related hparams in few places - mainly upsample/downsample ConvResBlocks and embeddings, test at your own risk, I didn't run it. Eepy time бб.

import torch
class TimestepEmbedding(torch.nn.Module):
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
super().__init__()
self.emb = torch.nn.Embedding(n_time, n_emb)
self.f_1 = torch.nn.Linear(n_emb, n_out)
# self.act = torch.nn.SiLU()
self.f_2 = torch.nn.Linear(n_out, n_out)
def forward(self, x) -> torch.Tensor:
x = self.emb(x)
x = self.f_1(x)
x = torch.nn.functional.silu(x)
return self.f_2(x)
class ImageEmbedding(torch.nn.Module):
def __init__(self, in_channels=7, out_channels=320) -> None:
super().__init__()
self.f = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class ImageUnembedding(torch.nn.Module):
def __init__(self, in_channels=320, out_channels=6) -> None:
super().__init__()
self.gn = torch.nn.GroupNorm(32, in_channels)
self.f = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(torch.nn.functional.silu(self.gn(x)))
# Improved universal block with fixes from gh:madebyollin
class ConvResblock(torch.nn.Module):
def __init__(self, in_features=320, out_features=320, skip_conv=False, up=False, down=False) -> None:
super().__init__()
self.f_t = torch.nn.Linear(1280, out_features * 2)
self.gn_1 = torch.nn.GroupNorm(32, in_features)
self.f_1 = torch.nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
self.gn_2 = torch.nn.GroupNorm(32, out_features)
self.f_2 = torch.nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
self.skip_conv = skip_conv
self.f_s = torch.nn.Identity() if not skip_conv else torch.nn.Conv2d(in_features, out_features, kernel_size=1, padding=0)
self.f_x = torch.nn.Identity()
self.up = up
self.down = down
assert not (up and down), "Can't be up and down at the same time!"
if up:
# torch.nn.functional.upsample_nearest(gn_1, scale_factor=2)
self.f_x = torch.nn.UpsamplingNearest2d(scale_factor=2)
elif down:
# torch.nn.functional.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
self.f_x = torch.nn.AvgPool2d(kernel_size=(2, 2), stride=None)
def forward(self, x, t):
x_skip = x
t: torch.Tensor = self.f_t(torch.nn.functional.silu(t))
t = t.chunk(2, dim=1)
# ???
# maybe need to swap them out idk, idxs are like that, first one is +1, other is as is
# probably that stupid while loop with `None`s
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
gn_1 = torch.nn.functional.silu(self.gn_1(x))
f_1 = self.f_1(self.f_x(gn_1))
gn_2 = self.gn_2(f_1)
# I don't know how addcmul is routed, probably += a*b? self is t_2, tensor1 is gn_2, tensor2 is t_1
addcmul = torch.nn.functional.silu(gn_2 * t_1 + t_2)
return self.f_s(self.f_x(x_skip)) + self.f_2(addcmul)
# ConsistencyDecoder aka super resolution from 4 to 3 channels!
class ConsistencyDecoder(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.embed_image = ImageEmbedding()
self.embed_time = TimestepEmbedding()
# No attention is needed here!
# We only "upscale" (48x that is or 64x if you don't count chan diff lulw)
# I was close to doing that,
# but I had CrossAttn over VAE reshaped
# to be Bx(HW div by whatever or -1 if you prefer)x1024 alongside DiffNeXt's skip
# 3 ResBlocks before downsample
# repeat 4 times
# downs are [320, 640, 1024, 1024]
# in reality it has distinctions between conv and downsamp
# Chess Battle Advanced
down_0 = torch.nn.ModuleList([
ConvResblock(320, 320),
ConvResblock(320, 320),
ConvResblock(320, 320),
# Downsample(320),
ConvResblock(320, 320, down=True),
])
down_1 = torch.nn.ModuleList([
ConvResblock(320, 640, skip_conv=True),
ConvResblock(640, 640),
ConvResblock(640, 640),
# Downsample(640),
ConvResblock(640, 640, down=True),
])
down_2 = torch.nn.ModuleList([
ConvResblock(640, 1024, skip_conv=True),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
# Downsample(1024),
ConvResblock(1024, 1024, down=True),
])
down_3 = torch.nn.ModuleList([
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
])
self.down = torch.nn.ModuleList([
down_0,
down_1,
down_2,
down_3,
])
# mid has 2
self.mid = torch.nn.ModuleList([
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
])
# Again,
# Chess Battle Advanced
up_3 = torch.nn.ModuleList([
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
# Upsample(1024),
ConvResblock(1024, 1024, up=True),
])
up_2 = torch.nn.ModuleList([
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024+640, 1024, skip_conv=True),
# Upsample(1024),
ConvResblock(1024, 1024, up=True),
])
up_1 = torch.nn.ModuleList([
ConvResblock(1024+640, 640, skip_conv=True),
ConvResblock(640*2, 640, skip_conv=True),
ConvResblock(640*2, 640, skip_conv=True),
ConvResblock(320+640, 640, skip_conv=True),
# Upsample(640),
ConvResblock(640, 640, up=True),
])
up_0 = torch.nn.ModuleList([
ConvResblock(320+640, 320, skip_conv=True),
ConvResblock(320*2, 320, skip_conv=True),
ConvResblock(320*2, 320, skip_conv=True),
ConvResblock(320*2, 320, skip_conv=True),
])
self.up = torch.nn.ModuleList([
up_0,
up_1,
up_2,
up_3,
])
# ImageUnembedding
self.output = ImageUnembedding()
def forward(self, x, t, features) -> torch.Tensor:
t = self.embed_time(t)
# LITERAL SUPER RESOLUTION
x = torch.cat(
# warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.")
# [x, torch.nn.functional.upsample_nearest(features, scale_factor=8)],
[x, torch.nn.functional.interpolate(features, scale_factor=8, mode='nearest')],
dim=1
)
x = self.embed_image(x)
# DOWN
block_outs = [x]
for mod in self.down:
for f in mod:
x = f(x, t)
block_outs.append(x)
# mid
for f in self.mid:
x = f(x, t)
# UP
for mod in self.up[::-1]:
for f in mod:
if not f.up:
x = torch.concat([x, block_outs.pop()], dim=1)
x = f(x, t)
# OUT
# GN -> silu -> f
x = self.output(x)
return x
if __name__ == "__main__":
model = ConsistencyDecoder()
print(model.state_dict().keys(), model.embed_time.emb.weight.shape)
import safetensors.torch
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors")
# print(cd_orig.keys())
# prefix
cd_orig = {k.replace("blocks.", ""): v for k,v in cd_orig.items()}
# layer names
cd_orig = {k.replace("down_0_", "down.0."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("down_1_", "down.1."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("down_2_", "down.2."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("down_3_", "down.3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_0_", "up.0."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_1_", "up.1."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_2_", "up.2."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_3_", "up.3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_4_", "up.4."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_0.", "0."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_1.", "1."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_2.", "2."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_3.", "3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("upsamp.", "4."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("downsamp.", "3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("mid_0", "mid.0"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("mid_1", "mid.1"): v for k,v in cd_orig.items()}
# conv+linear
cd_orig = {k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f.w", "f.weight").replace("f.b", "f.bias"): v for k,v in cd_orig.items()}
# GN
cd_orig = {k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias"): v for k,v in cd_orig.items()}
print(cd_orig.keys())
cd_orig["embed_time.emb.weight"] = safetensors.torch.load_file("embedding.safetensors")["weight"]
model.load_state_dict(cd_orig)
print(cd_orig["embed_time.emb.weight"][1][0])
def round_timesteps(
timesteps, total_timesteps, n_distilled_steps, truncate_start=True
):
with torch.no_grad():
space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
rounded_timesteps = (
torch.div(timesteps, space, rounding_mode="floor") + 1
) * space
if truncate_start:
rounded_timesteps[rounded_timesteps == total_timesteps] -= space
else:
rounded_timesteps[rounded_timesteps == total_timesteps] -= space
rounded_timesteps[rounded_timesteps == 0] += space
return rounded_timesteps
ts = round_timesteps(
torch.arange(0, 1024),
1024,
64,
truncate_start=False,
)
print(ts[0], ts.shape)
# model.forward(torch.zeros(1, 3, 256, 256), torch.zeros(1, dtype=torch.int), torch.zeros(1, 4, 256//8, 256//8))
model.forward(torch.zeros(1, 3, 256, 256), torch.tensor([ts[0].item()] * 1), torch.zeros(1, 4, 256//8, 256//8))
safetensors.torch.save_file(model.state_dict(), "stk_consistency_decoder_amalgamated.safetensors")
@madebyollin
Copy link

madebyollin commented Nov 7, 2023

Thanks for this! Made some minor fixes & it seems to work https://gist.github.com/madebyollin/865fa6a18d9099351ddbdfbe7299ccbf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment