Skip to content

Instantly share code, notes, and snippets.

@cmdr2
Created November 24, 2022 04:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cmdr2/685d0e1f33c2e6a869aeee190b7b3f2d to your computer and use it in GitHub Desktop.
Save cmdr2/685d0e1f33c2e6a869aeee190b7b3f2d to your computer and use it in GitHub Desktop.
Diff between https://github.com/CompVis/stable-diffusion (sd1) and https://github.com/Stability-AI/stablediffusion (sd2). Ran `diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1 sd2 > d.txt`
Only in sd2: LICENSE-MODEL
Only in sd2/assets: stable-inpainting
Only in sd2/assets/stable-samples: depth2img
Only in sd2/assets/stable-samples/txt2img: 768
Only in sd2/assets/stable-samples: upscaling
Only in sd1/configs: autoencoder
Only in sd1/configs: latent-diffusion
Only in sd1/configs: retrieval-augmented-diffusion
Only in sd1/configs/stable-diffusion: v1-inference.yaml
Only in sd2/configs/stable-diffusion: v2-inference-v.yaml
Only in sd2/configs/stable-diffusion: v2-inference.yaml
Only in sd2/configs/stable-diffusion: v2-inpainting-inference.yaml
Only in sd2/configs/stable-diffusion: v2-midas-inference.yaml
Only in sd2/configs/stable-diffusion: x4-upscaling.yaml
Only in sd1: data
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/environment.yaml sd2/environment.yaml
9,11c9,11
< - pytorch=1.11.0
< - torchvision=0.12.0
< - numpy=1.19.2
---
> - pytorch=1.12.1
> - torchvision=0.13.1
> - numpy=1.23.1
13,17c13,14
< - albumentations==0.4.3
< - diffusers
< - opencv-python==4.1.2.30
< - pudb==2019.2
< - invisible-watermark
---
> - albumentations==1.3.0
> - opencv-python==4.6.0.66
23c20
< - streamlit>=0.73.1
---
> - streamlit==1.12.1
25d21
< - torch-fidelity==0.3.0
27c23
< - torchmetrics==0.6.0
---
> - webdataset==0.2.5
29,30c25,28
< - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
< - -e git+https://github.com/openai/CLIP.git@main#egg=clip
---
> - open_clip_torch==2.0.2
> - invisible-watermark>=0.1.5
> - streamlit-drawable-canvas==0.8.0
> - torchmetrics==0.6.0
Only in sd1/ldm/data: base.py
Only in sd1/ldm/data: imagenet.py
Only in sd1/ldm/data: lsun.py
Only in sd2/ldm/data: util.py
Only in sd1/ldm: lr_scheduler.py
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/autoencoder.py sd2/ldm/models/autoencoder.py
6,7d5
< from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
<
11a10
> from ldm.modules.ema import LitEma
14c13
< class VQModel(pl.LightningModule):
---
> class AutoencoderKL(pl.LightningModule):
18d16
< n_embed,
25,30c23,24
< batch_resize_range=None,
< scheduler_config=None,
< lr_g_factor=1.0,
< remap=None,
< sane_index_shape=False, # tell vector quantizer to return indices as bhw
< use_ema=False
---
> ema_decay=None,
> learn_logvar=False
33,34c27
< self.embed_dim = embed_dim
< self.n_embed = n_embed
---
> self.learn_logvar = learn_logvar
39,42c32,33
< self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
< remap=remap,
< sane_index_shape=sane_index_shape)
< self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
---
> assert ddconfig["double_z"]
> self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
43a35
> self.embed_dim = embed_dim
49,51d40
< self.batch_resize_range = batch_resize_range
< if self.batch_resize_range is not None:
< print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
53c42
< self.use_ema = use_ema
---
> self.use_ema = ema_decay is not None
55c44,46
< self.model_ema = LitEma(self)
---
> self.ema_decay = ema_decay
> assert 0. < ema_decay < 1.
> self.model_ema = LitEma(self, decay=ema_decay)
60,61c51,61
< self.scheduler_config = scheduler_config
< self.lr_g_factor = lr_g_factor
---
>
> def init_from_ckpt(self, path, ignore_keys=list()):
> sd = torch.load(path, map_location="cpu")["state_dict"]
> keys = list(sd.keys())
> for k in keys:
> for ik in ignore_keys:
> if k.startswith(ik):
> print("Deleting key {} from state_dict.".format(k))
> del sd[k]
> self.load_state_dict(sd, strict=False)
> print(f"Restored from {path}")
78,91d77
< def init_from_ckpt(self, path, ignore_keys=list()):
< sd = torch.load(path, map_location="cpu")["state_dict"]
< keys = list(sd.keys())
< for k in keys:
< for ik in ignore_keys:
< if k.startswith(ik):
< print("Deleting key {} from state_dict.".format(k))
< del sd[k]
< missing, unexpected = self.load_state_dict(sd, strict=False)
< print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
< if len(missing) > 0:
< print(f"Missing Keys: {missing}")
< print(f"Unexpected Keys: {unexpected}")
<
98,325d83
< h = self.quant_conv(h)
< quant, emb_loss, info = self.quantize(h)
< return quant, emb_loss, info
<
< def encode_to_prequant(self, x):
< h = self.encoder(x)
< h = self.quant_conv(h)
< return h
<
< def decode(self, quant):
< quant = self.post_quant_conv(quant)
< dec = self.decoder(quant)
< return dec
<
< def decode_code(self, code_b):
< quant_b = self.quantize.embed_code(code_b)
< dec = self.decode(quant_b)
< return dec
<
< def forward(self, input, return_pred_indices=False):
< quant, diff, (_,_,ind) = self.encode(input)
< dec = self.decode(quant)
< if return_pred_indices:
< return dec, diff, ind
< return dec, diff
<
< def get_input(self, batch, k):
< x = batch[k]
< if len(x.shape) == 3:
< x = x[..., None]
< x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
< if self.batch_resize_range is not None:
< lower_size = self.batch_resize_range[0]
< upper_size = self.batch_resize_range[1]
< if self.global_step <= 4:
< # do the first few batches with max size to avoid later oom
< new_resize = upper_size
< else:
< new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
< if new_resize != x.shape[2]:
< x = F.interpolate(x, size=new_resize, mode="bicubic")
< x = x.detach()
< return x
<
< def training_step(self, batch, batch_idx, optimizer_idx):
< # https://github.com/pytorch/pytorch/issues/37142
< # try not to fool the heuristics
< x = self.get_input(batch, self.image_key)
< xrec, qloss, ind = self(x, return_pred_indices=True)
<
< if optimizer_idx == 0:
< # autoencode
< aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
< last_layer=self.get_last_layer(), split="train",
< predicted_indices=ind)
<
< self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
< return aeloss
<
< if optimizer_idx == 1:
< # discriminator
< discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
< last_layer=self.get_last_layer(), split="train")
< self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
< return discloss
<
< def validation_step(self, batch, batch_idx):
< log_dict = self._validation_step(batch, batch_idx)
< with self.ema_scope():
< log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
< return log_dict
<
< def _validation_step(self, batch, batch_idx, suffix=""):
< x = self.get_input(batch, self.image_key)
< xrec, qloss, ind = self(x, return_pred_indices=True)
< aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
< self.global_step,
< last_layer=self.get_last_layer(),
< split="val"+suffix,
< predicted_indices=ind
< )
<
< discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
< self.global_step,
< last_layer=self.get_last_layer(),
< split="val"+suffix,
< predicted_indices=ind
< )
< rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
< self.log(f"val{suffix}/rec_loss", rec_loss,
< prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
< self.log(f"val{suffix}/aeloss", aeloss,
< prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
< if version.parse(pl.__version__) >= version.parse('1.4.0'):
< del log_dict_ae[f"val{suffix}/rec_loss"]
< self.log_dict(log_dict_ae)
< self.log_dict(log_dict_disc)
< return self.log_dict
<
< def configure_optimizers(self):
< lr_d = self.learning_rate
< lr_g = self.lr_g_factor*self.learning_rate
< print("lr_d", lr_d)
< print("lr_g", lr_g)
< opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
< list(self.decoder.parameters())+
< list(self.quantize.parameters())+
< list(self.quant_conv.parameters())+
< list(self.post_quant_conv.parameters()),
< lr=lr_g, betas=(0.5, 0.9))
< opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
< lr=lr_d, betas=(0.5, 0.9))
<
< if self.scheduler_config is not None:
< scheduler = instantiate_from_config(self.scheduler_config)
<
< print("Setting up LambdaLR scheduler...")
< scheduler = [
< {
< 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
< 'interval': 'step',
< 'frequency': 1
< },
< {
< 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
< 'interval': 'step',
< 'frequency': 1
< },
< ]
< return [opt_ae, opt_disc], scheduler
< return [opt_ae, opt_disc], []
<
< def get_last_layer(self):
< return self.decoder.conv_out.weight
<
< def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
< log = dict()
< x = self.get_input(batch, self.image_key)
< x = x.to(self.device)
< if only_inputs:
< log["inputs"] = x
< return log
< xrec, _ = self(x)
< if x.shape[1] > 3:
< # colorize with random projection
< assert xrec.shape[1] > 3
< x = self.to_rgb(x)
< xrec = self.to_rgb(xrec)
< log["inputs"] = x
< log["reconstructions"] = xrec
< if plot_ema:
< with self.ema_scope():
< xrec_ema, _ = self(x)
< if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
< log["reconstructions_ema"] = xrec_ema
< return log
<
< def to_rgb(self, x):
< assert self.image_key == "segmentation"
< if not hasattr(self, "colorize"):
< self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
< x = F.conv2d(x, weight=self.colorize)
< x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
< return x
<
<
< class VQModelInterface(VQModel):
< def __init__(self, embed_dim, *args, **kwargs):
< super().__init__(embed_dim=embed_dim, *args, **kwargs)
< self.embed_dim = embed_dim
<
< def encode(self, x):
< h = self.encoder(x)
< h = self.quant_conv(h)
< return h
<
< def decode(self, h, force_not_quantize=False):
< # also go through quantization layer
< if not force_not_quantize:
< quant, emb_loss, info = self.quantize(h)
< else:
< quant = h
< quant = self.post_quant_conv(quant)
< dec = self.decoder(quant)
< return dec
<
<
< class AutoencoderKL(pl.LightningModule):
< def __init__(self,
< ddconfig,
< lossconfig,
< embed_dim,
< ckpt_path=None,
< ignore_keys=[],
< image_key="image",
< colorize_nlabels=None,
< monitor=None,
< ):
< super().__init__()
< self.image_key = image_key
< self.encoder = Encoder(**ddconfig)
< self.decoder = Decoder(**ddconfig)
< self.loss = instantiate_from_config(lossconfig)
< assert ddconfig["double_z"]
< self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
< self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
< self.embed_dim = embed_dim
< if colorize_nlabels is not None:
< assert type(colorize_nlabels)==int
< self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
< if monitor is not None:
< self.monitor = monitor
< if ckpt_path is not None:
< self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
<
< def init_from_ckpt(self, path, ignore_keys=list()):
< sd = torch.load(path, map_location="cpu")["state_dict"]
< keys = list(sd.keys())
< for k in keys:
< for ik in ignore_keys:
< if k.startswith(ik):
< print("Deleting key {} from state_dict.".format(k))
< del sd[k]
< self.load_state_dict(sd, strict=False)
< print(f"Restored from {path}")
<
< def encode(self, x):
< h = self.encoder(x)
372a131,136
> log_dict = self._validation_step(batch, batch_idx)
> with self.ema_scope():
> log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
> return log_dict
>
> def _validation_step(self, batch, batch_idx, postfix=""):
376c140
< last_layer=self.get_last_layer(), split="val")
---
> last_layer=self.get_last_layer(), split="val"+postfix)
379c143
< last_layer=self.get_last_layer(), split="val")
---
> last_layer=self.get_last_layer(), split="val"+postfix)
381c145
< self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
---
> self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
388,391c152,157
< opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
< list(self.decoder.parameters())+
< list(self.quant_conv.parameters())+
< list(self.post_quant_conv.parameters()),
---
> ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
> self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
> if self.learn_logvar:
> print(f"{self.__class__.__name__}: Learning logvar")
> ae_params_list.append(self.loss.logvar)
> opt_ae = torch.optim.Adam(ae_params_list,
401c167
< def log_images(self, batch, only_inputs=False, **kwargs):
---
> def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
413a180,188
> if log_ema or self.use_ema:
> with self.ema_scope():
> xrec_ema, posterior_ema = self(x)
> if x.shape[1] > 3:
> # colorize with random projection
> assert xrec_ema.shape[1] > 3
> xrec_ema = self.to_rgb(xrec_ema)
> log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
> log["reconstructions_ema"] = xrec_ema
428c203
< self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
---
> self.vq_interface = vq_interface
443a219
>
Only in sd1/ldm/models/diffusion: classifier.py
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/ddim.py sd2/ldm/models/diffusion/ddim.py
6d5
< from functools import partial
8,9c7
< from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
< extract_into_tensor
---
> from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
77,78c75,77
< unconditional_conditioning=None,
< # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
---
> unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
> dynamic_threshold=None,
> ucg_schedule=None,
83c82,84
< cbs = conditioning[list(conditioning.keys())[0]].shape[0]
---
> ctmp = conditioning[list(conditioning.keys())[0]]
> while isinstance(ctmp, list): ctmp = ctmp[0]
> cbs = ctmp.shape[0]
85a87,92
>
> elif isinstance(conditioning, list):
> for ctmp in conditioning:
> if ctmp.shape[0] != batch_size:
> print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
>
109a117,118
> dynamic_threshold=dynamic_threshold,
> ucg_schedule=ucg_schedule
119c128,129
< unconditional_guidance_scale=1., unconditional_conditioning=None,):
---
> unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
> ucg_schedule=None):
148a159,162
> if ucg_schedule is not None:
> assert len(ucg_schedule) == len(time_range)
> unconditional_guidance_scale = ucg_schedule[i]
>
154c168,169
< unconditional_conditioning=unconditional_conditioning)
---
> unconditional_conditioning=unconditional_conditioning,
> dynamic_threshold=dynamic_threshold)
168c183,184
< unconditional_guidance_scale=1., unconditional_conditioning=None):
---
> unconditional_guidance_scale=1., unconditional_conditioning=None,
> dynamic_threshold=None):
172c188
< e_t = self.model.apply_model(x, t, c)
---
> model_output = self.model.apply_model(x, t, c)
175a192,209
> if isinstance(c, dict):
> assert isinstance(unconditional_conditioning, dict)
> c_in = dict()
> for k in c:
> if isinstance(c[k], list):
> c_in[k] = [torch.cat([
> unconditional_conditioning[k][i],
> c[k][i]]) for i in range(len(c[k]))]
> else:
> c_in[k] = torch.cat([
> unconditional_conditioning[k],
> c[k]])
> elif isinstance(c, list):
> c_in = list()
> assert isinstance(unconditional_conditioning, list)
> for i in range(len(c)):
> c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
> else:
177,178c211,217
< e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
< e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
---
> model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
> model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
>
> if self.model.parameterization == "v":
> e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
> else:
> e_t = model_output
181c220
< assert self.model.parameterization == "eps"
---
> assert self.model.parameterization == "eps", 'not implemented'
194a234
> if self.model.parameterization != "v":
195a236,238
> else:
> pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
>
197a241,244
>
> if dynamic_threshold is not None:
> raise NotImplementedError()
>
206a254,300
> def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
> unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
> num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
>
> assert t_enc <= num_reference_steps
> num_steps = t_enc
>
> if use_original_steps:
> alphas_next = self.alphas_cumprod[:num_steps]
> alphas = self.alphas_cumprod_prev[:num_steps]
> else:
> alphas_next = self.ddim_alphas[:num_steps]
> alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
>
> x_next = x0
> intermediates = []
> inter_steps = []
> for i in tqdm(range(num_steps), desc='Encoding Image'):
> t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
> if unconditional_guidance_scale == 1.:
> noise_pred = self.model.apply_model(x_next, t, c)
> else:
> assert unconditional_conditioning is not None
> e_t_uncond, noise_pred = torch.chunk(
> self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
> torch.cat((unconditional_conditioning, c))), 2)
> noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
>
> xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
> weighted_noise_pred = alphas_next[i].sqrt() * (
> (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
> x_next = xt_weighted + weighted_noise_pred
> if return_intermediates and i % (
> num_steps // return_intermediates) == 0 and i < num_steps - 1:
> intermediates.append(x_next)
> inter_steps.append(i)
> elif return_intermediates and i >= num_steps - 2:
> intermediates.append(x_next)
> inter_steps.append(i)
> if callback: callback(i)
>
> out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
> if return_intermediates:
> out.update({'intermediates': intermediates})
> return x_next, out
>
> @torch.no_grad()
224c318
< use_original_steps=False):
---
> use_original_steps=False, callback=None):
240a335
> if callback: callback(i)
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/ddpm.py sd2/ldm/models/diffusion/ddpm.py
15c15
< from contextlib import contextmanager
---
> from contextlib import contextmanager, nullcontext
16a17
> import itertools
19a21
> from omegaconf import ListConfig
24c26
< from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
---
> from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
73a76,79
> make_it_fit=False,
> ucg_training=None,
> reset_ema=False,
> reset_num_ema_updates=False,
76c82
< assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
---
> assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
102a109,110
> self.make_it_fit = make_it_fit
> if reset_ema: assert exists(ckpt_path)
104a113,120
> if reset_ema:
> assert self.use_ema
> print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
> self.model_ema = LitEma(self.model)
> if reset_num_ema_updates:
> print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
> assert self.use_ema
> self.model_ema.reset_num_updates()
115a132,134
> self.ucg_training = ucg_training or dict()
> if self.ucg_training:
> self.ucg_prng = np.random.RandomState()
163a183,185
> elif self.parameterization == "v":
> lvlb_weights = torch.ones_like(self.betas ** 2 / (
> 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
166d187
< # TODO how to choose this term
185a207
> @torch.no_grad()
195a218,261
> if self.make_it_fit:
> n_params = len([name for name, _ in
> itertools.chain(self.named_parameters(),
> self.named_buffers())])
> for name, param in tqdm(
> itertools.chain(self.named_parameters(),
> self.named_buffers()),
> desc="Fitting old weights to new weights",
> total=n_params
> ):
> if not name in sd:
> continue
> old_shape = sd[name].shape
> new_shape = param.shape
> assert len(old_shape) == len(new_shape)
> if len(new_shape) > 2:
> # we only modify first two axes
> assert new_shape[2:] == old_shape[2:]
> # assumes first axis corresponds to output dim
> if not new_shape == old_shape:
> new_param = param.clone()
> old_param = sd[name]
> if len(new_shape) == 1:
> for i in range(new_param.shape[0]):
> new_param[i] = old_param[i % old_shape[0]]
> elif len(new_shape) >= 2:
> for i in range(new_param.shape[0]):
> for j in range(new_param.shape[1]):
> new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
>
> n_used_old = torch.ones(old_shape[1])
> for j in range(new_param.shape[1]):
> n_used_old[j % old_shape[1]] += 1
> n_used_new = torch.zeros(new_shape[1])
> for j in range(new_param.shape[1]):
> n_used_new[j] = n_used_old[j % old_shape[1]]
>
> n_used_new = n_used_new[None, :]
> while len(n_used_new.shape) < len(new_shape):
> n_used_new = n_used_new.unsqueeze(-1)
> new_param /= n_used_new
>
> sd[name] = new_param
>
200c266
< print(f"Missing Keys: {missing}")
---
> print(f"Missing Keys:\n {missing}")
202c268
< print(f"Unexpected Keys: {unexpected}")
---
> print(f"\nUnexpected Keys:\n {unexpected}")
221a288,301
> def predict_start_from_z_and_v(self, x_t, t, v):
> # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
> # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
> return (
> extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
> extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
> )
>
> def predict_eps_from_z_and_v(self, x_t, t, v):
> return (
> extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
> extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
> )
>
278a359,364
> def get_v(self, x, noise, t):
> return (
> extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
> extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
> )
>
303a390,391
> elif self.parameterization == "v":
> target = self.get_v(x_start, noise, t)
342a431,439
> for k in self.ucg_training:
> p = self.ucg_training[k]["p"]
> val = self.ucg_training[k]["val"]
> if val is None:
> val = ""
> for i in range(len(batch[k])):
> if self.ucg_prng.choice(2, p=[1 - p, p]):
> batch[k][i] = val
>
425a523
>
436a535
> force_null_conditioning=False,
437a537
> self.force_null_conditioning = force_null_conditioning
444c544
< if cond_stage_config == '__is_unconditional__':
---
> if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
446a547,548
> reset_ema = kwargs.pop("reset_ema", False)
> reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
469a572,580
> if reset_ema:
> assert self.use_ema
> print(
> f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
> self.model_ema = LitEma(self.model)
> if reset_num_ema_updates:
> print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
> assert self.use_ema
> self.model_ema.reset_num_updates()
655c766
< cond_key=None, return_original_cond=False, bs=None):
---
> cond_key=None, return_original_cond=False, bs=None, return_x=False):
663c774
< if self.model.conditioning_key is not None:
---
> if self.model.conditioning_key is not None and not self.force_null_conditioning:
667c778
< if cond_key in ['caption', 'coordinates_bbox']:
---
> if cond_key in ['caption', 'coordinates_bbox', "txt"]:
669c780
< elif cond_key == 'class_label':
---
> elif cond_key in ['class_label', 'cls']:
677d787
< # import pudb; pudb.set_trace()
700a811,812
> if return_x:
> out.extend([x])
714,822d825
<
< if hasattr(self, "split_input_params"):
< if self.split_input_params["patch_distributed_vq"]:
< ks = self.split_input_params["ks"] # eg. (128, 128)
< stride = self.split_input_params["stride"] # eg. (64, 64)
< uf = self.split_input_params["vqf"]
< bs, nc, h, w = z.shape
< if ks[0] > h or ks[1] > w:
< ks = (min(ks[0], h), min(ks[1], w))
< print("reducing Kernel")
<
< if stride[0] > h or stride[1] > w:
< stride = (min(stride[0], h), min(stride[1], w))
< print("reducing stride")
<
< fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
<
< z = unfold(z) # (bn, nc * prod(**ks), L)
< # 1. Reshape to img shape
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
<
< # 2. apply model loop over last dim
< if isinstance(self.first_stage_model, VQModelInterface):
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
< force_not_quantize=predict_cids or force_not_quantize)
< for i in range(z.shape[-1])]
< else:
<
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
< for i in range(z.shape[-1])]
<
< o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
< o = o * weighting
< # Reverse 1. reshape to img shape
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
< # stitch crops together
< decoded = fold(o)
< decoded = decoded / normalization # norm is shape (1, 1, h, w)
< return decoded
< else:
< if isinstance(self.first_stage_model, VQModelInterface):
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
< else:
< return self.first_stage_model.decode(z)
<
< else:
< if isinstance(self.first_stage_model, VQModelInterface):
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
< else:
< return self.first_stage_model.decode(z)
<
< # same as above but without decorator
< def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
< if predict_cids:
< if z.dim() == 4:
< z = torch.argmax(z.exp(), dim=1).long()
< z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
< z = rearrange(z, 'b h w c -> b c h w').contiguous()
<
< z = 1. / self.scale_factor * z
<
< if hasattr(self, "split_input_params"):
< if self.split_input_params["patch_distributed_vq"]:
< ks = self.split_input_params["ks"] # eg. (128, 128)
< stride = self.split_input_params["stride"] # eg. (64, 64)
< uf = self.split_input_params["vqf"]
< bs, nc, h, w = z.shape
< if ks[0] > h or ks[1] > w:
< ks = (min(ks[0], h), min(ks[1], w))
< print("reducing Kernel")
<
< if stride[0] > h or stride[1] > w:
< stride = (min(stride[0], h), min(stride[1], w))
< print("reducing stride")
<
< fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
<
< z = unfold(z) # (bn, nc * prod(**ks), L)
< # 1. Reshape to img shape
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
<
< # 2. apply model loop over last dim
< if isinstance(self.first_stage_model, VQModelInterface):
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
< force_not_quantize=predict_cids or force_not_quantize)
< for i in range(z.shape[-1])]
< else:
<
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
< for i in range(z.shape[-1])]
<
< o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
< o = o * weighting
< # Reverse 1. reshape to img shape
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
< # stitch crops together
< decoded = fold(o)
< decoded = decoded / normalization # norm is shape (1, 1, h, w)
< return decoded
< else:
< if isinstance(self.first_stage_model, VQModelInterface):
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
< else:
< return self.first_stage_model.decode(z)
<
< else:
< if isinstance(self.first_stage_model, VQModelInterface):
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
< else:
827,862d829
< if hasattr(self, "split_input_params"):
< if self.split_input_params["patch_distributed_vq"]:
< ks = self.split_input_params["ks"] # eg. (128, 128)
< stride = self.split_input_params["stride"] # eg. (64, 64)
< df = self.split_input_params["vqf"]
< self.split_input_params['original_image_size'] = x.shape[-2:]
< bs, nc, h, w = x.shape
< if ks[0] > h or ks[1] > w:
< ks = (min(ks[0], h), min(ks[1], w))
< print("reducing Kernel")
<
< if stride[0] > h or stride[1] > w:
< stride = (min(stride[0], h), min(stride[1], w))
< print("reducing stride")
<
< fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
< z = unfold(x) # (bn, nc * prod(**ks), L)
< # Reshape to img shape
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
<
< output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
< for i in range(z.shape[-1])]
<
< o = torch.stack(output_list, axis=-1)
< o = o * weighting
<
< # Reverse reshape to img shape
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
< # stitch crops together
< decoded = fold(o)
< decoded = decoded / normalization
< return decoded
<
< else:
< return self.first_stage_model.encode(x)
< else:
881,890d847
< def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
< def rescale_bbox(bbox):
< x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
< y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
< w = min(bbox[2] / crop_coordinates[2], 1 - x0)
< h = min(bbox[3] / crop_coordinates[3], 1 - y0)
< return x0, y0, w, h
<
< return [rescale_bbox(b) for b in bboxes]
<
892d848
<
894c850
< # hybrid case, cond is exptected to be a dict
---
> # hybrid case, cond is expected to be a dict
902,986d857
< if hasattr(self, "split_input_params"):
< assert len(cond) == 1 # todo can only deal with one conditioning atm
< assert not return_ids
< ks = self.split_input_params["ks"] # eg. (128, 128)
< stride = self.split_input_params["stride"] # eg. (64, 64)
<
< h, w = x_noisy.shape[-2:]
<
< fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
<
< z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
< # Reshape to img shape
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
< z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
<
< if self.cond_stage_key in ["image", "LR_image", "segmentation",
< 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
< c_key = next(iter(cond.keys())) # get key
< c = next(iter(cond.values())) # get value
< assert (len(c) == 1) # todo extend to list with more than one elem
< c = c[0] # get element
<
< c = unfold(c)
< c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
<
< cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
<
< elif self.cond_stage_key == 'coordinates_bbox':
< assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
<
< # assuming padding of unfold is always 0 and its dilation is always 1
< n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
< full_img_h, full_img_w = self.split_input_params['original_image_size']
< # as we are operating on latents, we need the factor from the original image size to the
< # spatial latent size to properly rescale the crops for regenerating the bbox annotations
< num_downs = self.first_stage_model.encoder.num_resolutions - 1
< rescale_latent = 2 ** (num_downs)
<
< # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
< # need to rescale the tl patch coordinates to be in between (0,1)
< tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
< rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
< for patch_nr in range(z.shape[-1])]
<
< # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
< patch_limits = [(x_tl, y_tl,
< rescale_latent * ks[0] / full_img_w,
< rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
< # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
<
< # tokenize crop coordinates for the bounding boxes of the respective patches
< patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
< for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
< print(patch_limits_tknzd[0].shape)
< # cut tknzd crop position from conditioning
< assert isinstance(cond, dict), 'cond must be dict to be fed into model'
< cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
< print(cut_cond.shape)
<
< adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
< adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
< print(adapted_cond.shape)
< adapted_cond = self.get_learned_conditioning(adapted_cond)
< print(adapted_cond.shape)
< adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
< print(adapted_cond.shape)
<
< cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
<
< else:
< cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
<
< # apply model by loop over crops
< output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
< assert not isinstance(output_list[0],
< tuple) # todo cant deal with multiple model outputs check this never happens
<
< o = torch.stack(output_list, axis=-1)
< o = o * weighting
< # Reverse reshape to img shape
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
< # stitch crops together
< x_recon = fold(o) / normalization
<
< else:
1023a895,896
> elif self.parameterization == "v":
> target = self.get_v(x_start, noise, t)
1236d1108
<
1248a1121,1144
> @torch.no_grad()
> def get_unconditional_conditioning(self, batch_size, null_label=None):
> if null_label is not None:
> xc = null_label
> if isinstance(xc, ListConfig):
> xc = list(xc)
> if isinstance(xc, dict) or isinstance(xc, list):
> c = self.get_learned_conditioning(xc)
> else:
> if hasattr(xc, "to"):
> xc = xc.to(self.device)
> c = self.get_learned_conditioning(xc)
> else:
> if self.cond_stage_key in ["class_label", "cls"]:
> xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
> return self.get_learned_conditioning(xc)
> else:
> raise NotImplementedError("todo")
> if isinstance(c, list): # in case the encoder gives us a list
> for i in range(len(c)):
> c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
> else:
> c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
> return c
1251c1147
< def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
---
> def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1253,1254c1149,1152
< plot_diffusion_rows=True, **kwargs):
<
---
> plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
> use_ema_scope=True,
> **kwargs):
> ema_scope = self.ema_scope if use_ema_scope else nullcontext
1271,1272c1169,1170
< elif self.cond_stage_key in ["caption"]:
< xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
---
> elif self.cond_stage_key in ["caption", "txt"]:
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1274,1275c1172,1174
< elif self.cond_stage_key == 'class_label':
< xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
---
> elif self.cond_stage_key in ['class_label', "cls"]:
> try:
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1276a1176,1178
> except KeyError:
> # probably no "human_label" in batch
> pass
1302c1204
< with self.ema_scope("Plotting"):
---
> with ema_scope("Sampling"):
1315c1217
< with self.ema_scope("Plotting Quantized Denoised"):
---
> with ema_scope("Plotting Quantized Denoised"):
1323a1226,1238
> if unconditional_guidance_scale > 1.0:
> uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
> if self.model.conditioning_key == "crossattn-adm":
> uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
> with ema_scope("Sampling with classifier-free guidance"):
> samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
> ddim_steps=ddim_steps, eta=ddim_eta,
> unconditional_guidance_scale=unconditional_guidance_scale,
> unconditional_conditioning=uc,
> )
> x_samples_cfg = self.decode_first_stage(samples_cfg)
> log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
>
1331,1332c1246
< with self.ema_scope("Plotting Inpaint"):
<
---
> with ema_scope("Plotting Inpaint"):
1340c1254,1255
< with self.ema_scope("Plotting Outpaint"):
---
> mask = 1. - mask
> with ema_scope("Plotting Outpaint"):
1347c1262
< with self.ema_scope("Plotting Progressives"):
---
> with ema_scope("Plotting Progressives"):
1397a1313
> self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1400c1316
< assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
---
> assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1402c1318
< def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
---
> def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1408a1325
> if not self.sequential_cross_attn:
1409a1327,1328
> else:
> cc = c_crossattn
1414a1334,1342
> elif self.conditioning_key == 'hybrid-adm':
> assert c_adm is not None
> xc = torch.cat([x] + c_concat, dim=1)
> cc = torch.cat(c_crossattn, 1)
> out = self.diffusion_model(xc, t, context=cc, y=c_adm)
> elif self.conditioning_key == 'crossattn-adm':
> assert c_adm is not None
> cc = torch.cat(c_crossattn, 1)
> out = self.diffusion_model(x, t, context=cc, y=c_adm)
1424,1445c1352,1795
< class Layout2ImgDiffusion(LatentDiffusion):
< # TODO: move all layout-specific hacks to this class
< def __init__(self, cond_stage_key, *args, **kwargs):
< assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
< super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
<
< def log_images(self, batch, N=8, *args, **kwargs):
< logs = super().log_images(batch=batch, N=N, *args, **kwargs)
<
< key = 'train' if self.training else 'validation'
< dset = self.trainer.datamodule.datasets[key]
< mapper = dset.conditional_builders[self.cond_stage_key]
<
< bbox_imgs = []
< map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
< for tknzd_bbox in batch[self.cond_stage_key][:N]:
< bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
< bbox_imgs.append(bboximg)
<
< cond_img = torch.stack(bbox_imgs, dim=0)
< logs['bbox_image'] = cond_img
< return logs
---
> class LatentUpscaleDiffusion(LatentDiffusion):
> def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
> super().__init__(*args, **kwargs)
> # assumes that neither the cond_stage nor the low_scale_model contain trainable params
> assert not self.cond_stage_trainable
> self.instantiate_low_stage(low_scale_config)
> self.low_scale_key = low_scale_key
> self.noise_level_key = noise_level_key
>
> def instantiate_low_stage(self, config):
> model = instantiate_from_config(config)
> self.low_scale_model = model.eval()
> self.low_scale_model.train = disabled_train
> for param in self.low_scale_model.parameters():
> param.requires_grad = False
>
> @torch.no_grad()
> def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
> if not log_mode:
> z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
> else:
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
> force_c_encode=True, return_original_cond=True, bs=bs)
> x_low = batch[self.low_scale_key][:bs]
> x_low = rearrange(x_low, 'b h w c -> b c h w')
> x_low = x_low.to(memory_format=torch.contiguous_format).float()
> zx, noise_level = self.low_scale_model(x_low)
> if self.noise_level_key is not None:
> # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
> raise NotImplementedError('TODO')
>
> all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
> if log_mode:
> # TODO: maybe disable if too expensive
> x_low_rec = self.low_scale_model.decode(zx)
> return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
> return z, all_conds
>
> @torch.no_grad()
> def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
> plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
> unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
> **kwargs):
> ema_scope = self.ema_scope if use_ema_scope else nullcontext
> use_ddim = ddim_steps is not None
>
> log = dict()
> z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
> log_mode=True)
> N = min(x.shape[0], N)
> n_row = min(x.shape[0], n_row)
> log["inputs"] = x
> log["reconstruction"] = xrec
> log["x_lr"] = x_low
> log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
> if self.model.conditioning_key is not None:
> if hasattr(self.cond_stage_model, "decode"):
> xc = self.cond_stage_model.decode(c)
> log["conditioning"] = xc
> elif self.cond_stage_key in ["caption", "txt"]:
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
> log["conditioning"] = xc
> elif self.cond_stage_key in ['class_label', 'cls']:
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
> log['conditioning'] = xc
> elif isimage(xc):
> log["conditioning"] = xc
> if ismap(xc):
> log["original_conditioning"] = self.to_rgb(xc)
>
> if plot_diffusion_rows:
> # get diffusion row
> diffusion_row = list()
> z_start = z[:n_row]
> for t in range(self.num_timesteps):
> if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
> t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
> t = t.to(self.device).long()
> noise = torch.randn_like(z_start)
> z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
> diffusion_row.append(self.decode_first_stage(z_noisy))
>
> diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
> diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
> diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
> diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
> log["diffusion_row"] = diffusion_grid
>
> if sample:
> # get denoise row
> with ema_scope("Sampling"):
> samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
> ddim_steps=ddim_steps, eta=ddim_eta)
> # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
> x_samples = self.decode_first_stage(samples)
> log["samples"] = x_samples
> if plot_denoise_rows:
> denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
> log["denoise_row"] = denoise_grid
>
> if unconditional_guidance_scale > 1.0:
> uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
> # TODO explore better "unconditional" choices for the other keys
> # maybe guide away from empty text label and highest noise level and maximally degraded zx?
> uc = dict()
> for k in c:
> if k == "c_crossattn":
> assert isinstance(c[k], list) and len(c[k]) == 1
> uc[k] = [uc_tmp]
> elif k == "c_adm": # todo: only run with text-based guidance?
> assert isinstance(c[k], torch.Tensor)
> #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
> uc[k] = c[k]
> elif isinstance(c[k], list):
> uc[k] = [c[k][i] for i in range(len(c[k]))]
> else:
> uc[k] = c[k]
>
> with ema_scope("Sampling with classifier-free guidance"):
> samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
> ddim_steps=ddim_steps, eta=ddim_eta,
> unconditional_guidance_scale=unconditional_guidance_scale,
> unconditional_conditioning=uc,
> )
> x_samples_cfg = self.decode_first_stage(samples_cfg)
> log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
>
> if plot_progressive_rows:
> with ema_scope("Plotting Progressives"):
> img, progressives = self.progressive_denoising(c,
> shape=(self.channels, self.image_size, self.image_size),
> batch_size=N)
> prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
> log["progressive_row"] = prog_row
>
> return log
>
>
> class LatentFinetuneDiffusion(LatentDiffusion):
> """
> Basis for different finetunas, such as inpainting or depth2image
> To disable finetuning mode, set finetune_keys to None
> """
>
> def __init__(self,
> concat_keys: tuple,
> finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
> "model_ema.diffusion_modelinput_blocks00weight"
> ),
> keep_finetune_dims=4,
> # if model was trained without concat mode before and we would like to keep these channels
> c_concat_log_start=None, # to log reconstruction of c_concat codes
> c_concat_log_end=None,
> *args, **kwargs
> ):
> ckpt_path = kwargs.pop("ckpt_path", None)
> ignore_keys = kwargs.pop("ignore_keys", list())
> super().__init__(*args, **kwargs)
> self.finetune_keys = finetune_keys
> self.concat_keys = concat_keys
> self.keep_dims = keep_finetune_dims
> self.c_concat_log_start = c_concat_log_start
> self.c_concat_log_end = c_concat_log_end
> if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
> if exists(ckpt_path):
> self.init_from_ckpt(ckpt_path, ignore_keys)
>
> def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
> sd = torch.load(path, map_location="cpu")
> if "state_dict" in list(sd.keys()):
> sd = sd["state_dict"]
> keys = list(sd.keys())
> for k in keys:
> for ik in ignore_keys:
> if k.startswith(ik):
> print("Deleting key {} from state_dict.".format(k))
> del sd[k]
>
> # make it explicit, finetune by including extra input channels
> if exists(self.finetune_keys) and k in self.finetune_keys:
> new_entry = None
> for name, param in self.named_parameters():
> if name in self.finetune_keys:
> print(
> f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
> new_entry = torch.zeros_like(param) # zero init
> assert exists(new_entry), 'did not find matching parameter to modify'
> new_entry[:, :self.keep_dims, ...] = sd[k]
> sd[k] = new_entry
>
> missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
> sd, strict=False)
> print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
> if len(missing) > 0:
> print(f"Missing Keys: {missing}")
> if len(unexpected) > 0:
> print(f"Unexpected Keys: {unexpected}")
>
> @torch.no_grad()
> def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
> quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
> plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
> use_ema_scope=True,
> **kwargs):
> ema_scope = self.ema_scope if use_ema_scope else nullcontext
> use_ddim = ddim_steps is not None
>
> log = dict()
> z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
> c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
> N = min(x.shape[0], N)
> n_row = min(x.shape[0], n_row)
> log["inputs"] = x
> log["reconstruction"] = xrec
> if self.model.conditioning_key is not None:
> if hasattr(self.cond_stage_model, "decode"):
> xc = self.cond_stage_model.decode(c)
> log["conditioning"] = xc
> elif self.cond_stage_key in ["caption", "txt"]:
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
> log["conditioning"] = xc
> elif self.cond_stage_key in ['class_label', 'cls']:
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
> log['conditioning'] = xc
> elif isimage(xc):
> log["conditioning"] = xc
> if ismap(xc):
> log["original_conditioning"] = self.to_rgb(xc)
>
> if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
> log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
>
> if plot_diffusion_rows:
> # get diffusion row
> diffusion_row = list()
> z_start = z[:n_row]
> for t in range(self.num_timesteps):
> if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
> t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
> t = t.to(self.device).long()
> noise = torch.randn_like(z_start)
> z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
> diffusion_row.append(self.decode_first_stage(z_noisy))
>
> diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
> diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
> diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
> diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
> log["diffusion_row"] = diffusion_grid
>
> if sample:
> # get denoise row
> with ema_scope("Sampling"):
> samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
> batch_size=N, ddim=use_ddim,
> ddim_steps=ddim_steps, eta=ddim_eta)
> # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
> x_samples = self.decode_first_stage(samples)
> log["samples"] = x_samples
> if plot_denoise_rows:
> denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
> log["denoise_row"] = denoise_grid
>
> if unconditional_guidance_scale > 1.0:
> uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
> uc_cat = c_cat
> uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
> with ema_scope("Sampling with classifier-free guidance"):
> samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
> batch_size=N, ddim=use_ddim,
> ddim_steps=ddim_steps, eta=ddim_eta,
> unconditional_guidance_scale=unconditional_guidance_scale,
> unconditional_conditioning=uc_full,
> )
> x_samples_cfg = self.decode_first_stage(samples_cfg)
> log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
>
> return log
>
>
> class LatentInpaintDiffusion(LatentFinetuneDiffusion):
> """
> can either run as pure inpainting model (only concat mode) or with mixed conditionings,
> e.g. mask as concat and text via cross-attn.
> To disable finetuning mode, set finetune_keys to None
> """
>
> def __init__(self,
> concat_keys=("mask", "masked_image"),
> masked_image_key="masked_image",
> *args, **kwargs
> ):
> super().__init__(concat_keys, *args, **kwargs)
> self.masked_image_key = masked_image_key
> assert self.masked_image_key in concat_keys
>
> @torch.no_grad()
> def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
> # note: restricted to non-trainable encoders currently
> assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
> force_c_encode=True, return_original_cond=True, bs=bs)
>
> assert exists(self.concat_keys)
> c_cat = list()
> for ck in self.concat_keys:
> cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
> if bs is not None:
> cc = cc[:bs]
> cc = cc.to(self.device)
> bchw = z.shape
> if ck != self.masked_image_key:
> cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
> else:
> cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
> c_cat.append(cc)
> c_cat = torch.cat(c_cat, dim=1)
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
> if return_first_stage_outputs:
> return z, all_conds, x, xrec, xc
> return z, all_conds
>
> @torch.no_grad()
> def log_images(self, *args, **kwargs):
> log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
> log["masked_image"] = rearrange(args[0]["masked_image"],
> 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
> return log
>
>
> class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
> """
> condition on monocular depth estimation
> """
>
> def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
> super().__init__(concat_keys=concat_keys, *args, **kwargs)
> self.depth_model = instantiate_from_config(depth_stage_config)
> self.depth_stage_key = concat_keys[0]
>
> @torch.no_grad()
> def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
> # note: restricted to non-trainable encoders currently
> assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
> force_c_encode=True, return_original_cond=True, bs=bs)
>
> assert exists(self.concat_keys)
> assert len(self.concat_keys) == 1
> c_cat = list()
> for ck in self.concat_keys:
> cc = batch[ck]
> if bs is not None:
> cc = cc[:bs]
> cc = cc.to(self.device)
> cc = self.depth_model(cc)
> cc = torch.nn.functional.interpolate(
> cc,
> size=z.shape[2:],
> mode="bicubic",
> align_corners=False,
> )
>
> depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
> keepdim=True)
> cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
> c_cat.append(cc)
> c_cat = torch.cat(c_cat, dim=1)
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
> if return_first_stage_outputs:
> return z, all_conds, x, xrec, xc
> return z, all_conds
>
> @torch.no_grad()
> def log_images(self, *args, **kwargs):
> log = super().log_images(*args, **kwargs)
> depth = self.depth_model(args[0][self.depth_stage_key])
> depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
> torch.amax(depth, dim=[1, 2, 3], keepdim=True)
> log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
> return log
>
>
> class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
> """
> condition on low-res image (and optionally on some spatial noise augmentation)
> """
> def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
> low_scale_config=None, low_scale_key=None, *args, **kwargs):
> super().__init__(concat_keys=concat_keys, *args, **kwargs)
> self.reshuffle_patch_size = reshuffle_patch_size
> self.low_scale_model = None
> if low_scale_config is not None:
> print("Initializing a low-scale model")
> assert exists(low_scale_key)
> self.instantiate_low_stage(low_scale_config)
> self.low_scale_key = low_scale_key
>
> def instantiate_low_stage(self, config):
> model = instantiate_from_config(config)
> self.low_scale_model = model.eval()
> self.low_scale_model.train = disabled_train
> for param in self.low_scale_model.parameters():
> param.requires_grad = False
>
> @torch.no_grad()
> def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
> # note: restricted to non-trainable encoders currently
> assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
> force_c_encode=True, return_original_cond=True, bs=bs)
>
> assert exists(self.concat_keys)
> assert len(self.concat_keys) == 1
> # optionally make spatial noise_level here
> c_cat = list()
> noise_level = None
> for ck in self.concat_keys:
> cc = batch[ck]
> cc = rearrange(cc, 'b h w c -> b c h w')
> if exists(self.reshuffle_patch_size):
> assert isinstance(self.reshuffle_patch_size, int)
> cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
> p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
> if bs is not None:
> cc = cc[:bs]
> cc = cc.to(self.device)
> if exists(self.low_scale_model) and ck == self.low_scale_key:
> cc, noise_level = self.low_scale_model(cc)
> c_cat.append(cc)
> c_cat = torch.cat(c_cat, dim=1)
> if exists(noise_level):
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
> else:
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
> if return_first_stage_outputs:
> return z, all_conds, x, xrec, xc
> return z, all_conds
>
> @torch.no_grad()
> def log_images(self, *args, **kwargs):
> log = super().log_images(*args, **kwargs)
> log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
> return log
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/dpm_solver/dpm_solver.py sd2/ldm/models/diffusion/dpm_solver/dpm_solver.py
3a4
> from tqdm import tqdm
16d16
<
21d20
<
25d23
<
29d26
<
31d27
<
33d28
<
35d29
<
37d30
<
39d31
<
44d35
<
48d38
<
50d39
<
58,59d46
<
<
61d47
<
64d49
<
71d55
<
73d56
<
81d63
<
83d64
<
86d66
<
89d68
<
92d70
<
96c74,76
< raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
---
> raise ValueError(
> "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
> schedule))
115c95,96
< self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
---
> self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
> 1. + self.cosine_s) / math.pi - self.cosine_s
130c111,112
< return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
---
> return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
> self.log_alpha_array.to(t.device)).reshape((-1))
168c150,151
< t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
---
> t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
> torch.flip(self.t_array.to(lamb.device), [1]))
172c155,156
< t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
---
> t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
> 1. + self.cosine_s) / math.pi - self.cosine_s
190d173
<
193d175
<
195d176
<
197d177
<
199d178
<
202d180
<
213d190
<
220d196
<
226d201
<
231d205
<
234d207
<
241d213
<
245d216
<
248d218
<
256d225
<
258d226
<
354d321
<
360d326
<
412d377
<
437c402,403
< raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
---
> raise ValueError(
> "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
442d407
<
456d420
<
495c459,460
< timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)]
---
> timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
> torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
507d471
<
551c515,516
< def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'):
---
> def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
> solver_type='dpm_solver'):
554d518
<
578c542,543
< log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
---
> log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
> s1), ns.marginal_log_mean_coeff(t)
603c568,569
< + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)
---
> + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
> model_s1 - model_s)
633c599,600
< def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'):
---
> def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
> return_intermediate=False, solver_type='dpm_solver'):
636d602
<
667,668c633,636
< log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
< sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
---
> log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
> s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
> sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
> s2), ns.marginal_std(t)
758d725
<
775c742,743
< lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
---
> lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
> t_prev_0), ns.marginal_lambda(t)
815d782
<
830c797,798
< lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
---
> lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
> t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
859c827,828
< def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None):
---
> def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
> r2=None):
862d830
<
879c847,848
< return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
---
> return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
> solver_type=solver_type, r1=r1)
881c850,851
< return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
---
> return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
> solver_type=solver_type, r1=r1, r2=r2)
888d857
<
909c878,879
< def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
---
> def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
> solver_type='dpm_solver'):
912d881
<
928d896
<
941c909,911
< higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
---
> higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
> solver_type=solver_type,
> **kwargs)
944,945c914,919
< lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
< higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
---
> lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
> return_intermediate=True,
> solver_type=solver_type)
> higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
> solver_type=solver_type,
> **kwargs)
971d944
<
973d945
<
1012d983
<
1014d984
<
1028d997
<
1033d1001
<
1053d1020
<
1069d1035
<
1076c1042,1043
< x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
---
> x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
> solver_type=solver_type)
1086c1053
< for init_order in range(1, order):
---
> for init_order in tqdm(range(1, order), desc="DPM init order"):
1088c1055,1056
< x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
---
> x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
> solver_type=solver_type)
1092c1060
< for step in range(order, steps + 1):
---
> for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1098c1066,1067
< x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type)
---
> x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
> solver_type=solver_type)
1108c1077,1080
< timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
---
> timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
> skip_type=skip_type,
> t_T=t_T, t_0=t_0,
> device=device)
1115c1087,1088
< timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device)
---
> timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
> N=order, device=device)
1127d1099
<
1137d1108
<
1177d1147
<
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/dpm_solver/sampler.py sd2/ldm/models/diffusion/dpm_solver/sampler.py
2d1
<
7a7,12
> MODEL_TYPES = {
> "eps": "noise",
> "v": "v"
> }
>
>
59c64
< # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
---
> print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
72c77
< model_type="noise",
---
> model_type=MODEL_TYPES[self.model.parameterization],
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/plms.py sd2/ldm/models/diffusion/plms.py
8a9
> from ldm.models.diffusion.sampling_util import norm_thresholding
79a81
> dynamic_threshold=None,
110a113
> dynamic_threshold=dynamic_threshold,
120c123,124
< unconditional_guidance_scale=1., unconditional_conditioning=None,):
---
> unconditional_guidance_scale=1., unconditional_conditioning=None,
> dynamic_threshold=None):
158c162,163
< old_eps=old_eps, t_next=ts_next)
---
> old_eps=old_eps, t_next=ts_next,
> dynamic_threshold=dynamic_threshold)
175c180,181
< unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
---
> unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
> dynamic_threshold=None):
209a216,217
> if dynamic_threshold is not None:
> pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
Only in sd2/ldm/models/diffusion: sampling_util.py
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/attention.py sd2/ldm/modules/attention.py
6a7
> from typing import Optional, Any
10a12,19
> try:
> import xformers
> import xformers.ops
> XFORMERS_IS_AVAILBLE = True
> except:
> XFORMERS_IS_AVAILBLE = False
>
>
80,98d88
< class LinearAttention(nn.Module):
< def __init__(self, dim, heads=4, dim_head=32):
< super().__init__()
< self.heads = heads
< hidden_dim = dim_head * heads
< self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
< self.to_out = nn.Conv2d(hidden_dim, dim, 1)
<
< def forward(self, x):
< b, c, h, w = x.shape
< qkv = self.to_qkv(x)
< q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
< k = k.softmax(dim=-1)
< context = torch.einsum('bhdn,bhen->bhde', k, v)
< out = torch.einsum('bhde,bhdn->bhen', context, q)
< out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
< return self.to_out(out)
<
<
180a171
> del q, k
189c180
< attn = sim.softmax(dim=-1)
---
> sim = sim.softmax(dim=-1)
191c182
< out = einsum('b i j, b j d -> b i d', attn, v)
---
> out = einsum('b i j, b j d -> b i d', sim, v)
195a187,235
> class MemoryEfficientCrossAttention(nn.Module):
> # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
> def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
> super().__init__()
> print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
> f"{heads} heads.")
> inner_dim = dim_head * heads
> context_dim = default(context_dim, query_dim)
>
> self.heads = heads
> self.dim_head = dim_head
>
> self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
> self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
> self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
>
> self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
> self.attention_op: Optional[Any] = None
>
> def forward(self, x, context=None, mask=None):
> q = self.to_q(x)
> context = default(context, x)
> k = self.to_k(context)
> v = self.to_v(context)
>
> b, _, _ = q.shape
> q, k, v = map(
> lambda t: t.unsqueeze(3)
> .reshape(b, t.shape[1], self.heads, self.dim_head)
> .permute(0, 2, 1, 3)
> .reshape(b * self.heads, t.shape[1], self.dim_head)
> .contiguous(),
> (q, k, v),
> )
>
> # actually compute the attention, what we cannot get enough of
> out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
>
> if exists(mask):
> raise NotImplementedError
> out = (
> out.unsqueeze(0)
> .reshape(b, self.heads, out.shape[1], self.dim_head)
> .permute(0, 2, 1, 3)
> .reshape(b, out.shape[1], self.heads * self.dim_head)
> )
> return self.to_out(out)
>
>
197c237,242
< def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
---
> ATTENTION_MODES = {
> "softmax": CrossAttention, # vanilla attention
> "softmax-xformers": MemoryEfficientCrossAttention
> }
> def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
> disable_self_attn=False):
199c244,249
< self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
---
> attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
> assert attn_mode in self.ATTENTION_MODES
> attn_cls = self.ATTENTION_MODES[attn_mode]
> self.disable_self_attn = disable_self_attn
> self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
> context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
201c251
< self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
---
> self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
212c262
< x = self.attn1(self.norm1(x)) + x
---
> x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
224a275
> NEW: use_linear for more efficiency instead of the 1x1 convs
227c278,280
< depth=1, dropout=0., context_dim=None):
---
> depth=1, dropout=0., context_dim=None,
> disable_self_attn=False, use_linear=False,
> use_checkpoint=True):
228a282,283
> if exists(context_dim) and not isinstance(context_dim, list):
> context_dim = [context_dim]
232c287
<
---
> if not use_linear:
237a293,294
> else:
> self.proj_in = nn.Linear(in_channels, inner_dim)
240c297,298
< [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
---
> [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
> disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
243c301
<
---
> if not use_linear:
248a307,309
> else:
> self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
> self.use_linear = use_linear
251a313,314
> if not isinstance(context, list):
> context = [context]
254a318,321
> if not self.use_linear:
> x = self.proj_in(x)
> x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
> if self.use_linear:
256,259c323,328
< x = rearrange(x, 'b c h w -> b (h w) c')
< for block in self.transformer_blocks:
< x = block(x, context=context)
< x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
---
> for i, block in enumerate(self.transformer_blocks):
> x = block(x, context=context[i])
> if self.use_linear:
> x = self.proj_out(x)
> x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
> if not self.use_linear:
261a331
>
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/diffusionmodules/model.py sd2/ldm/modules/diffusionmodules/model.py
6a7
> from typing import Optional, Any
8,9c9,17
< from ldm.util import instantiate_from_config
< from ldm.modules.attention import LinearAttention
---
> from ldm.modules.attention import MemoryEfficientCrossAttention
>
> try:
> import xformers
> import xformers.ops
> XFORMERS_IS_AVAILBLE = True
> except:
> XFORMERS_IS_AVAILBLE = False
> print("No module 'xformers'. Proceeding without it.")
144,149d151
< class LinAttnBlock(LinearAttention):
< """to match AttnBlock usage"""
< def __init__(self, in_channels):
< super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
<
<
177d178
<
203a205,258
> class MemoryEfficientAttnBlock(nn.Module):
> """
> Uses xformers efficient implementation,
> see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
> Note: this is a single-head self-attention operation
> """
> #
> def __init__(self, in_channels):
> super().__init__()
> self.in_channels = in_channels
>
> self.norm = Normalize(in_channels)
> self.q = torch.nn.Conv2d(in_channels,
> in_channels,
> kernel_size=1,
> stride=1,
> padding=0)
> self.k = torch.nn.Conv2d(in_channels,
> in_channels,
> kernel_size=1,
> stride=1,
> padding=0)
> self.v = torch.nn.Conv2d(in_channels,
> in_channels,
> kernel_size=1,
> stride=1,
> padding=0)
> self.proj_out = torch.nn.Conv2d(in_channels,
> in_channels,
> kernel_size=1,
> stride=1,
> padding=0)
> self.attention_op: Optional[Any] = None
>
> def forward(self, x):
> h_ = x
> h_ = self.norm(h_)
> q = self.q(h_)
> k = self.k(h_)
> v = self.v(h_)
>
> # compute attention
> B, C, H, W = q.shape
> q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
>
> q, k, v = map(
> lambda t: t.unsqueeze(3)
> .reshape(B, t.shape[1], 1, C)
> .permute(0, 2, 1, 3)
> .reshape(B * 1, t.shape[1], C)
> .contiguous(),
> (q, k, v),
> )
> out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
205,206c260,283
< def make_attn(in_channels, attn_type="vanilla"):
< assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
---
> out = (
> out.unsqueeze(0)
> .reshape(B, 1, out.shape[1], C)
> .permute(0, 2, 1, 3)
> .reshape(B, out.shape[1], C)
> )
> out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
> out = self.proj_out(out)
> return x+out
>
>
> class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
> def forward(self, x, context=None, mask=None):
> b, c, h, w = x.shape
> x = rearrange(x, 'b c h w -> b (h w) c')
> out = super().forward(x, context=context, mask=mask)
> out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
> return x + out
>
>
> def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
> assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
> if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
> attn_type = "vanilla-xformers"
208a286
> assert attn_kwargs is None
209a288,293
> elif attn_type == "vanilla-xformers":
> print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
> return MemoryEfficientAttnBlock(in_channels)
> elif type == "memory-efficient-cross-attn":
> attn_kwargs["query_dim"] = in_channels
> return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
213c297
< return LinAttnBlock(in_channels)
---
> raise NotImplementedError()
769,835d852
<
< class FirstStagePostProcessor(nn.Module):
<
< def __init__(self, ch_mult:list, in_channels,
< pretrained_model:nn.Module=None,
< reshape=False,
< n_channels=None,
< dropout=0.,
< pretrained_config=None):
< super().__init__()
< if pretrained_config is None:
< assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
< self.pretrained_model = pretrained_model
< else:
< assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
< self.instantiate_pretrained(pretrained_config)
<
< self.do_reshape = reshape
<
< if n_channels is None:
< n_channels = self.pretrained_model.encoder.ch
<
< self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
< self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
< stride=1,padding=1)
<
< blocks = []
< downs = []
< ch_in = n_channels
< for m in ch_mult:
< blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
< ch_in = m * n_channels
< downs.append(Downsample(ch_in, with_conv=False))
<
< self.model = nn.ModuleList(blocks)
< self.downsampler = nn.ModuleList(downs)
<
<
< def instantiate_pretrained(self, config):
< model = instantiate_from_config(config)
< self.pretrained_model = model.eval()
< # self.pretrained_model.train = False
< for param in self.pretrained_model.parameters():
< param.requires_grad = False
<
<
< @torch.no_grad()
< def encode_with_pretrained(self,x):
< c = self.pretrained_model.encode(x)
< if isinstance(c, DiagonalGaussianDistribution):
< c = c.mode()
< return c
<
< def forward(self,x):
< z_fs = self.encode_with_pretrained(x)
< z = self.proj_norm(z_fs)
< z = self.proj(z)
< z = nonlinearity(z)
<
< for submodel, downmodel in zip(self.model,self.downsampler):
< z = submodel(z,temb=None)
< z = downmodel(z)
<
< if self.do_reshape:
< z = rearrange(z,'b c h w -> b (h w) c')
< return z
<
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/diffusionmodules/openaimodel.py sd2/ldm/modules/diffusionmodules/openaimodel.py
2d1
< from functools import partial
4d2
< from typing import Iterable
20a19
> from ldm.util import exists
468a468,471
> disable_self_attentions=None,
> num_attention_blocks=None,
> disable_middle_self_attn=False,
> use_linear_in_transformer=False,
492a496,501
> if isinstance(num_res_blocks, int):
> self.num_res_blocks = len(channel_mult) * [num_res_blocks]
> else:
> if len(num_res_blocks) != len(channel_mult):
> raise ValueError("provide num_res_blocks either as an int (globally constant) or "
> "as a list/tuple (per-level) with the same length as channel_mult")
493a503,513
> if disable_self_attentions is not None:
> # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
> assert len(disable_self_attentions) == len(channel_mult)
> if num_attention_blocks is not None:
> assert len(num_attention_blocks) == len(self.num_res_blocks)
> assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
> print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
> f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
> f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
> f"attention will still not be set.")
>
513a534
> if isinstance(self.num_classes, int):
514a536,540
> elif self.num_classes == "continuous":
> print("setting up linear c_adm embedding layer")
> self.label_emb = nn.Linear(1, time_embed_dim)
> else:
> raise ValueError()
528c554
< for _ in range(num_res_blocks):
---
> for nr in range(self.num_res_blocks[level]):
549a576,581
> if exists(disable_self_attentions):
> disabled_sa = disable_self_attentions[level]
> else:
> disabled_sa = False
>
> if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
558c590,592
< ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
---
> ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
> disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
> use_checkpoint=use_checkpoint
612,613c646,649
< ) if not use_spatial_transformer else SpatialTransformer(
< ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
---
> ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
> ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
> disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
> use_checkpoint=use_checkpoint
628c664
< for i in range(num_res_blocks + 1):
---
> for i in range(self.num_res_blocks[level] + 1):
650a687,692
> if exists(disable_self_attentions):
> disabled_sa = disable_self_attentions[level]
> else:
> disabled_sa = False
>
> if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
659c701,703
< ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
---
> ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
> disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
> use_checkpoint=use_checkpoint
662c706
< if level and i == num_res_blocks:
---
> if level and i == self.num_res_blocks[level]:
727c771
< assert y.shape == (x.shape[0],)
---
> assert y.shape[0] == x.shape[0]
743,961d786
<
<
< class EncoderUNetModel(nn.Module):
< """
< The half UNet model with attention and timestep embedding.
< For usage, see UNet.
< """
<
< def __init__(
< self,
< image_size,
< in_channels,
< model_channels,
< out_channels,
< num_res_blocks,
< attention_resolutions,
< dropout=0,
< channel_mult=(1, 2, 4, 8),
< conv_resample=True,
< dims=2,
< use_checkpoint=False,
< use_fp16=False,
< num_heads=1,
< num_head_channels=-1,
< num_heads_upsample=-1,
< use_scale_shift_norm=False,
< resblock_updown=False,
< use_new_attention_order=False,
< pool="adaptive",
< *args,
< **kwargs
< ):
< super().__init__()
<
< if num_heads_upsample == -1:
< num_heads_upsample = num_heads
<
< self.in_channels = in_channels
< self.model_channels = model_channels
< self.out_channels = out_channels
< self.num_res_blocks = num_res_blocks
< self.attention_resolutions = attention_resolutions
< self.dropout = dropout
< self.channel_mult = channel_mult
< self.conv_resample = conv_resample
< self.use_checkpoint = use_checkpoint
< self.dtype = th.float16 if use_fp16 else th.float32
< self.num_heads = num_heads
< self.num_head_channels = num_head_channels
< self.num_heads_upsample = num_heads_upsample
<
< time_embed_dim = model_channels * 4
< self.time_embed = nn.Sequential(
< linear(model_channels, time_embed_dim),
< nn.SiLU(),
< linear(time_embed_dim, time_embed_dim),
< )
<
< self.input_blocks = nn.ModuleList(
< [
< TimestepEmbedSequential(
< conv_nd(dims, in_channels, model_channels, 3, padding=1)
< )
< ]
< )
< self._feature_size = model_channels
< input_block_chans = [model_channels]
< ch = model_channels
< ds = 1
< for level, mult in enumerate(channel_mult):
< for _ in range(num_res_blocks):
< layers = [
< ResBlock(
< ch,
< time_embed_dim,
< dropout,
< out_channels=mult * model_channels,
< dims=dims,
< use_checkpoint=use_checkpoint,
< use_scale_shift_norm=use_scale_shift_norm,
< )
< ]
< ch = mult * model_channels
< if ds in attention_resolutions:
< layers.append(
< AttentionBlock(
< ch,
< use_checkpoint=use_checkpoint,
< num_heads=num_heads,
< num_head_channels=num_head_channels,
< use_new_attention_order=use_new_attention_order,
< )
< )
< self.input_blocks.append(TimestepEmbedSequential(*layers))
< self._feature_size += ch
< input_block_chans.append(ch)
< if level != len(channel_mult) - 1:
< out_ch = ch
< self.input_blocks.append(
< TimestepEmbedSequential(
< ResBlock(
< ch,
< time_embed_dim,
< dropout,
< out_channels=out_ch,
< dims=dims,
< use_checkpoint=use_checkpoint,
< use_scale_shift_norm=use_scale_shift_norm,
< down=True,
< )
< if resblock_updown
< else Downsample(
< ch, conv_resample, dims=dims, out_channels=out_ch
< )
< )
< )
< ch = out_ch
< input_block_chans.append(ch)
< ds *= 2
< self._feature_size += ch
<
< self.middle_block = TimestepEmbedSequential(
< ResBlock(
< ch,
< time_embed_dim,
< dropout,
< dims=dims,
< use_checkpoint=use_checkpoint,
< use_scale_shift_norm=use_scale_shift_norm,
< ),
< AttentionBlock(
< ch,
< use_checkpoint=use_checkpoint,
< num_heads=num_heads,
< num_head_channels=num_head_channels,
< use_new_attention_order=use_new_attention_order,
< ),
< ResBlock(
< ch,
< time_embed_dim,
< dropout,
< dims=dims,
< use_checkpoint=use_checkpoint,
< use_scale_shift_norm=use_scale_shift_norm,
< ),
< )
< self._feature_size += ch
< self.pool = pool
< if pool == "adaptive":
< self.out = nn.Sequential(
< normalization(ch),
< nn.SiLU(),
< nn.AdaptiveAvgPool2d((1, 1)),
< zero_module(conv_nd(dims, ch, out_channels, 1)),
< nn.Flatten(),
< )
< elif pool == "attention":
< assert num_head_channels != -1
< self.out = nn.Sequential(
< normalization(ch),
< nn.SiLU(),
< AttentionPool2d(
< (image_size // ds), ch, num_head_channels, out_channels
< ),
< )
< elif pool == "spatial":
< self.out = nn.Sequential(
< nn.Linear(self._feature_size, 2048),
< nn.ReLU(),
< nn.Linear(2048, self.out_channels),
< )
< elif pool == "spatial_v2":
< self.out = nn.Sequential(
< nn.Linear(self._feature_size, 2048),
< normalization(2048),
< nn.SiLU(),
< nn.Linear(2048, self.out_channels),
< )
< else:
< raise NotImplementedError(f"Unexpected {pool} pooling")
<
< def convert_to_fp16(self):
< """
< Convert the torso of the model to float16.
< """
< self.input_blocks.apply(convert_module_to_f16)
< self.middle_block.apply(convert_module_to_f16)
<
< def convert_to_fp32(self):
< """
< Convert the torso of the model to float32.
< """
< self.input_blocks.apply(convert_module_to_f32)
< self.middle_block.apply(convert_module_to_f32)
<
< def forward(self, x, timesteps):
< """
< Apply the model to an input batch.
< :param x: an [N x C x ...] Tensor of inputs.
< :param timesteps: a 1-D batch of timesteps.
< :return: an [N x K] Tensor of outputs.
< """
< emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
<
< results = []
< h = x.type(self.dtype)
< for module in self.input_blocks:
< h = module(h, emb)
< if self.pool.startswith("spatial"):
< results.append(h.type(x.dtype).mean(dim=(2, 3)))
< h = self.middle_block(h, emb)
< if self.pool.startswith("spatial"):
< results.append(h.type(x.dtype).mean(dim=(2, 3)))
< h = th.cat(results, axis=-1)
< return self.out(h)
< else:
< h = h.type(x.dtype)
< return self.out(h)
<
Only in sd2/ldm/modules/diffusionmodules: upscaling.py
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/diffusionmodules/util.py sd2/ldm/modules/diffusionmodules/util.py
125c125,127
<
---
> ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
> "dtype": torch.get_autocast_gpu_dtype(),
> "cache_enabled": torch.is_autocast_cache_enabled()}
133c135,136
< with torch.enable_grad():
---
> with torch.enable_grad(), \
> torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/ema.py sd2/ldm/modules/ema.py
24a25,28
> def reset_num_updates(self):
> del self.num_updates
> self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
>
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/encoders/modules.py sd2/ldm/modules/encoders/modules.py
3,7c3
< from functools import partial
< import clip
< from einops import rearrange, repeat
< from transformers import CLIPTokenizer, CLIPTextModel
< import kornia
---
> from torch.utils.checkpoint import checkpoint
9c5,8
< from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
---
> from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
>
> import open_clip
> from ldm.util import default, count_params
19a19,23
> class IdentityEncoder(AbstractEncoder):
>
> def encode(self, x):
> return x
>
22c26
< def __init__(self, embed_dim, n_classes=1000, key='class'):
---
> def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
25a30,31
> self.n_classes = n_classes
> self.ucg_rate = ucg_rate
27c33
< def forward(self, batch, key=None):
---
> def forward(self, batch, key=None, disable_dropout=False):
31a38,41
> if self.ucg_rate > 0. and not disable_dropout:
> mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
> c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
> c = c.long()
34a45,49
> def get_unconditional_conditioning(self, bs, device="cuda"):
> uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
> uc = torch.ones((bs,), device=device) * uc_class
> uc = {self.key: uc}
> return uc
36,42d50
< class TransformerEmbedder(AbstractEncoder):
< """Some transformer encoder layers"""
< def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
< super().__init__()
< self.device = device
< self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
< attn_layers=Encoder(dim=n_embed, depth=n_layer))
44,50c52,55
< def forward(self, tokens):
< tokens = tokens.to(self.device) # meh
< z = self.transformer(tokens, return_embeddings=True)
< return z
<
< def encode(self, x):
< return self(x)
---
> def disabled_train(self, mode=True):
> """Overwrite model.train with this function to make sure train/eval mode
> does not change anymore."""
> return self
53,55c58,60
< class BERTTokenizer(AbstractEncoder):
< """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
< def __init__(self, device="cuda", vq_interface=True, max_length=77):
---
> class FrozenT5Embedder(AbstractEncoder):
> """Uses the T5 transformer encoder for text"""
> def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
57,58c62,63
< from transformers import BertTokenizerFast # TODO: add to reuquirements
< self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
---
> self.tokenizer = T5Tokenizer.from_pretrained(version)
> self.transformer = T5EncoderModel.from_pretrained(version)
60,61c65,73
< self.vq_interface = vq_interface
< self.max_length = max_length
---
> self.max_length = max_length # TODO: typical value?
> if freeze:
> self.freeze()
>
> def freeze(self):
> self.transformer = self.transformer.eval()
> #self.train = disabled_train
> for param in self.parameters():
> param.requires_grad = False
67,91c79
< return tokens
<
< @torch.no_grad()
< def encode(self, text):
< tokens = self(text)
< if not self.vq_interface:
< return tokens
< return None, None, [None, None, tokens]
<
< def decode(self, text):
< return text
<
<
< class BERTEmbedder(AbstractEncoder):
< """Uses the BERT tokenizr model and add some transformer encoder layers"""
< def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
< device="cuda",use_tokenizer=True, embedding_dropout=0.0):
< super().__init__()
< self.use_tknz_fn = use_tokenizer
< if self.use_tknz_fn:
< self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
< self.device = device
< self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
< attn_layers=Encoder(dim=n_embed, depth=n_layer),
< emb_dropout=embedding_dropout)
---
> outputs = self.transformer(input_ids=tokens)
93,98c81
< def forward(self, text):
< if self.use_tknz_fn:
< tokens = self.tknz_fn(text)#.to(self.device)
< else:
< tokens = text
< z = self.transformer(tokens, return_embeddings=True)
---
> z = outputs.last_hidden_state
102d84
< # output of length 77
106,136d87
< class SpatialRescaler(nn.Module):
< def __init__(self,
< n_stages=1,
< method='bilinear',
< multiplier=0.5,
< in_channels=3,
< out_channels=None,
< bias=False):
< super().__init__()
< self.n_stages = n_stages
< assert self.n_stages >= 0
< assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
< self.multiplier = multiplier
< self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
< self.remap_output = out_channels is not None
< if self.remap_output:
< print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
< self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
<
< def forward(self,x):
< for stage in range(self.n_stages):
< x = self.interpolator(x, scale_factor=self.multiplier)
<
<
< if self.remap_output:
< x = self.channel_mapper(x)
< return x
<
< def encode(self, x):
< return self(x)
<
138,139c89,96
< """Uses the CLIP transformer encoder for text (from Hugging Face)"""
< def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
---
> """Uses the CLIP transformer encoder for text (from huggingface)"""
> LAYERS = [
> "last",
> "pooled",
> "hidden"
> ]
> def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
> freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
140a98
> assert layer in self.LAYERS
144a103
> if freeze:
145a105,109
> self.layer = layer
> self.layer_idx = layer_idx
> if layer == "hidden":
> assert layer_idx is not None
> assert 0 <= abs(layer_idx) <= 12
148a113
> #self.train = disabled_train
156,157c121,122
< outputs = self.transformer(input_ids=tokens)
<
---
> outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
> if self.layer == "last":
158a124,127
> elif self.layer == "pooled":
> z = outputs.pooler_output[:, None, :]
> else:
> z = outputs.hidden_states[self.layer_idx]
165c134
< class FrozenCLIPTextEmbedder(nn.Module):
---
> class FrozenOpenCLIPEmbedder(AbstractEncoder):
167c136
< Uses the CLIP transformer encoder for text.
---
> Uses the OpenCLIP transformer encoder for text
169,171c138,150
< def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
< super().__init__()
< self.model, _ = clip.load(version, jit=False, device="cpu")
---
> LAYERS = [
> #"pooled",
> "last",
> "penultimate"
> ]
> def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
> freeze=True, layer="last"):
> super().__init__()
> assert layer in self.LAYERS
> model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
> del model.visual
> self.model = model
>
174,175c153,161
< self.n_repeat = n_repeat
< self.normalize = normalize
---
> if freeze:
> self.freeze()
> self.layer = layer
> if self.layer == "last":
> self.layer_idx = 0
> elif self.layer == "penultimate":
> self.layer_idx = 1
> else:
> raise NotImplementedError()
183,186c169,170
< tokens = clip.tokenize(text).to(self.device)
< z = self.model.encode_text(tokens)
< if self.normalize:
< z = z / torch.linalg.norm(z, dim=1, keepdim=True)
---
> tokens = open_clip.tokenize(text)
> z = self.encode_with_transformer(tokens.to(self.device))
188a173,191
> def encode_with_transformer(self, text):
> x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
> x = x + self.model.positional_embedding
> x = x.permute(1, 0, 2) # NLD -> LND
> x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
> x = x.permute(1, 0, 2) # LND -> NLD
> x = self.model.ln_final(x)
> return x
>
> def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
> for i, r in enumerate(self.model.transformer.resblocks):
> if i == len(self.model.transformer.resblocks) - self.layer_idx:
> break
> if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
> x = checkpoint(r, x, attn_mask)
> else:
> x = r(x, attn_mask=attn_mask)
> return x
>
190,194c193
< z = self(text)
< if z.ndim==2:
< z = z[:, None, :]
< z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
< return z
---
> return self(text)
197,224c196,206
< class FrozenClipImageEmbedder(nn.Module):
< """
< Uses the CLIP image encoder.
< """
< def __init__(
< self,
< model,
< jit=False,
< device='cuda' if torch.cuda.is_available() else 'cpu',
< antialias=False,
< ):
< super().__init__()
< self.model, _ = clip.load(name=model, device=device, jit=jit)
<
< self.antialias = antialias
<
< self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
< self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
<
< def preprocess(self, x):
< # normalize to [0,1]
< x = kornia.geometry.resize(x, (224, 224),
< interpolation='bicubic',align_corners=True,
< antialias=self.antialias)
< x = (x + 1.) / 2.
< # renormalize according to clip
< x = kornia.enhance.normalize(x, self.mean, self.std)
< return x
---
> class FrozenCLIPT5Encoder(AbstractEncoder):
> def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
> clip_max_length=77, t5_max_length=77):
> super().__init__()
> self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
> self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
> print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
> f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
>
> def encode(self, text):
> return self(text)
226,228c208,211
< def forward(self, x):
< # x is assumed to be in range [-1,1]
< return self.model.encode_image(self.preprocess(x))
---
> def forward(self, text):
> clip_z = self.clip_encoder.encode(text)
> t5_z = self.t5_encoder.encode(text)
> return [clip_z, t5_z]
231,234d213
< if __name__ == "__main__":
< from ldm.util import count_params
< model = FrozenCLIPEmbedder()
< count_params(model, verbose=True)
\ No newline at end of file
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/image_degradation/bsrgan_light.py sd2/ldm/modules/image_degradation/bsrgan_light.py
28d27
<
257c256
< x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
---
> x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
280c279
< x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
---
> x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
293c292
< x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
---
> x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
338c337
< img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
---
> img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
500c499
< img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
---
> img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
534c533
< def degradation_bsrgan_variant(image, sf=4, isp_model=None):
---
> def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
592c591
< image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
---
> image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
619a619,620
> if up:
> image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
Only in sd1/ldm/modules: losses
Only in sd2/ldm/modules: midas
Only in sd1/ldm/modules: x_transformer.py
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/util.py sd2/ldm/util.py
3a4
> from torch import optim
5,11d5
< from collections import abc
< from einops import rearrange
< from functools import partial
<
< import multiprocessing as mp
< from threading import Thread
< from queue import Queue
96,169c90,195
< def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
< # create dummy dataset instance
<
< # run prefetching
< if idx_to_fn:
< res = func(data, worker_id=idx)
< else:
< res = func(data)
< Q.put([idx, res])
< Q.put("Done")
<
<
< def parallel_data_prefetch(
< func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
< ):
< # if target_data_type not in ["ndarray", "list"]:
< # raise ValueError(
< # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
< # )
< if isinstance(data, np.ndarray) and target_data_type == "list":
< raise ValueError("list expected but function got ndarray.")
< elif isinstance(data, abc.Iterable):
< if isinstance(data, dict):
< print(
< f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
< )
< data = list(data.values())
< if target_data_type == "ndarray":
< data = np.asarray(data)
< else:
< data = list(data)
< else:
< raise TypeError(
< f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
< )
<
< if cpu_intensive:
< Q = mp.Queue(1000)
< proc = mp.Process
< else:
< Q = Queue(1000)
< proc = Thread
< # spawn processes
< if target_data_type == "ndarray":
< arguments = [
< [func, Q, part, i, use_worker_id]
< for i, part in enumerate(np.array_split(data, n_proc))
< ]
< else:
< step = (
< int(len(data) / n_proc + 1)
< if len(data) % n_proc != 0
< else int(len(data) / n_proc)
< )
< arguments = [
< [func, Q, part, i, use_worker_id]
< for i, part in enumerate(
< [data[i: i + step] for i in range(0, len(data), step)]
< )
< ]
< processes = []
< for i in range(n_proc):
< p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
< processes += [p]
<
< # start processes
< print(f"Start prefetching...")
< import time
<
< start = time.time()
< gather_res = [[] for _ in range(n_proc)]
< try:
< for p in processes:
< p.start()
---
> class AdamWwithEMAandWings(optim.Optimizer):
> # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
> def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
> weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
> ema_power=1., param_names=()):
> """AdamW that saves EMA versions of the parameters."""
> if not 0.0 <= lr:
> raise ValueError("Invalid learning rate: {}".format(lr))
> if not 0.0 <= eps:
> raise ValueError("Invalid epsilon value: {}".format(eps))
> if not 0.0 <= betas[0] < 1.0:
> raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
> if not 0.0 <= betas[1] < 1.0:
> raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
> if not 0.0 <= weight_decay:
> raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
> if not 0.0 <= ema_decay <= 1.0:
> raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
> defaults = dict(lr=lr, betas=betas, eps=eps,
> weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
> ema_power=ema_power, param_names=param_names)
> super().__init__(params, defaults)
>
> def __setstate__(self, state):
> super().__setstate__(state)
> for group in self.param_groups:
> group.setdefault('amsgrad', False)
>
> @torch.no_grad()
> def step(self, closure=None):
> """Performs a single optimization step.
> Args:
> closure (callable, optional): A closure that reevaluates the model
> and returns the loss.
> """
> loss = None
> if closure is not None:
> with torch.enable_grad():
> loss = closure()
>
> for group in self.param_groups:
> params_with_grad = []
> grads = []
> exp_avgs = []
> exp_avg_sqs = []
> ema_params_with_grad = []
> state_sums = []
> max_exp_avg_sqs = []
> state_steps = []
> amsgrad = group['amsgrad']
> beta1, beta2 = group['betas']
> ema_decay = group['ema_decay']
> ema_power = group['ema_power']
>
> for p in group['params']:
> if p.grad is None:
> continue
> params_with_grad.append(p)
> if p.grad.is_sparse:
> raise RuntimeError('AdamW does not support sparse gradients')
> grads.append(p.grad)
>
> state = self.state[p]
>
> # State initialization
> if len(state) == 0:
> state['step'] = 0
> # Exponential moving average of gradient values
> state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
> # Exponential moving average of squared gradient values
> state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
> if amsgrad:
> # Maintains max of all exp. moving avg. of sq. grad. values
> state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
> # Exponential moving average of parameter values
> state['param_exp_avg'] = p.detach().float().clone()
>
> exp_avgs.append(state['exp_avg'])
> exp_avg_sqs.append(state['exp_avg_sq'])
> ema_params_with_grad.append(state['param_exp_avg'])
>
> if amsgrad:
> max_exp_avg_sqs.append(state['max_exp_avg_sq'])
>
> # update the steps for each param group update
> state['step'] += 1
> # record the step after step update
> state_steps.append(state['step'])
>
> optim._functional.adamw(params_with_grad,
> grads,
> exp_avgs,
> exp_avg_sqs,
> max_exp_avg_sqs,
> state_steps,
> amsgrad=amsgrad,
> beta1=beta1,
> beta2=beta2,
> lr=group['lr'],
> weight_decay=group['weight_decay'],
> eps=group['eps'],
> maximize=False)
>
> cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
> for param, ema_param in zip(params_with_grad, ema_params_with_grad):
> ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
171,203c197
< k = 0
< while k < n_proc:
< # get result
< res = Q.get()
< if res == "Done":
< k += 1
< else:
< gather_res[res[0]] = res[1]
<
< except Exception as e:
< print("Exception: ", e)
< for p in processes:
< p.terminate()
<
< raise e
< finally:
< for p in processes:
< p.join()
< print(f"Prefetching complete. [{time.time() - start} sec.]")
<
< if target_data_type == 'ndarray':
< if not isinstance(gather_res[0], np.ndarray):
< return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
<
< # order outputs
< return np.concatenate(gather_res, axis=0)
< elif target_data_type == 'list':
< out = []
< for r in gather_res:
< out.extend(r)
< return out
< else:
< return gather_res
---
> return loss
\ No newline at end of file
Only in sd1: main.py
Only in sd1: models
Only in sd1: notebook_helpers.py
Only in sd2: requirements.txt
Only in sd1/scripts: download_first_stages.sh
Only in sd1/scripts: download_models.sh
Only in sd2/scripts: gradio
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/scripts/img2img.py sd2/scripts/img2img.py
3c3
< import argparse, os, sys, glob
---
> import argparse, os
15d14
< import time
16a16
> from imwatermark import WatermarkEncoder
17a18,19
>
> from scripts.txt2img import put_watermark
20d21
< from ldm.models.diffusion.plms import PLMSSampler
52c53
< w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
---
> w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
87,98d87
< "--skip_grid",
< action='store_true',
< help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
< )
<
< parser.add_argument(
< "--skip_save",
< action='store_true',
< help="do not save indiviual samples. For speed measurements.",
< )
<
< parser.add_argument(
106,110d94
< "--plms",
< action='store_true',
< help="use plms sampling",
< )
< parser.add_argument(
127a112
>
139a125
>
145a132
>
151a139
>
155c143
< default=5.0,
---
> default=9.0,
162c150
< default=0.75,
---
> default=0.8,
164a153
>
173c162
< default="configs/stable-diffusion/v1-inference.yaml",
---
> default="configs/stable-diffusion/v2-inference.yaml",
179d167
< default="models/ldm/stable-diffusion-v1/model.ckpt",
205,208d192
< if opt.plms:
< raise NotImplementedError("PLMS sampler not (yet) supported")
< sampler = PLMSSampler(model)
< else:
213a198,202
> print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
> wm = "SDV2"
> wm_encoder = WatermarkEncoder()
> wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
>
247d235
< tic = time.time()
267d254
< if not opt.skip_save:
270,271c257,259
< Image.fromarray(x_sample.astype(np.uint8)).save(
< os.path.join(sample_path, f"{base_count:05}.png"))
---
> img = Image.fromarray(x_sample.astype(np.uint8))
> img = put_watermark(img, wm_encoder)
> img.save(os.path.join(sample_path, f"{base_count:05}.png"))
275d262
< if not opt.skip_grid:
283c270,272
< Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
---
> grid = Image.fromarray(grid.astype(np.uint8))
> grid = put_watermark(grid, wm_encoder)
> grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
286,289c275
< toc = time.time()
<
< print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
< f" \nEnjoy.")
---
> print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
Only in sd1/scripts: inpaint.py
Only in sd1/scripts: knn2img.py
Only in sd1/scripts: latent_imagenet_diffusion.ipynb
Only in sd1/scripts: sample_diffusion.py
Only in sd2/scripts: streamlit
Only in sd1/scripts: train_searcher.py
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/scripts/txt2img.py sd2/scripts/txt2img.py
1c1
< import argparse, os, sys, glob
---
> import argparse, os
8d7
< from imwatermark import WatermarkEncoder
12d10
< import time
15c13,14
< from contextlib import contextmanager, nullcontext
---
> from contextlib import nullcontext
> from imwatermark import WatermarkEncoder
22,30c21
< from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
< from transformers import AutoFeatureExtractor
<
<
< # load safety model
< safety_model_id = "CompVis/stable-diffusion-safety-checker"
< safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
< safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
<
---
> torch.set_grad_enabled(False)
37,48d27
< def numpy_to_pil(images):
< """
< Convert a numpy image or a batch of images to a PIL image.
< """
< if images.ndim == 3:
< images = images[None, ...]
< images = (images * 255).round().astype("uint8")
< pil_images = [Image.fromarray(image) for image in images]
<
< return pil_images
<
<
69,98c48
< def put_watermark(img, wm_encoder=None):
< if wm_encoder is not None:
< img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
< img = wm_encoder.encode(img, 'dwtDct')
< img = Image.fromarray(img[:, :, ::-1])
< return img
<
<
< def load_replacement(x):
< try:
< hwc = x.shape
< y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
< y = (np.array(y)/255.0).astype(x.dtype)
< assert y.shape == x.shape
< return y
< except Exception:
< return x
<
<
< def check_safety(x_image):
< safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
< x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
< assert x_checked_image.shape[0] == len(has_nsfw_concept)
< for i in range(len(has_nsfw_concept)):
< if has_nsfw_concept[i]:
< x_checked_image[i] = load_replacement(x_checked_image[i])
< return x_checked_image, has_nsfw_concept
<
<
< def main():
---
> def parse_args():
100d49
<
105c54
< default="a painting of a virus monster playing guitar",
---
> default="a professional photograph of an astronaut riding a triceratops",
116,126c65
< "--skip_grid",
< action='store_true',
< help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
< )
< parser.add_argument(
< "--skip_save",
< action='store_true',
< help="do not save individual samples. For speed measurements.",
< )
< parser.add_argument(
< "--ddim_steps",
---
> "--steps",
137,142c76
< "--dpm_solver",
< action='store_true',
< help="use dpm_solver sampling",
< )
< parser.add_argument(
< "--laion400m",
---
> "--dpm",
144c78
< help="uses the LAION400M model",
---
> help="use DPM (2) sampler",
149c83
< help="if enabled, uses the same starting code across samples ",
---
> help="if enabled, uses the same starting code across all samples ",
160c94
< default=2,
---
> default=3,
185c119
< help="downsampling factor",
---
> help="downsampling factor, most often 8 or 16",
191c125
< help="how many samples to produce for each given prompt. A.k.a. batch size",
---
> help="how many samples to produce for each given prompt. A.k.a batch size",
202c136
< default=7.5,
---
> default=9.0,
208c142
< help="if specified, load prompts from this file",
---
> help="if specified, load prompts from this file, separated by newlines",
213c147
< default="configs/stable-diffusion/v1-inference.yaml",
---
> default="configs/stable-diffusion/v2-inference.yaml",
219d152
< default="models/ldm/stable-diffusion-v1/model.ckpt",
234a168,173
> parser.add_argument(
> "--repeat",
> type=int,
> default=1,
> help="repeat each prompt in file this often",
> )
235a175
> return opt
237,241d176
< if opt.laion400m:
< print("Falling back to LAION 400M model...")
< opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
< opt.ckpt = "models/ldm/text2img-large/model.ckpt"
< opt.outdir = "outputs/txt2img-samples-laion400m"
242a178,186
> def put_watermark(img, wm_encoder=None):
> if wm_encoder is not None:
> img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
> img = wm_encoder.encode(img, 'dwtDct')
> img = Image.fromarray(img[:, :, ::-1])
> return img
>
>
> def main(opt):
251,253c195
< if opt.dpm_solver:
< sampler = DPMSolverSampler(model)
< elif opt.plms:
---
> if opt.plms:
254a197,198
> elif opt.dpm:
> sampler = DPMSolverSampler(model)
262c206
< wm = "StableDiffusionV1"
---
> wm = "SDV2"
276a221
> data = [p for p in data for i in range(opt.repeat)]
280a226
> sample_count = 0
289,292c235,237
< with torch.no_grad():
< with precision_scope("cuda"):
< with model.ema_scope():
< tic = time.time()
---
> with torch.no_grad(), \
> precision_scope("cuda"), \
> model.ema_scope():
303c248
< samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
---
> samples, _ = sampler.sample(S=opt.steps,
313,319c258,259
< x_samples_ddim = model.decode_first_stage(samples_ddim)
< x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
< x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
<
< x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
<
< x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
---
> x_samples = model.decode_first_stage(samples)
> x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
321,322c261
< if not opt.skip_save:
< for x_sample in x_checked_image_torch:
---
> for x_sample in x_samples:
327a267
> sample_count += 1
329,330c269
< if not opt.skip_grid:
< all_samples.append(x_checked_image_torch)
---
> all_samples.append(x_samples)
332d270
< if not opt.skip_grid:
340,342c278,280
< img = Image.fromarray(grid.astype(np.uint8))
< img = put_watermark(img, wm_encoder)
< img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
---
> grid = Image.fromarray(grid.astype(np.uint8))
> grid = put_watermark(grid, wm_encoder)
> grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
345,346d282
< toc = time.time()
<
352c288,289
< main()
---
> opt = parse_args()
> main(opt)
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/setup.py sd2/setup.py
4c4
< name='latent-diffusion',
---
> name='stable-diffusion',
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment