Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Forked from mrsteyk/README.md
Last active April 2, 2024 13:50
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save madebyollin/865fa6a18d9099351ddbdfbe7299ccbf to your computer and use it in GitHub Desktop.
Save madebyollin/865fa6a18d9099351ddbdfbe7299ccbf to your computer and use it in GitHub Desktop.
dalle_runner_api.model_infra.modules.public_diff_vae

Consistency Decoder PyTorch Model Code

Cleaned up version of https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac, which is in turn based on the public_diff_vae.ConvUNetVAE from https://github.com/openai/consistencydecoder.

Example Usage

Install the consistency decoder code (for the inference logic) and download the extracted weights:

pip install -q git+https://github.com/openai/consistencydecoder.git
git clone https://huggingface.co/mrsteyk/consistency-decoder-sd15/

Then, run the standard sample code (but replace the jitted checkpoint with a ConvUNetVAE instance):

import torch
from diffusers import StableDiffusionPipeline
from consistencydecoder import ConsistencyDecoder, save_image, load_image

from conv_unet_vae import ConvUNetVAE, rename_state_dict
from safetensors.torch import load_file as stl

# encode with stable diffusion vae
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe.vae.cuda()

# construct original decoder with jitted model
decoder_consistency = ConsistencyDecoder(device="cuda:0")

# construct UNet code, overwrite the decoder with conv_unet_vae
model = ConvUNetVAE()
model.load_state_dict(
    rename_state_dict(
        stl("consistency-decoder-sd15/consistency_decoder.safetensors"),
        stl("consistency-decoder-sd15/embedding.safetensors"),
    )
)
model = model.cuda()
decoder_consistency.ckpt = model

image = load_image("test_dog_image.jpg", size=(256, 256), center_crop=True)
latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()

# decode with gan
sample_gan = pipe.vae.decode(latent).sample.detach()
save_image(sample_gan, "gan.png")

# decode with conv_unet_vae
sample_consistency = decoder_consistency(latent)
save_image(sample_consistency, "con.png")

The result should be a faithful reconstruction of the original image:

image

#!/usr/bin/env python3
"""
Cleaned up reimplementation of public_diff_vae.ConvUNetVAE,
thanks to https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac.
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
class TimestepEmbedding(nn.Module):
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
super().__init__()
self.emb = nn.Embedding(n_time, n_emb)
self.f_1 = nn.Linear(n_emb, n_out)
self.f_2 = nn.Linear(n_out, n_out)
def forward(self, x) -> torch.Tensor:
x = self.emb(x)
x = self.f_1(x)
x = F.silu(x)
return self.f_2(x)
class ImageEmbedding(nn.Module):
def __init__(self, in_channels=7, out_channels=320) -> None:
super().__init__()
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class ImageUnembedding(nn.Module):
def __init__(self, in_channels=320, out_channels=6) -> None:
super().__init__()
self.gn = nn.GroupNorm(32, in_channels)
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(F.silu(self.gn(x)))
class ConvResblock(nn.Module):
def __init__(self, in_features=320, out_features=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, out_features * 2)
self.gn_1 = nn.GroupNorm(32, in_features)
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, out_features)
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
skip_conv = in_features != out_features
self.f_s = (
nn.Conv2d(in_features, out_features, kernel_size=1, padding=0)
if skip_conv
else nn.Identity()
)
def forward(self, x, t):
x_skip = x
t = self.f_t(F.silu(t))
t = t.chunk(2, dim=1)
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
gn_1 = F.silu(self.gn_1(x))
f_1 = self.f_1(gn_1)
gn_2 = self.gn_2(f_1)
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
# Also ConvResblock
class Downsample(nn.Module):
def __init__(self, in_channels=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
f_1 = self.f_1(avg_pool2d)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
# Also ConvResblock
class Upsample(nn.Module):
def __init__(self, in_channels=1024) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
upsample = F.upsample_nearest(gn_1, scale_factor=2)
f_1 = self.f_1(upsample)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
class ConvUNetVAE(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embed_image = ImageEmbedding()
self.embed_time = TimestepEmbedding()
down_0 = nn.ModuleList(
[
ConvResblock(320, 320),
ConvResblock(320, 320),
ConvResblock(320, 320),
Downsample(320),
]
)
down_1 = nn.ModuleList(
[
ConvResblock(320, 640),
ConvResblock(640, 640),
ConvResblock(640, 640),
Downsample(640),
]
)
down_2 = nn.ModuleList(
[
ConvResblock(640, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
Downsample(1024),
]
)
down_3 = nn.ModuleList(
[
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
]
)
self.down = nn.ModuleList(
[
down_0,
down_1,
down_2,
down_3,
]
)
self.mid = nn.ModuleList(
[
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
]
)
up_3 = nn.ModuleList(
[
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
Upsample(1024),
]
)
up_2 = nn.ModuleList(
[
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 + 640, 1024),
Upsample(1024),
]
)
up_1 = nn.ModuleList(
[
ConvResblock(1024 + 640, 640),
ConvResblock(640 * 2, 640),
ConvResblock(640 * 2, 640),
ConvResblock(320 + 640, 640),
Upsample(640),
]
)
up_0 = nn.ModuleList(
[
ConvResblock(320 + 640, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
]
)
self.up = nn.ModuleList(
[
up_0,
up_1,
up_2,
up_3,
]
)
self.output = ImageUnembedding()
def forward(self, x, t, features) -> torch.Tensor:
x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
t = self.embed_time(t)
x = self.embed_image(x)
skips = [x]
for down in self.down:
for block in down:
x = block(x, t)
skips.append(x)
for i in range(2):
x = self.mid[i](x, t)
for up in self.up[::-1]:
for block in up:
if isinstance(block, ConvResblock):
x = torch.concat([x, skips.pop()], dim=1)
x = block(x, t)
return self.output(x)
def rename_state_dict_key(k):
k = k.replace("blocks.", "")
for i in range(5):
k = k.replace(f"down_{i}_", f"down.{i}.")
k = k.replace(f"conv_{i}.", f"{i}.")
k = k.replace(f"up_{i}_", f"up.{i}.")
k = k.replace(f"mid_{i}", f"mid.{i}")
k = k.replace("upsamp.", "4.")
k = k.replace("downsamp.", "3.")
k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
return k
def rename_state_dict(sd, embedding):
sd = {rename_state_dict_key(k): v for k, v in sd.items()}
sd["embed_time.emb.weight"] = embedding["weight"]
return sd
if __name__ == "__main__":
model = ConvUNetVAE()
import safetensors.torch
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors")
embedding = safetensors.torch.load_file("embedding.safetensors")
print(model.load_state_dict(rename_state_dict(cd_orig, embedding)))
@patrickvonplaten
Copy link

Amazing work!

@mrsteyk
Copy link

mrsteyk commented Nov 7, 2023

Uploaded ready to use checkpoint and combined Downsample and Upsample into ConvResblock in my original code (internal name for all of them is the same), should make code slightly cleaner, probably need to decouple f_x into f_x and f_s alongside removing skip_conv=True and doing an if on channels to be more faithful to the original OpenAI UNet code.

@city96
Copy link

city96 commented Nov 7, 2023

With a few changes it seems to work in fp16/bf16 too, which makes it a lot more feasible than being locked to loading the model with torch.jit. Great job on this!

@madebyollin
Copy link
Author

@mrsteyk Excellent! I like the combined resblock idea.

@city96 Yeah, I think float16 should work out of the box and bfloat16 only needed adjusted upsampling code to work on my machine (torch 2.0.1+cu117), so hopefully it's easy to integrate.

@city96
Copy link

city96 commented Nov 8, 2023

@madebyollin I just ended up casting it to float and then back to the dtype before/after the upsample bit. I'd imagine the perf loss isn't too bad but I might test it against your code sometime. Thanks!

@RomainSF
Copy link

RomainSF commented Nov 8, 2023

what does it take to make it compatible with sdxl? I guess it's missing the source to retrain it?

@madebyollin
Copy link
Author

@city96 Makes sense! Most of the time should be convs anyway so it probably doesn't matter much.

@RomainSF Yeah, you'd need to train / fine-tune the CD model on SDXL latents to make it compatible. You can technically use CD+SDXL as-is using city96's SDXL->SD latent interposer Birch-san/sdxl-play#4 (comment) but it seems like the quality lost from running the interposer usually exceeds the quality gained from using CD decoding.

@city96
Copy link

city96 commented Nov 11, 2023

@madebyollin I commented on that linked thread but might as well mention it here as well; The interposer 100% causes more quality loss than the consistency decoder makes up for. It's fairly crude, mostly meant to let you use SDXL latents at a high denoise with SDv1 models. I'd be happy to improve it, if I knew how to lol.

@madebyollin
Copy link
Author

@city96 Yeah, your interposer is difficult to beat 😅 switching to diffusion for the interposer itself might help a bit (instead of L1 / perceptual losses), but there will still inevitably be some destruction of info if we do SDXL Latents->Ideal Interposer->SD Latents->CD, since SDXL's latent format can express some information which the SD latent format can't.

For super high quality SDXL decoding, I expect fine-tuning the CD model would ultimately be the best option.

@city96
Copy link

city96 commented Nov 11, 2023

@madebyollin

switching to diffusion for the interposer itself might help a bit

Not sure I follow. Do you mean doing something similar to what CD does (i.e. a multi-step process)? I fear that it might get too slow if I do something like that. And if it isn't really faster than doing a VAE decode->encode between the two models then the whole thing becomes a bit useless lol.

there will still inevitably be some destruction of info if we do SDXL Latents->Ideal Interposer->SD Latents->CD

Yup. So even if the XL latent encoded a letter on the image correctly, that'd get lost on the way and CD would just make up a similar looking letter from the data that's there, since it was only ever trained to work with the KL-F8 encoder. (Reverse also applies, hence why the v1->xl interposer has higher loss. At that point you're asking the tiny 5MB model to not only match the format, but to make up fake details...)

For super high quality SDXL decoding, I expect fine-tuning the CD model would ultimately be the best option.

Agreed. Hope they end up releasing the training code. Would be interesting to see CD finetuned for specialized usecases as well (realism VS art, etc).

@madebyollin
Copy link
Author

@city96 Yeah, I was imagining a multistep interposer (small diffusion model, small # sampling steps). But I agree - it's easier to just roundtrip to pixel space at that point 😆

@Doctor-James
Copy link

Thank you for your excellent work. May I ask how to train a ConsistencyDecoder from scratch? OpenAI's repository does not provide a complete training process. Could you offer some suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment