Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Forked from mrsteyk/README.md
Last active April 2, 2024 13:50
Show Gist options
  • Save madebyollin/865fa6a18d9099351ddbdfbe7299ccbf to your computer and use it in GitHub Desktop.
Save madebyollin/865fa6a18d9099351ddbdfbe7299ccbf to your computer and use it in GitHub Desktop.
dalle_runner_api.model_infra.modules.public_diff_vae

Consistency Decoder PyTorch Model Code

Cleaned up version of https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac, which is in turn based on the public_diff_vae.ConvUNetVAE from https://github.com/openai/consistencydecoder.

Example Usage

Install the consistency decoder code (for the inference logic) and download the extracted weights:

pip install -q git+https://github.com/openai/consistencydecoder.git
git clone https://huggingface.co/mrsteyk/consistency-decoder-sd15/

Then, run the standard sample code (but replace the jitted checkpoint with a ConvUNetVAE instance):

import torch
from diffusers import StableDiffusionPipeline
from consistencydecoder import ConsistencyDecoder, save_image, load_image

from conv_unet_vae import ConvUNetVAE, rename_state_dict
from safetensors.torch import load_file as stl

# encode with stable diffusion vae
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe.vae.cuda()

# construct original decoder with jitted model
decoder_consistency = ConsistencyDecoder(device="cuda:0")

# construct UNet code, overwrite the decoder with conv_unet_vae
model = ConvUNetVAE()
model.load_state_dict(
    rename_state_dict(
        stl("consistency-decoder-sd15/consistency_decoder.safetensors"),
        stl("consistency-decoder-sd15/embedding.safetensors"),
    )
)
model = model.cuda()
decoder_consistency.ckpt = model

image = load_image("test_dog_image.jpg", size=(256, 256), center_crop=True)
latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()

# decode with gan
sample_gan = pipe.vae.decode(latent).sample.detach()
save_image(sample_gan, "gan.png")

# decode with conv_unet_vae
sample_consistency = decoder_consistency(latent)
save_image(sample_consistency, "con.png")

The result should be a faithful reconstruction of the original image:

image

#!/usr/bin/env python3
"""
Cleaned up reimplementation of public_diff_vae.ConvUNetVAE,
thanks to https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac.
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
class TimestepEmbedding(nn.Module):
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
super().__init__()
self.emb = nn.Embedding(n_time, n_emb)
self.f_1 = nn.Linear(n_emb, n_out)
self.f_2 = nn.Linear(n_out, n_out)
def forward(self, x) -> torch.Tensor:
x = self.emb(x)
x = self.f_1(x)
x = F.silu(x)
return self.f_2(x)
class ImageEmbedding(nn.Module):
def __init__(self, in_channels=7, out_channels=320) -> None:
super().__init__()
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class ImageUnembedding(nn.Module):
def __init__(self, in_channels=320, out_channels=6) -> None:
super().__init__()
self.gn = nn.GroupNorm(32, in_channels)
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(F.silu(self.gn(x)))
class ConvResblock(nn.Module):
def __init__(self, in_features=320, out_features=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, out_features * 2)
self.gn_1 = nn.GroupNorm(32, in_features)
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, out_features)
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
skip_conv = in_features != out_features
self.f_s = (
nn.Conv2d(in_features, out_features, kernel_size=1, padding=0)
if skip_conv
else nn.Identity()
)
def forward(self, x, t):
x_skip = x
t = self.f_t(F.silu(t))
t = t.chunk(2, dim=1)
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
gn_1 = F.silu(self.gn_1(x))
f_1 = self.f_1(gn_1)
gn_2 = self.gn_2(f_1)
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
# Also ConvResblock
class Downsample(nn.Module):
def __init__(self, in_channels=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
f_1 = self.f_1(avg_pool2d)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
# Also ConvResblock
class Upsample(nn.Module):
def __init__(self, in_channels=1024) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
upsample = F.upsample_nearest(gn_1, scale_factor=2)
f_1 = self.f_1(upsample)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
class ConvUNetVAE(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embed_image = ImageEmbedding()
self.embed_time = TimestepEmbedding()
down_0 = nn.ModuleList(
[
ConvResblock(320, 320),
ConvResblock(320, 320),
ConvResblock(320, 320),
Downsample(320),
]
)
down_1 = nn.ModuleList(
[
ConvResblock(320, 640),
ConvResblock(640, 640),
ConvResblock(640, 640),
Downsample(640),
]
)
down_2 = nn.ModuleList(
[
ConvResblock(640, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
Downsample(1024),
]
)
down_3 = nn.ModuleList(
[
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
]
)
self.down = nn.ModuleList(
[
down_0,
down_1,
down_2,
down_3,
]
)
self.mid = nn.ModuleList(
[
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
]
)
up_3 = nn.ModuleList(
[
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
Upsample(1024),
]
)
up_2 = nn.ModuleList(
[
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 + 640, 1024),
Upsample(1024),
]
)
up_1 = nn.ModuleList(
[
ConvResblock(1024 + 640, 640),
ConvResblock(640 * 2, 640),
ConvResblock(640 * 2, 640),
ConvResblock(320 + 640, 640),
Upsample(640),
]
)
up_0 = nn.ModuleList(
[
ConvResblock(320 + 640, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
]
)
self.up = nn.ModuleList(
[
up_0,
up_1,
up_2,
up_3,
]
)
self.output = ImageUnembedding()
def forward(self, x, t, features) -> torch.Tensor:
x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
t = self.embed_time(t)
x = self.embed_image(x)
skips = [x]
for down in self.down:
for block in down:
x = block(x, t)
skips.append(x)
for i in range(2):
x = self.mid[i](x, t)
for up in self.up[::-1]:
for block in up:
if isinstance(block, ConvResblock):
x = torch.concat([x, skips.pop()], dim=1)
x = block(x, t)
return self.output(x)
def rename_state_dict_key(k):
k = k.replace("blocks.", "")
for i in range(5):
k = k.replace(f"down_{i}_", f"down.{i}.")
k = k.replace(f"conv_{i}.", f"{i}.")
k = k.replace(f"up_{i}_", f"up.{i}.")
k = k.replace(f"mid_{i}", f"mid.{i}")
k = k.replace("upsamp.", "4.")
k = k.replace("downsamp.", "3.")
k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
return k
def rename_state_dict(sd, embedding):
sd = {rename_state_dict_key(k): v for k, v in sd.items()}
sd["embed_time.emb.weight"] = embedding["weight"]
return sd
if __name__ == "__main__":
model = ConvUNetVAE()
import safetensors.torch
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors")
embedding = safetensors.torch.load_file("embedding.safetensors")
print(model.load_state_dict(rename_state_dict(cd_orig, embedding)))
@madebyollin
Copy link
Author

@city96 Yeah, I was imagining a multistep interposer (small diffusion model, small # sampling steps). But I agree - it's easier to just roundtrip to pixel space at that point 😆

@Doctor-James
Copy link

Thank you for your excellent work. May I ask how to train a ConsistencyDecoder from scratch? OpenAI's repository does not provide a complete training process. Could you offer some suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment