Created
November 24, 2022 04:39
-
-
Save cmdr2/685d0e1f33c2e6a869aeee190b7b3f2d to your computer and use it in GitHub Desktop.
Diff between https://github.com/CompVis/stable-diffusion (sd1) and https://github.com/Stability-AI/stablediffusion (sd2). Ran `diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1 sd2 > d.txt`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Only in sd2: LICENSE-MODEL | |
Only in sd2/assets: stable-inpainting | |
Only in sd2/assets/stable-samples: depth2img | |
Only in sd2/assets/stable-samples/txt2img: 768 | |
Only in sd2/assets/stable-samples: upscaling | |
Only in sd1/configs: autoencoder | |
Only in sd1/configs: latent-diffusion | |
Only in sd1/configs: retrieval-augmented-diffusion | |
Only in sd1/configs/stable-diffusion: v1-inference.yaml | |
Only in sd2/configs/stable-diffusion: v2-inference-v.yaml | |
Only in sd2/configs/stable-diffusion: v2-inference.yaml | |
Only in sd2/configs/stable-diffusion: v2-inpainting-inference.yaml | |
Only in sd2/configs/stable-diffusion: v2-midas-inference.yaml | |
Only in sd2/configs/stable-diffusion: x4-upscaling.yaml | |
Only in sd1: data | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/environment.yaml sd2/environment.yaml | |
9,11c9,11 | |
< - pytorch=1.11.0 | |
< - torchvision=0.12.0 | |
< - numpy=1.19.2 | |
--- | |
> - pytorch=1.12.1 | |
> - torchvision=0.13.1 | |
> - numpy=1.23.1 | |
13,17c13,14 | |
< - albumentations==0.4.3 | |
< - diffusers | |
< - opencv-python==4.1.2.30 | |
< - pudb==2019.2 | |
< - invisible-watermark | |
--- | |
> - albumentations==1.3.0 | |
> - opencv-python==4.6.0.66 | |
23c20 | |
< - streamlit>=0.73.1 | |
--- | |
> - streamlit==1.12.1 | |
25d21 | |
< - torch-fidelity==0.3.0 | |
27c23 | |
< - torchmetrics==0.6.0 | |
--- | |
> - webdataset==0.2.5 | |
29,30c25,28 | |
< - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | |
< - -e git+https://github.com/openai/CLIP.git@main#egg=clip | |
--- | |
> - open_clip_torch==2.0.2 | |
> - invisible-watermark>=0.1.5 | |
> - streamlit-drawable-canvas==0.8.0 | |
> - torchmetrics==0.6.0 | |
Only in sd1/ldm/data: base.py | |
Only in sd1/ldm/data: imagenet.py | |
Only in sd1/ldm/data: lsun.py | |
Only in sd2/ldm/data: util.py | |
Only in sd1/ldm: lr_scheduler.py | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/autoencoder.py sd2/ldm/models/autoencoder.py | |
6,7d5 | |
< from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer | |
< | |
11a10 | |
> from ldm.modules.ema import LitEma | |
14c13 | |
< class VQModel(pl.LightningModule): | |
--- | |
> class AutoencoderKL(pl.LightningModule): | |
18d16 | |
< n_embed, | |
25,30c23,24 | |
< batch_resize_range=None, | |
< scheduler_config=None, | |
< lr_g_factor=1.0, | |
< remap=None, | |
< sane_index_shape=False, # tell vector quantizer to return indices as bhw | |
< use_ema=False | |
--- | |
> ema_decay=None, | |
> learn_logvar=False | |
33,34c27 | |
< self.embed_dim = embed_dim | |
< self.n_embed = n_embed | |
--- | |
> self.learn_logvar = learn_logvar | |
39,42c32,33 | |
< self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, | |
< remap=remap, | |
< sane_index_shape=sane_index_shape) | |
< self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) | |
--- | |
> assert ddconfig["double_z"] | |
> self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) | |
43a35 | |
> self.embed_dim = embed_dim | |
49,51d40 | |
< self.batch_resize_range = batch_resize_range | |
< if self.batch_resize_range is not None: | |
< print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") | |
53c42 | |
< self.use_ema = use_ema | |
--- | |
> self.use_ema = ema_decay is not None | |
55c44,46 | |
< self.model_ema = LitEma(self) | |
--- | |
> self.ema_decay = ema_decay | |
> assert 0. < ema_decay < 1. | |
> self.model_ema = LitEma(self, decay=ema_decay) | |
60,61c51,61 | |
< self.scheduler_config = scheduler_config | |
< self.lr_g_factor = lr_g_factor | |
--- | |
> | |
> def init_from_ckpt(self, path, ignore_keys=list()): | |
> sd = torch.load(path, map_location="cpu")["state_dict"] | |
> keys = list(sd.keys()) | |
> for k in keys: | |
> for ik in ignore_keys: | |
> if k.startswith(ik): | |
> print("Deleting key {} from state_dict.".format(k)) | |
> del sd[k] | |
> self.load_state_dict(sd, strict=False) | |
> print(f"Restored from {path}") | |
78,91d77 | |
< def init_from_ckpt(self, path, ignore_keys=list()): | |
< sd = torch.load(path, map_location="cpu")["state_dict"] | |
< keys = list(sd.keys()) | |
< for k in keys: | |
< for ik in ignore_keys: | |
< if k.startswith(ik): | |
< print("Deleting key {} from state_dict.".format(k)) | |
< del sd[k] | |
< missing, unexpected = self.load_state_dict(sd, strict=False) | |
< print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") | |
< if len(missing) > 0: | |
< print(f"Missing Keys: {missing}") | |
< print(f"Unexpected Keys: {unexpected}") | |
< | |
98,325d83 | |
< h = self.quant_conv(h) | |
< quant, emb_loss, info = self.quantize(h) | |
< return quant, emb_loss, info | |
< | |
< def encode_to_prequant(self, x): | |
< h = self.encoder(x) | |
< h = self.quant_conv(h) | |
< return h | |
< | |
< def decode(self, quant): | |
< quant = self.post_quant_conv(quant) | |
< dec = self.decoder(quant) | |
< return dec | |
< | |
< def decode_code(self, code_b): | |
< quant_b = self.quantize.embed_code(code_b) | |
< dec = self.decode(quant_b) | |
< return dec | |
< | |
< def forward(self, input, return_pred_indices=False): | |
< quant, diff, (_,_,ind) = self.encode(input) | |
< dec = self.decode(quant) | |
< if return_pred_indices: | |
< return dec, diff, ind | |
< return dec, diff | |
< | |
< def get_input(self, batch, k): | |
< x = batch[k] | |
< if len(x.shape) == 3: | |
< x = x[..., None] | |
< x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() | |
< if self.batch_resize_range is not None: | |
< lower_size = self.batch_resize_range[0] | |
< upper_size = self.batch_resize_range[1] | |
< if self.global_step <= 4: | |
< # do the first few batches with max size to avoid later oom | |
< new_resize = upper_size | |
< else: | |
< new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) | |
< if new_resize != x.shape[2]: | |
< x = F.interpolate(x, size=new_resize, mode="bicubic") | |
< x = x.detach() | |
< return x | |
< | |
< def training_step(self, batch, batch_idx, optimizer_idx): | |
< # https://github.com/pytorch/pytorch/issues/37142 | |
< # try not to fool the heuristics | |
< x = self.get_input(batch, self.image_key) | |
< xrec, qloss, ind = self(x, return_pred_indices=True) | |
< | |
< if optimizer_idx == 0: | |
< # autoencode | |
< aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, | |
< last_layer=self.get_last_layer(), split="train", | |
< predicted_indices=ind) | |
< | |
< self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) | |
< return aeloss | |
< | |
< if optimizer_idx == 1: | |
< # discriminator | |
< discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, | |
< last_layer=self.get_last_layer(), split="train") | |
< self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) | |
< return discloss | |
< | |
< def validation_step(self, batch, batch_idx): | |
< log_dict = self._validation_step(batch, batch_idx) | |
< with self.ema_scope(): | |
< log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") | |
< return log_dict | |
< | |
< def _validation_step(self, batch, batch_idx, suffix=""): | |
< x = self.get_input(batch, self.image_key) | |
< xrec, qloss, ind = self(x, return_pred_indices=True) | |
< aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, | |
< self.global_step, | |
< last_layer=self.get_last_layer(), | |
< split="val"+suffix, | |
< predicted_indices=ind | |
< ) | |
< | |
< discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, | |
< self.global_step, | |
< last_layer=self.get_last_layer(), | |
< split="val"+suffix, | |
< predicted_indices=ind | |
< ) | |
< rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] | |
< self.log(f"val{suffix}/rec_loss", rec_loss, | |
< prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) | |
< self.log(f"val{suffix}/aeloss", aeloss, | |
< prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) | |
< if version.parse(pl.__version__) >= version.parse('1.4.0'): | |
< del log_dict_ae[f"val{suffix}/rec_loss"] | |
< self.log_dict(log_dict_ae) | |
< self.log_dict(log_dict_disc) | |
< return self.log_dict | |
< | |
< def configure_optimizers(self): | |
< lr_d = self.learning_rate | |
< lr_g = self.lr_g_factor*self.learning_rate | |
< print("lr_d", lr_d) | |
< print("lr_g", lr_g) | |
< opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ | |
< list(self.decoder.parameters())+ | |
< list(self.quantize.parameters())+ | |
< list(self.quant_conv.parameters())+ | |
< list(self.post_quant_conv.parameters()), | |
< lr=lr_g, betas=(0.5, 0.9)) | |
< opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), | |
< lr=lr_d, betas=(0.5, 0.9)) | |
< | |
< if self.scheduler_config is not None: | |
< scheduler = instantiate_from_config(self.scheduler_config) | |
< | |
< print("Setting up LambdaLR scheduler...") | |
< scheduler = [ | |
< { | |
< 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), | |
< 'interval': 'step', | |
< 'frequency': 1 | |
< }, | |
< { | |
< 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), | |
< 'interval': 'step', | |
< 'frequency': 1 | |
< }, | |
< ] | |
< return [opt_ae, opt_disc], scheduler | |
< return [opt_ae, opt_disc], [] | |
< | |
< def get_last_layer(self): | |
< return self.decoder.conv_out.weight | |
< | |
< def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): | |
< log = dict() | |
< x = self.get_input(batch, self.image_key) | |
< x = x.to(self.device) | |
< if only_inputs: | |
< log["inputs"] = x | |
< return log | |
< xrec, _ = self(x) | |
< if x.shape[1] > 3: | |
< # colorize with random projection | |
< assert xrec.shape[1] > 3 | |
< x = self.to_rgb(x) | |
< xrec = self.to_rgb(xrec) | |
< log["inputs"] = x | |
< log["reconstructions"] = xrec | |
< if plot_ema: | |
< with self.ema_scope(): | |
< xrec_ema, _ = self(x) | |
< if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) | |
< log["reconstructions_ema"] = xrec_ema | |
< return log | |
< | |
< def to_rgb(self, x): | |
< assert self.image_key == "segmentation" | |
< if not hasattr(self, "colorize"): | |
< self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) | |
< x = F.conv2d(x, weight=self.colorize) | |
< x = 2.*(x-x.min())/(x.max()-x.min()) - 1. | |
< return x | |
< | |
< | |
< class VQModelInterface(VQModel): | |
< def __init__(self, embed_dim, *args, **kwargs): | |
< super().__init__(embed_dim=embed_dim, *args, **kwargs) | |
< self.embed_dim = embed_dim | |
< | |
< def encode(self, x): | |
< h = self.encoder(x) | |
< h = self.quant_conv(h) | |
< return h | |
< | |
< def decode(self, h, force_not_quantize=False): | |
< # also go through quantization layer | |
< if not force_not_quantize: | |
< quant, emb_loss, info = self.quantize(h) | |
< else: | |
< quant = h | |
< quant = self.post_quant_conv(quant) | |
< dec = self.decoder(quant) | |
< return dec | |
< | |
< | |
< class AutoencoderKL(pl.LightningModule): | |
< def __init__(self, | |
< ddconfig, | |
< lossconfig, | |
< embed_dim, | |
< ckpt_path=None, | |
< ignore_keys=[], | |
< image_key="image", | |
< colorize_nlabels=None, | |
< monitor=None, | |
< ): | |
< super().__init__() | |
< self.image_key = image_key | |
< self.encoder = Encoder(**ddconfig) | |
< self.decoder = Decoder(**ddconfig) | |
< self.loss = instantiate_from_config(lossconfig) | |
< assert ddconfig["double_z"] | |
< self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) | |
< self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
< self.embed_dim = embed_dim | |
< if colorize_nlabels is not None: | |
< assert type(colorize_nlabels)==int | |
< self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) | |
< if monitor is not None: | |
< self.monitor = monitor | |
< if ckpt_path is not None: | |
< self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
< | |
< def init_from_ckpt(self, path, ignore_keys=list()): | |
< sd = torch.load(path, map_location="cpu")["state_dict"] | |
< keys = list(sd.keys()) | |
< for k in keys: | |
< for ik in ignore_keys: | |
< if k.startswith(ik): | |
< print("Deleting key {} from state_dict.".format(k)) | |
< del sd[k] | |
< self.load_state_dict(sd, strict=False) | |
< print(f"Restored from {path}") | |
< | |
< def encode(self, x): | |
< h = self.encoder(x) | |
372a131,136 | |
> log_dict = self._validation_step(batch, batch_idx) | |
> with self.ema_scope(): | |
> log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") | |
> return log_dict | |
> | |
> def _validation_step(self, batch, batch_idx, postfix=""): | |
376c140 | |
< last_layer=self.get_last_layer(), split="val") | |
--- | |
> last_layer=self.get_last_layer(), split="val"+postfix) | |
379c143 | |
< last_layer=self.get_last_layer(), split="val") | |
--- | |
> last_layer=self.get_last_layer(), split="val"+postfix) | |
381c145 | |
< self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) | |
--- | |
> self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) | |
388,391c152,157 | |
< opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ | |
< list(self.decoder.parameters())+ | |
< list(self.quant_conv.parameters())+ | |
< list(self.post_quant_conv.parameters()), | |
--- | |
> ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( | |
> self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) | |
> if self.learn_logvar: | |
> print(f"{self.__class__.__name__}: Learning logvar") | |
> ae_params_list.append(self.loss.logvar) | |
> opt_ae = torch.optim.Adam(ae_params_list, | |
401c167 | |
< def log_images(self, batch, only_inputs=False, **kwargs): | |
--- | |
> def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): | |
413a180,188 | |
> if log_ema or self.use_ema: | |
> with self.ema_scope(): | |
> xrec_ema, posterior_ema = self(x) | |
> if x.shape[1] > 3: | |
> # colorize with random projection | |
> assert xrec_ema.shape[1] > 3 | |
> xrec_ema = self.to_rgb(xrec_ema) | |
> log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) | |
> log["reconstructions_ema"] = xrec_ema | |
428c203 | |
< self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff | |
--- | |
> self.vq_interface = vq_interface | |
443a219 | |
> | |
Only in sd1/ldm/models/diffusion: classifier.py | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/ddim.py sd2/ldm/models/diffusion/ddim.py | |
6d5 | |
< from functools import partial | |
8,9c7 | |
< from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ | |
< extract_into_tensor | |
--- | |
> from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor | |
77,78c75,77 | |
< unconditional_conditioning=None, | |
< # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
--- | |
> unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
> dynamic_threshold=None, | |
> ucg_schedule=None, | |
83c82,84 | |
< cbs = conditioning[list(conditioning.keys())[0]].shape[0] | |
--- | |
> ctmp = conditioning[list(conditioning.keys())[0]] | |
> while isinstance(ctmp, list): ctmp = ctmp[0] | |
> cbs = ctmp.shape[0] | |
85a87,92 | |
> | |
> elif isinstance(conditioning, list): | |
> for ctmp in conditioning: | |
> if ctmp.shape[0] != batch_size: | |
> print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") | |
> | |
109a117,118 | |
> dynamic_threshold=dynamic_threshold, | |
> ucg_schedule=ucg_schedule | |
119c128,129 | |
< unconditional_guidance_scale=1., unconditional_conditioning=None,): | |
--- | |
> unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, | |
> ucg_schedule=None): | |
148a159,162 | |
> if ucg_schedule is not None: | |
> assert len(ucg_schedule) == len(time_range) | |
> unconditional_guidance_scale = ucg_schedule[i] | |
> | |
154c168,169 | |
< unconditional_conditioning=unconditional_conditioning) | |
--- | |
> unconditional_conditioning=unconditional_conditioning, | |
> dynamic_threshold=dynamic_threshold) | |
168c183,184 | |
< unconditional_guidance_scale=1., unconditional_conditioning=None): | |
--- | |
> unconditional_guidance_scale=1., unconditional_conditioning=None, | |
> dynamic_threshold=None): | |
172c188 | |
< e_t = self.model.apply_model(x, t, c) | |
--- | |
> model_output = self.model.apply_model(x, t, c) | |
175a192,209 | |
> if isinstance(c, dict): | |
> assert isinstance(unconditional_conditioning, dict) | |
> c_in = dict() | |
> for k in c: | |
> if isinstance(c[k], list): | |
> c_in[k] = [torch.cat([ | |
> unconditional_conditioning[k][i], | |
> c[k][i]]) for i in range(len(c[k]))] | |
> else: | |
> c_in[k] = torch.cat([ | |
> unconditional_conditioning[k], | |
> c[k]]) | |
> elif isinstance(c, list): | |
> c_in = list() | |
> assert isinstance(unconditional_conditioning, list) | |
> for i in range(len(c)): | |
> c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) | |
> else: | |
177,178c211,217 | |
< e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | |
< e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
--- | |
> model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | |
> model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) | |
> | |
> if self.model.parameterization == "v": | |
> e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) | |
> else: | |
> e_t = model_output | |
181c220 | |
< assert self.model.parameterization == "eps" | |
--- | |
> assert self.model.parameterization == "eps", 'not implemented' | |
194a234 | |
> if self.model.parameterization != "v": | |
195a236,238 | |
> else: | |
> pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) | |
> | |
197a241,244 | |
> | |
> if dynamic_threshold is not None: | |
> raise NotImplementedError() | |
> | |
206a254,300 | |
> def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, | |
> unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): | |
> num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] | |
> | |
> assert t_enc <= num_reference_steps | |
> num_steps = t_enc | |
> | |
> if use_original_steps: | |
> alphas_next = self.alphas_cumprod[:num_steps] | |
> alphas = self.alphas_cumprod_prev[:num_steps] | |
> else: | |
> alphas_next = self.ddim_alphas[:num_steps] | |
> alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) | |
> | |
> x_next = x0 | |
> intermediates = [] | |
> inter_steps = [] | |
> for i in tqdm(range(num_steps), desc='Encoding Image'): | |
> t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) | |
> if unconditional_guidance_scale == 1.: | |
> noise_pred = self.model.apply_model(x_next, t, c) | |
> else: | |
> assert unconditional_conditioning is not None | |
> e_t_uncond, noise_pred = torch.chunk( | |
> self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), | |
> torch.cat((unconditional_conditioning, c))), 2) | |
> noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) | |
> | |
> xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next | |
> weighted_noise_pred = alphas_next[i].sqrt() * ( | |
> (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred | |
> x_next = xt_weighted + weighted_noise_pred | |
> if return_intermediates and i % ( | |
> num_steps // return_intermediates) == 0 and i < num_steps - 1: | |
> intermediates.append(x_next) | |
> inter_steps.append(i) | |
> elif return_intermediates and i >= num_steps - 2: | |
> intermediates.append(x_next) | |
> inter_steps.append(i) | |
> if callback: callback(i) | |
> | |
> out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} | |
> if return_intermediates: | |
> out.update({'intermediates': intermediates}) | |
> return x_next, out | |
> | |
> @torch.no_grad() | |
224c318 | |
< use_original_steps=False): | |
--- | |
> use_original_steps=False, callback=None): | |
240a335 | |
> if callback: callback(i) | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/ddpm.py sd2/ldm/models/diffusion/ddpm.py | |
15c15 | |
< from contextlib import contextmanager | |
--- | |
> from contextlib import contextmanager, nullcontext | |
16a17 | |
> import itertools | |
19a21 | |
> from omegaconf import ListConfig | |
24c26 | |
< from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL | |
--- | |
> from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL | |
73a76,79 | |
> make_it_fit=False, | |
> ucg_training=None, | |
> reset_ema=False, | |
> reset_num_ema_updates=False, | |
76c82 | |
< assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' | |
--- | |
> assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' | |
102a109,110 | |
> self.make_it_fit = make_it_fit | |
> if reset_ema: assert exists(ckpt_path) | |
104a113,120 | |
> if reset_ema: | |
> assert self.use_ema | |
> print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") | |
> self.model_ema = LitEma(self.model) | |
> if reset_num_ema_updates: | |
> print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") | |
> assert self.use_ema | |
> self.model_ema.reset_num_updates() | |
115a132,134 | |
> self.ucg_training = ucg_training or dict() | |
> if self.ucg_training: | |
> self.ucg_prng = np.random.RandomState() | |
163a183,185 | |
> elif self.parameterization == "v": | |
> lvlb_weights = torch.ones_like(self.betas ** 2 / ( | |
> 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) | |
166d187 | |
< # TODO how to choose this term | |
185a207 | |
> @torch.no_grad() | |
195a218,261 | |
> if self.make_it_fit: | |
> n_params = len([name for name, _ in | |
> itertools.chain(self.named_parameters(), | |
> self.named_buffers())]) | |
> for name, param in tqdm( | |
> itertools.chain(self.named_parameters(), | |
> self.named_buffers()), | |
> desc="Fitting old weights to new weights", | |
> total=n_params | |
> ): | |
> if not name in sd: | |
> continue | |
> old_shape = sd[name].shape | |
> new_shape = param.shape | |
> assert len(old_shape) == len(new_shape) | |
> if len(new_shape) > 2: | |
> # we only modify first two axes | |
> assert new_shape[2:] == old_shape[2:] | |
> # assumes first axis corresponds to output dim | |
> if not new_shape == old_shape: | |
> new_param = param.clone() | |
> old_param = sd[name] | |
> if len(new_shape) == 1: | |
> for i in range(new_param.shape[0]): | |
> new_param[i] = old_param[i % old_shape[0]] | |
> elif len(new_shape) >= 2: | |
> for i in range(new_param.shape[0]): | |
> for j in range(new_param.shape[1]): | |
> new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] | |
> | |
> n_used_old = torch.ones(old_shape[1]) | |
> for j in range(new_param.shape[1]): | |
> n_used_old[j % old_shape[1]] += 1 | |
> n_used_new = torch.zeros(new_shape[1]) | |
> for j in range(new_param.shape[1]): | |
> n_used_new[j] = n_used_old[j % old_shape[1]] | |
> | |
> n_used_new = n_used_new[None, :] | |
> while len(n_used_new.shape) < len(new_shape): | |
> n_used_new = n_used_new.unsqueeze(-1) | |
> new_param /= n_used_new | |
> | |
> sd[name] = new_param | |
> | |
200c266 | |
< print(f"Missing Keys: {missing}") | |
--- | |
> print(f"Missing Keys:\n {missing}") | |
202c268 | |
< print(f"Unexpected Keys: {unexpected}") | |
--- | |
> print(f"\nUnexpected Keys:\n {unexpected}") | |
221a288,301 | |
> def predict_start_from_z_and_v(self, x_t, t, v): | |
> # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) | |
> # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) | |
> return ( | |
> extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - | |
> extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v | |
> ) | |
> | |
> def predict_eps_from_z_and_v(self, x_t, t, v): | |
> return ( | |
> extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + | |
> extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t | |
> ) | |
> | |
278a359,364 | |
> def get_v(self, x, noise, t): | |
> return ( | |
> extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - | |
> extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x | |
> ) | |
> | |
303a390,391 | |
> elif self.parameterization == "v": | |
> target = self.get_v(x_start, noise, t) | |
342a431,439 | |
> for k in self.ucg_training: | |
> p = self.ucg_training[k]["p"] | |
> val = self.ucg_training[k]["val"] | |
> if val is None: | |
> val = "" | |
> for i in range(len(batch[k])): | |
> if self.ucg_prng.choice(2, p=[1 - p, p]): | |
> batch[k][i] = val | |
> | |
425a523 | |
> | |
436a535 | |
> force_null_conditioning=False, | |
437a537 | |
> self.force_null_conditioning = force_null_conditioning | |
444c544 | |
< if cond_stage_config == '__is_unconditional__': | |
--- | |
> if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: | |
446a547,548 | |
> reset_ema = kwargs.pop("reset_ema", False) | |
> reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) | |
469a572,580 | |
> if reset_ema: | |
> assert self.use_ema | |
> print( | |
> f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") | |
> self.model_ema = LitEma(self.model) | |
> if reset_num_ema_updates: | |
> print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") | |
> assert self.use_ema | |
> self.model_ema.reset_num_updates() | |
655c766 | |
< cond_key=None, return_original_cond=False, bs=None): | |
--- | |
> cond_key=None, return_original_cond=False, bs=None, return_x=False): | |
663c774 | |
< if self.model.conditioning_key is not None: | |
--- | |
> if self.model.conditioning_key is not None and not self.force_null_conditioning: | |
667c778 | |
< if cond_key in ['caption', 'coordinates_bbox']: | |
--- | |
> if cond_key in ['caption', 'coordinates_bbox', "txt"]: | |
669c780 | |
< elif cond_key == 'class_label': | |
--- | |
> elif cond_key in ['class_label', 'cls']: | |
677d787 | |
< # import pudb; pudb.set_trace() | |
700a811,812 | |
> if return_x: | |
> out.extend([x]) | |
714,822d825 | |
< | |
< if hasattr(self, "split_input_params"): | |
< if self.split_input_params["patch_distributed_vq"]: | |
< ks = self.split_input_params["ks"] # eg. (128, 128) | |
< stride = self.split_input_params["stride"] # eg. (64, 64) | |
< uf = self.split_input_params["vqf"] | |
< bs, nc, h, w = z.shape | |
< if ks[0] > h or ks[1] > w: | |
< ks = (min(ks[0], h), min(ks[1], w)) | |
< print("reducing Kernel") | |
< | |
< if stride[0] > h or stride[1] > w: | |
< stride = (min(stride[0], h), min(stride[1], w)) | |
< print("reducing stride") | |
< | |
< fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) | |
< | |
< z = unfold(z) # (bn, nc * prod(**ks), L) | |
< # 1. Reshape to img shape | |
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
< | |
< # 2. apply model loop over last dim | |
< if isinstance(self.first_stage_model, VQModelInterface): | |
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i], | |
< force_not_quantize=predict_cids or force_not_quantize) | |
< for i in range(z.shape[-1])] | |
< else: | |
< | |
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) | |
< for i in range(z.shape[-1])] | |
< | |
< o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) | |
< o = o * weighting | |
< # Reverse 1. reshape to img shape | |
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
< # stitch crops together | |
< decoded = fold(o) | |
< decoded = decoded / normalization # norm is shape (1, 1, h, w) | |
< return decoded | |
< else: | |
< if isinstance(self.first_stage_model, VQModelInterface): | |
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
< else: | |
< return self.first_stage_model.decode(z) | |
< | |
< else: | |
< if isinstance(self.first_stage_model, VQModelInterface): | |
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
< else: | |
< return self.first_stage_model.decode(z) | |
< | |
< # same as above but without decorator | |
< def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
< if predict_cids: | |
< if z.dim() == 4: | |
< z = torch.argmax(z.exp(), dim=1).long() | |
< z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
< z = rearrange(z, 'b h w c -> b c h w').contiguous() | |
< | |
< z = 1. / self.scale_factor * z | |
< | |
< if hasattr(self, "split_input_params"): | |
< if self.split_input_params["patch_distributed_vq"]: | |
< ks = self.split_input_params["ks"] # eg. (128, 128) | |
< stride = self.split_input_params["stride"] # eg. (64, 64) | |
< uf = self.split_input_params["vqf"] | |
< bs, nc, h, w = z.shape | |
< if ks[0] > h or ks[1] > w: | |
< ks = (min(ks[0], h), min(ks[1], w)) | |
< print("reducing Kernel") | |
< | |
< if stride[0] > h or stride[1] > w: | |
< stride = (min(stride[0], h), min(stride[1], w)) | |
< print("reducing stride") | |
< | |
< fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) | |
< | |
< z = unfold(z) # (bn, nc * prod(**ks), L) | |
< # 1. Reshape to img shape | |
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
< | |
< # 2. apply model loop over last dim | |
< if isinstance(self.first_stage_model, VQModelInterface): | |
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i], | |
< force_not_quantize=predict_cids or force_not_quantize) | |
< for i in range(z.shape[-1])] | |
< else: | |
< | |
< output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) | |
< for i in range(z.shape[-1])] | |
< | |
< o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) | |
< o = o * weighting | |
< # Reverse 1. reshape to img shape | |
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
< # stitch crops together | |
< decoded = fold(o) | |
< decoded = decoded / normalization # norm is shape (1, 1, h, w) | |
< return decoded | |
< else: | |
< if isinstance(self.first_stage_model, VQModelInterface): | |
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
< else: | |
< return self.first_stage_model.decode(z) | |
< | |
< else: | |
< if isinstance(self.first_stage_model, VQModelInterface): | |
< return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
< else: | |
827,862d829 | |
< if hasattr(self, "split_input_params"): | |
< if self.split_input_params["patch_distributed_vq"]: | |
< ks = self.split_input_params["ks"] # eg. (128, 128) | |
< stride = self.split_input_params["stride"] # eg. (64, 64) | |
< df = self.split_input_params["vqf"] | |
< self.split_input_params['original_image_size'] = x.shape[-2:] | |
< bs, nc, h, w = x.shape | |
< if ks[0] > h or ks[1] > w: | |
< ks = (min(ks[0], h), min(ks[1], w)) | |
< print("reducing Kernel") | |
< | |
< if stride[0] > h or stride[1] > w: | |
< stride = (min(stride[0], h), min(stride[1], w)) | |
< print("reducing stride") | |
< | |
< fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) | |
< z = unfold(x) # (bn, nc * prod(**ks), L) | |
< # Reshape to img shape | |
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
< | |
< output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) | |
< for i in range(z.shape[-1])] | |
< | |
< o = torch.stack(output_list, axis=-1) | |
< o = o * weighting | |
< | |
< # Reverse reshape to img shape | |
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
< # stitch crops together | |
< decoded = fold(o) | |
< decoded = decoded / normalization | |
< return decoded | |
< | |
< else: | |
< return self.first_stage_model.encode(x) | |
< else: | |
881,890d847 | |
< def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset | |
< def rescale_bbox(bbox): | |
< x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) | |
< y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) | |
< w = min(bbox[2] / crop_coordinates[2], 1 - x0) | |
< h = min(bbox[3] / crop_coordinates[3], 1 - y0) | |
< return x0, y0, w, h | |
< | |
< return [rescale_bbox(b) for b in bboxes] | |
< | |
892d848 | |
< | |
894c850 | |
< # hybrid case, cond is exptected to be a dict | |
--- | |
> # hybrid case, cond is expected to be a dict | |
902,986d857 | |
< if hasattr(self, "split_input_params"): | |
< assert len(cond) == 1 # todo can only deal with one conditioning atm | |
< assert not return_ids | |
< ks = self.split_input_params["ks"] # eg. (128, 128) | |
< stride = self.split_input_params["stride"] # eg. (64, 64) | |
< | |
< h, w = x_noisy.shape[-2:] | |
< | |
< fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) | |
< | |
< z = unfold(x_noisy) # (bn, nc * prod(**ks), L) | |
< # Reshape to img shape | |
< z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
< z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] | |
< | |
< if self.cond_stage_key in ["image", "LR_image", "segmentation", | |
< 'bbox_img'] and self.model.conditioning_key: # todo check for completeness | |
< c_key = next(iter(cond.keys())) # get key | |
< c = next(iter(cond.values())) # get value | |
< assert (len(c) == 1) # todo extend to list with more than one elem | |
< c = c[0] # get element | |
< | |
< c = unfold(c) | |
< c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
< | |
< cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] | |
< | |
< elif self.cond_stage_key == 'coordinates_bbox': | |
< assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' | |
< | |
< # assuming padding of unfold is always 0 and its dilation is always 1 | |
< n_patches_per_row = int((w - ks[0]) / stride[0] + 1) | |
< full_img_h, full_img_w = self.split_input_params['original_image_size'] | |
< # as we are operating on latents, we need the factor from the original image size to the | |
< # spatial latent size to properly rescale the crops for regenerating the bbox annotations | |
< num_downs = self.first_stage_model.encoder.num_resolutions - 1 | |
< rescale_latent = 2 ** (num_downs) | |
< | |
< # get top left postions of patches as conforming for the bbbox tokenizer, therefore we | |
< # need to rescale the tl patch coordinates to be in between (0,1) | |
< tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, | |
< rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) | |
< for patch_nr in range(z.shape[-1])] | |
< | |
< # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) | |
< patch_limits = [(x_tl, y_tl, | |
< rescale_latent * ks[0] / full_img_w, | |
< rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] | |
< # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] | |
< | |
< # tokenize crop coordinates for the bounding boxes of the respective patches | |
< patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) | |
< for bbox in patch_limits] # list of length l with tensors of shape (1, 2) | |
< print(patch_limits_tknzd[0].shape) | |
< # cut tknzd crop position from conditioning | |
< assert isinstance(cond, dict), 'cond must be dict to be fed into model' | |
< cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) | |
< print(cut_cond.shape) | |
< | |
< adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) | |
< adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') | |
< print(adapted_cond.shape) | |
< adapted_cond = self.get_learned_conditioning(adapted_cond) | |
< print(adapted_cond.shape) | |
< adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) | |
< print(adapted_cond.shape) | |
< | |
< cond_list = [{'c_crossattn': [e]} for e in adapted_cond] | |
< | |
< else: | |
< cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient | |
< | |
< # apply model by loop over crops | |
< output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] | |
< assert not isinstance(output_list[0], | |
< tuple) # todo cant deal with multiple model outputs check this never happens | |
< | |
< o = torch.stack(output_list, axis=-1) | |
< o = o * weighting | |
< # Reverse reshape to img shape | |
< o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
< # stitch crops together | |
< x_recon = fold(o) / normalization | |
< | |
< else: | |
1023a895,896 | |
> elif self.parameterization == "v": | |
> target = self.get_v(x_start, noise, t) | |
1236d1108 | |
< | |
1248a1121,1144 | |
> @torch.no_grad() | |
> def get_unconditional_conditioning(self, batch_size, null_label=None): | |
> if null_label is not None: | |
> xc = null_label | |
> if isinstance(xc, ListConfig): | |
> xc = list(xc) | |
> if isinstance(xc, dict) or isinstance(xc, list): | |
> c = self.get_learned_conditioning(xc) | |
> else: | |
> if hasattr(xc, "to"): | |
> xc = xc.to(self.device) | |
> c = self.get_learned_conditioning(xc) | |
> else: | |
> if self.cond_stage_key in ["class_label", "cls"]: | |
> xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) | |
> return self.get_learned_conditioning(xc) | |
> else: | |
> raise NotImplementedError("todo") | |
> if isinstance(c, list): # in case the encoder gives us a list | |
> for i in range(len(c)): | |
> c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) | |
> else: | |
> c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) | |
> return c | |
1251c1147 | |
< def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, | |
--- | |
> def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, | |
1253,1254c1149,1152 | |
< plot_diffusion_rows=True, **kwargs): | |
< | |
--- | |
> plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, | |
> use_ema_scope=True, | |
> **kwargs): | |
> ema_scope = self.ema_scope if use_ema_scope else nullcontext | |
1271,1272c1169,1170 | |
< elif self.cond_stage_key in ["caption"]: | |
< xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) | |
--- | |
> elif self.cond_stage_key in ["caption", "txt"]: | |
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) | |
1274,1275c1172,1174 | |
< elif self.cond_stage_key == 'class_label': | |
< xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) | |
--- | |
> elif self.cond_stage_key in ['class_label', "cls"]: | |
> try: | |
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) | |
1276a1176,1178 | |
> except KeyError: | |
> # probably no "human_label" in batch | |
> pass | |
1302c1204 | |
< with self.ema_scope("Plotting"): | |
--- | |
> with ema_scope("Sampling"): | |
1315c1217 | |
< with self.ema_scope("Plotting Quantized Denoised"): | |
--- | |
> with ema_scope("Plotting Quantized Denoised"): | |
1323a1226,1238 | |
> if unconditional_guidance_scale > 1.0: | |
> uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) | |
> if self.model.conditioning_key == "crossattn-adm": | |
> uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} | |
> with ema_scope("Sampling with classifier-free guidance"): | |
> samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, | |
> ddim_steps=ddim_steps, eta=ddim_eta, | |
> unconditional_guidance_scale=unconditional_guidance_scale, | |
> unconditional_conditioning=uc, | |
> ) | |
> x_samples_cfg = self.decode_first_stage(samples_cfg) | |
> log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg | |
> | |
1331,1332c1246 | |
< with self.ema_scope("Plotting Inpaint"): | |
< | |
--- | |
> with ema_scope("Plotting Inpaint"): | |
1340c1254,1255 | |
< with self.ema_scope("Plotting Outpaint"): | |
--- | |
> mask = 1. - mask | |
> with ema_scope("Plotting Outpaint"): | |
1347c1262 | |
< with self.ema_scope("Plotting Progressives"): | |
--- | |
> with ema_scope("Plotting Progressives"): | |
1397a1313 | |
> self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) | |
1400c1316 | |
< assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] | |
--- | |
> assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] | |
1402c1318 | |
< def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): | |
--- | |
> def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): | |
1408a1325 | |
> if not self.sequential_cross_attn: | |
1409a1327,1328 | |
> else: | |
> cc = c_crossattn | |
1414a1334,1342 | |
> elif self.conditioning_key == 'hybrid-adm': | |
> assert c_adm is not None | |
> xc = torch.cat([x] + c_concat, dim=1) | |
> cc = torch.cat(c_crossattn, 1) | |
> out = self.diffusion_model(xc, t, context=cc, y=c_adm) | |
> elif self.conditioning_key == 'crossattn-adm': | |
> assert c_adm is not None | |
> cc = torch.cat(c_crossattn, 1) | |
> out = self.diffusion_model(x, t, context=cc, y=c_adm) | |
1424,1445c1352,1795 | |
< class Layout2ImgDiffusion(LatentDiffusion): | |
< # TODO: move all layout-specific hacks to this class | |
< def __init__(self, cond_stage_key, *args, **kwargs): | |
< assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' | |
< super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) | |
< | |
< def log_images(self, batch, N=8, *args, **kwargs): | |
< logs = super().log_images(batch=batch, N=N, *args, **kwargs) | |
< | |
< key = 'train' if self.training else 'validation' | |
< dset = self.trainer.datamodule.datasets[key] | |
< mapper = dset.conditional_builders[self.cond_stage_key] | |
< | |
< bbox_imgs = [] | |
< map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) | |
< for tknzd_bbox in batch[self.cond_stage_key][:N]: | |
< bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) | |
< bbox_imgs.append(bboximg) | |
< | |
< cond_img = torch.stack(bbox_imgs, dim=0) | |
< logs['bbox_image'] = cond_img | |
< return logs | |
--- | |
> class LatentUpscaleDiffusion(LatentDiffusion): | |
> def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): | |
> super().__init__(*args, **kwargs) | |
> # assumes that neither the cond_stage nor the low_scale_model contain trainable params | |
> assert not self.cond_stage_trainable | |
> self.instantiate_low_stage(low_scale_config) | |
> self.low_scale_key = low_scale_key | |
> self.noise_level_key = noise_level_key | |
> | |
> def instantiate_low_stage(self, config): | |
> model = instantiate_from_config(config) | |
> self.low_scale_model = model.eval() | |
> self.low_scale_model.train = disabled_train | |
> for param in self.low_scale_model.parameters(): | |
> param.requires_grad = False | |
> | |
> @torch.no_grad() | |
> def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): | |
> if not log_mode: | |
> z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) | |
> else: | |
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, | |
> force_c_encode=True, return_original_cond=True, bs=bs) | |
> x_low = batch[self.low_scale_key][:bs] | |
> x_low = rearrange(x_low, 'b h w c -> b c h w') | |
> x_low = x_low.to(memory_format=torch.contiguous_format).float() | |
> zx, noise_level = self.low_scale_model(x_low) | |
> if self.noise_level_key is not None: | |
> # get noise level from batch instead, e.g. when extracting a custom noise level for bsr | |
> raise NotImplementedError('TODO') | |
> | |
> all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} | |
> if log_mode: | |
> # TODO: maybe disable if too expensive | |
> x_low_rec = self.low_scale_model.decode(zx) | |
> return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level | |
> return z, all_conds | |
> | |
> @torch.no_grad() | |
> def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, | |
> plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, | |
> unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, | |
> **kwargs): | |
> ema_scope = self.ema_scope if use_ema_scope else nullcontext | |
> use_ddim = ddim_steps is not None | |
> | |
> log = dict() | |
> z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, | |
> log_mode=True) | |
> N = min(x.shape[0], N) | |
> n_row = min(x.shape[0], n_row) | |
> log["inputs"] = x | |
> log["reconstruction"] = xrec | |
> log["x_lr"] = x_low | |
> log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec | |
> if self.model.conditioning_key is not None: | |
> if hasattr(self.cond_stage_model, "decode"): | |
> xc = self.cond_stage_model.decode(c) | |
> log["conditioning"] = xc | |
> elif self.cond_stage_key in ["caption", "txt"]: | |
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) | |
> log["conditioning"] = xc | |
> elif self.cond_stage_key in ['class_label', 'cls']: | |
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) | |
> log['conditioning'] = xc | |
> elif isimage(xc): | |
> log["conditioning"] = xc | |
> if ismap(xc): | |
> log["original_conditioning"] = self.to_rgb(xc) | |
> | |
> if plot_diffusion_rows: | |
> # get diffusion row | |
> diffusion_row = list() | |
> z_start = z[:n_row] | |
> for t in range(self.num_timesteps): | |
> if t % self.log_every_t == 0 or t == self.num_timesteps - 1: | |
> t = repeat(torch.tensor([t]), '1 -> b', b=n_row) | |
> t = t.to(self.device).long() | |
> noise = torch.randn_like(z_start) | |
> z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) | |
> diffusion_row.append(self.decode_first_stage(z_noisy)) | |
> | |
> diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W | |
> diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') | |
> diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') | |
> diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) | |
> log["diffusion_row"] = diffusion_grid | |
> | |
> if sample: | |
> # get denoise row | |
> with ema_scope("Sampling"): | |
> samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, | |
> ddim_steps=ddim_steps, eta=ddim_eta) | |
> # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) | |
> x_samples = self.decode_first_stage(samples) | |
> log["samples"] = x_samples | |
> if plot_denoise_rows: | |
> denoise_grid = self._get_denoise_row_from_list(z_denoise_row) | |
> log["denoise_row"] = denoise_grid | |
> | |
> if unconditional_guidance_scale > 1.0: | |
> uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) | |
> # TODO explore better "unconditional" choices for the other keys | |
> # maybe guide away from empty text label and highest noise level and maximally degraded zx? | |
> uc = dict() | |
> for k in c: | |
> if k == "c_crossattn": | |
> assert isinstance(c[k], list) and len(c[k]) == 1 | |
> uc[k] = [uc_tmp] | |
> elif k == "c_adm": # todo: only run with text-based guidance? | |
> assert isinstance(c[k], torch.Tensor) | |
> #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level | |
> uc[k] = c[k] | |
> elif isinstance(c[k], list): | |
> uc[k] = [c[k][i] for i in range(len(c[k]))] | |
> else: | |
> uc[k] = c[k] | |
> | |
> with ema_scope("Sampling with classifier-free guidance"): | |
> samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, | |
> ddim_steps=ddim_steps, eta=ddim_eta, | |
> unconditional_guidance_scale=unconditional_guidance_scale, | |
> unconditional_conditioning=uc, | |
> ) | |
> x_samples_cfg = self.decode_first_stage(samples_cfg) | |
> log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg | |
> | |
> if plot_progressive_rows: | |
> with ema_scope("Plotting Progressives"): | |
> img, progressives = self.progressive_denoising(c, | |
> shape=(self.channels, self.image_size, self.image_size), | |
> batch_size=N) | |
> prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") | |
> log["progressive_row"] = prog_row | |
> | |
> return log | |
> | |
> | |
> class LatentFinetuneDiffusion(LatentDiffusion): | |
> """ | |
> Basis for different finetunas, such as inpainting or depth2image | |
> To disable finetuning mode, set finetune_keys to None | |
> """ | |
> | |
> def __init__(self, | |
> concat_keys: tuple, | |
> finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", | |
> "model_ema.diffusion_modelinput_blocks00weight" | |
> ), | |
> keep_finetune_dims=4, | |
> # if model was trained without concat mode before and we would like to keep these channels | |
> c_concat_log_start=None, # to log reconstruction of c_concat codes | |
> c_concat_log_end=None, | |
> *args, **kwargs | |
> ): | |
> ckpt_path = kwargs.pop("ckpt_path", None) | |
> ignore_keys = kwargs.pop("ignore_keys", list()) | |
> super().__init__(*args, **kwargs) | |
> self.finetune_keys = finetune_keys | |
> self.concat_keys = concat_keys | |
> self.keep_dims = keep_finetune_dims | |
> self.c_concat_log_start = c_concat_log_start | |
> self.c_concat_log_end = c_concat_log_end | |
> if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' | |
> if exists(ckpt_path): | |
> self.init_from_ckpt(ckpt_path, ignore_keys) | |
> | |
> def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): | |
> sd = torch.load(path, map_location="cpu") | |
> if "state_dict" in list(sd.keys()): | |
> sd = sd["state_dict"] | |
> keys = list(sd.keys()) | |
> for k in keys: | |
> for ik in ignore_keys: | |
> if k.startswith(ik): | |
> print("Deleting key {} from state_dict.".format(k)) | |
> del sd[k] | |
> | |
> # make it explicit, finetune by including extra input channels | |
> if exists(self.finetune_keys) and k in self.finetune_keys: | |
> new_entry = None | |
> for name, param in self.named_parameters(): | |
> if name in self.finetune_keys: | |
> print( | |
> f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") | |
> new_entry = torch.zeros_like(param) # zero init | |
> assert exists(new_entry), 'did not find matching parameter to modify' | |
> new_entry[:, :self.keep_dims, ...] = sd[k] | |
> sd[k] = new_entry | |
> | |
> missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( | |
> sd, strict=False) | |
> print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") | |
> if len(missing) > 0: | |
> print(f"Missing Keys: {missing}") | |
> if len(unexpected) > 0: | |
> print(f"Unexpected Keys: {unexpected}") | |
> | |
> @torch.no_grad() | |
> def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, | |
> quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, | |
> plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, | |
> use_ema_scope=True, | |
> **kwargs): | |
> ema_scope = self.ema_scope if use_ema_scope else nullcontext | |
> use_ddim = ddim_steps is not None | |
> | |
> log = dict() | |
> z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) | |
> c_cat, c = c["c_concat"][0], c["c_crossattn"][0] | |
> N = min(x.shape[0], N) | |
> n_row = min(x.shape[0], n_row) | |
> log["inputs"] = x | |
> log["reconstruction"] = xrec | |
> if self.model.conditioning_key is not None: | |
> if hasattr(self.cond_stage_model, "decode"): | |
> xc = self.cond_stage_model.decode(c) | |
> log["conditioning"] = xc | |
> elif self.cond_stage_key in ["caption", "txt"]: | |
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) | |
> log["conditioning"] = xc | |
> elif self.cond_stage_key in ['class_label', 'cls']: | |
> xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) | |
> log['conditioning'] = xc | |
> elif isimage(xc): | |
> log["conditioning"] = xc | |
> if ismap(xc): | |
> log["original_conditioning"] = self.to_rgb(xc) | |
> | |
> if not (self.c_concat_log_start is None and self.c_concat_log_end is None): | |
> log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) | |
> | |
> if plot_diffusion_rows: | |
> # get diffusion row | |
> diffusion_row = list() | |
> z_start = z[:n_row] | |
> for t in range(self.num_timesteps): | |
> if t % self.log_every_t == 0 or t == self.num_timesteps - 1: | |
> t = repeat(torch.tensor([t]), '1 -> b', b=n_row) | |
> t = t.to(self.device).long() | |
> noise = torch.randn_like(z_start) | |
> z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) | |
> diffusion_row.append(self.decode_first_stage(z_noisy)) | |
> | |
> diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W | |
> diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') | |
> diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') | |
> diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) | |
> log["diffusion_row"] = diffusion_grid | |
> | |
> if sample: | |
> # get denoise row | |
> with ema_scope("Sampling"): | |
> samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, | |
> batch_size=N, ddim=use_ddim, | |
> ddim_steps=ddim_steps, eta=ddim_eta) | |
> # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) | |
> x_samples = self.decode_first_stage(samples) | |
> log["samples"] = x_samples | |
> if plot_denoise_rows: | |
> denoise_grid = self._get_denoise_row_from_list(z_denoise_row) | |
> log["denoise_row"] = denoise_grid | |
> | |
> if unconditional_guidance_scale > 1.0: | |
> uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) | |
> uc_cat = c_cat | |
> uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} | |
> with ema_scope("Sampling with classifier-free guidance"): | |
> samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, | |
> batch_size=N, ddim=use_ddim, | |
> ddim_steps=ddim_steps, eta=ddim_eta, | |
> unconditional_guidance_scale=unconditional_guidance_scale, | |
> unconditional_conditioning=uc_full, | |
> ) | |
> x_samples_cfg = self.decode_first_stage(samples_cfg) | |
> log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg | |
> | |
> return log | |
> | |
> | |
> class LatentInpaintDiffusion(LatentFinetuneDiffusion): | |
> """ | |
> can either run as pure inpainting model (only concat mode) or with mixed conditionings, | |
> e.g. mask as concat and text via cross-attn. | |
> To disable finetuning mode, set finetune_keys to None | |
> """ | |
> | |
> def __init__(self, | |
> concat_keys=("mask", "masked_image"), | |
> masked_image_key="masked_image", | |
> *args, **kwargs | |
> ): | |
> super().__init__(concat_keys, *args, **kwargs) | |
> self.masked_image_key = masked_image_key | |
> assert self.masked_image_key in concat_keys | |
> | |
> @torch.no_grad() | |
> def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): | |
> # note: restricted to non-trainable encoders currently | |
> assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' | |
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, | |
> force_c_encode=True, return_original_cond=True, bs=bs) | |
> | |
> assert exists(self.concat_keys) | |
> c_cat = list() | |
> for ck in self.concat_keys: | |
> cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() | |
> if bs is not None: | |
> cc = cc[:bs] | |
> cc = cc.to(self.device) | |
> bchw = z.shape | |
> if ck != self.masked_image_key: | |
> cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) | |
> else: | |
> cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) | |
> c_cat.append(cc) | |
> c_cat = torch.cat(c_cat, dim=1) | |
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} | |
> if return_first_stage_outputs: | |
> return z, all_conds, x, xrec, xc | |
> return z, all_conds | |
> | |
> @torch.no_grad() | |
> def log_images(self, *args, **kwargs): | |
> log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) | |
> log["masked_image"] = rearrange(args[0]["masked_image"], | |
> 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() | |
> return log | |
> | |
> | |
> class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): | |
> """ | |
> condition on monocular depth estimation | |
> """ | |
> | |
> def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): | |
> super().__init__(concat_keys=concat_keys, *args, **kwargs) | |
> self.depth_model = instantiate_from_config(depth_stage_config) | |
> self.depth_stage_key = concat_keys[0] | |
> | |
> @torch.no_grad() | |
> def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): | |
> # note: restricted to non-trainable encoders currently | |
> assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' | |
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, | |
> force_c_encode=True, return_original_cond=True, bs=bs) | |
> | |
> assert exists(self.concat_keys) | |
> assert len(self.concat_keys) == 1 | |
> c_cat = list() | |
> for ck in self.concat_keys: | |
> cc = batch[ck] | |
> if bs is not None: | |
> cc = cc[:bs] | |
> cc = cc.to(self.device) | |
> cc = self.depth_model(cc) | |
> cc = torch.nn.functional.interpolate( | |
> cc, | |
> size=z.shape[2:], | |
> mode="bicubic", | |
> align_corners=False, | |
> ) | |
> | |
> depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], | |
> keepdim=True) | |
> cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. | |
> c_cat.append(cc) | |
> c_cat = torch.cat(c_cat, dim=1) | |
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} | |
> if return_first_stage_outputs: | |
> return z, all_conds, x, xrec, xc | |
> return z, all_conds | |
> | |
> @torch.no_grad() | |
> def log_images(self, *args, **kwargs): | |
> log = super().log_images(*args, **kwargs) | |
> depth = self.depth_model(args[0][self.depth_stage_key]) | |
> depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ | |
> torch.amax(depth, dim=[1, 2, 3], keepdim=True) | |
> log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. | |
> return log | |
> | |
> | |
> class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): | |
> """ | |
> condition on low-res image (and optionally on some spatial noise augmentation) | |
> """ | |
> def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, | |
> low_scale_config=None, low_scale_key=None, *args, **kwargs): | |
> super().__init__(concat_keys=concat_keys, *args, **kwargs) | |
> self.reshuffle_patch_size = reshuffle_patch_size | |
> self.low_scale_model = None | |
> if low_scale_config is not None: | |
> print("Initializing a low-scale model") | |
> assert exists(low_scale_key) | |
> self.instantiate_low_stage(low_scale_config) | |
> self.low_scale_key = low_scale_key | |
> | |
> def instantiate_low_stage(self, config): | |
> model = instantiate_from_config(config) | |
> self.low_scale_model = model.eval() | |
> self.low_scale_model.train = disabled_train | |
> for param in self.low_scale_model.parameters(): | |
> param.requires_grad = False | |
> | |
> @torch.no_grad() | |
> def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): | |
> # note: restricted to non-trainable encoders currently | |
> assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' | |
> z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, | |
> force_c_encode=True, return_original_cond=True, bs=bs) | |
> | |
> assert exists(self.concat_keys) | |
> assert len(self.concat_keys) == 1 | |
> # optionally make spatial noise_level here | |
> c_cat = list() | |
> noise_level = None | |
> for ck in self.concat_keys: | |
> cc = batch[ck] | |
> cc = rearrange(cc, 'b h w c -> b c h w') | |
> if exists(self.reshuffle_patch_size): | |
> assert isinstance(self.reshuffle_patch_size, int) | |
> cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', | |
> p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) | |
> if bs is not None: | |
> cc = cc[:bs] | |
> cc = cc.to(self.device) | |
> if exists(self.low_scale_model) and ck == self.low_scale_key: | |
> cc, noise_level = self.low_scale_model(cc) | |
> c_cat.append(cc) | |
> c_cat = torch.cat(c_cat, dim=1) | |
> if exists(noise_level): | |
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} | |
> else: | |
> all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} | |
> if return_first_stage_outputs: | |
> return z, all_conds, x, xrec, xc | |
> return z, all_conds | |
> | |
> @torch.no_grad() | |
> def log_images(self, *args, **kwargs): | |
> log = super().log_images(*args, **kwargs) | |
> log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') | |
> return log | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/dpm_solver/dpm_solver.py sd2/ldm/models/diffusion/dpm_solver/dpm_solver.py | |
3a4 | |
> from tqdm import tqdm | |
16d16 | |
< | |
21d20 | |
< | |
25d23 | |
< | |
29d26 | |
< | |
31d27 | |
< | |
33d28 | |
< | |
35d29 | |
< | |
37d30 | |
< | |
39d31 | |
< | |
44d35 | |
< | |
48d38 | |
< | |
50d39 | |
< | |
58,59d46 | |
< | |
< | |
61d47 | |
< | |
64d49 | |
< | |
71d55 | |
< | |
73d56 | |
< | |
81d63 | |
< | |
83d64 | |
< | |
86d66 | |
< | |
89d68 | |
< | |
92d70 | |
< | |
96c74,76 | |
< raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) | |
--- | |
> raise ValueError( | |
> "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( | |
> schedule)) | |
115c95,96 | |
< self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s | |
--- | |
> self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( | |
> 1. + self.cosine_s) / math.pi - self.cosine_s | |
130c111,112 | |
< return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) | |
--- | |
> return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), | |
> self.log_alpha_array.to(t.device)).reshape((-1)) | |
168c150,151 | |
< t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) | |
--- | |
> t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), | |
> torch.flip(self.t_array.to(lamb.device), [1])) | |
172c155,156 | |
< t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s | |
--- | |
> t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( | |
> 1. + self.cosine_s) / math.pi - self.cosine_s | |
190d173 | |
< | |
193d175 | |
< | |
195d176 | |
< | |
197d177 | |
< | |
199d178 | |
< | |
202d180 | |
< | |
213d190 | |
< | |
220d196 | |
< | |
226d201 | |
< | |
231d205 | |
< | |
234d207 | |
< | |
241d213 | |
< | |
245d216 | |
< | |
248d218 | |
< | |
256d225 | |
< | |
258d226 | |
< | |
354d321 | |
< | |
360d326 | |
< | |
412d377 | |
< | |
437c402,403 | |
< raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) | |
--- | |
> raise ValueError( | |
> "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) | |
442d407 | |
< | |
456d420 | |
< | |
495c459,460 | |
< timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] | |
--- | |
> timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ | |
> torch.cumsum(torch.tensor([0, ] + orders)).to(device)] | |
507d471 | |
< | |
551c515,516 | |
< def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): | |
--- | |
> def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, | |
> solver_type='dpm_solver'): | |
554d518 | |
< | |
578c542,543 | |
< log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) | |
--- | |
> log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( | |
> s1), ns.marginal_log_mean_coeff(t) | |
603c568,569 | |
< + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) | |
--- | |
> + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( | |
> model_s1 - model_s) | |
633c599,600 | |
< def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): | |
--- | |
> def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, | |
> return_intermediate=False, solver_type='dpm_solver'): | |
636d602 | |
< | |
667,668c633,636 | |
< log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) | |
< sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) | |
--- | |
> log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( | |
> s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) | |
> sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( | |
> s2), ns.marginal_std(t) | |
758d725 | |
< | |
775c742,743 | |
< lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) | |
--- | |
> lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( | |
> t_prev_0), ns.marginal_lambda(t) | |
815d782 | |
< | |
830c797,798 | |
< lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) | |
--- | |
> lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( | |
> t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) | |
859c827,828 | |
< def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): | |
--- | |
> def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, | |
> r2=None): | |
862d830 | |
< | |
879c847,848 | |
< return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) | |
--- | |
> return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, | |
> solver_type=solver_type, r1=r1) | |
881c850,851 | |
< return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) | |
--- | |
> return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, | |
> solver_type=solver_type, r1=r1, r2=r2) | |
888d857 | |
< | |
909c878,879 | |
< def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): | |
--- | |
> def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, | |
> solver_type='dpm_solver'): | |
912d881 | |
< | |
928d896 | |
< | |
941c909,911 | |
< higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) | |
--- | |
> higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, | |
> solver_type=solver_type, | |
> **kwargs) | |
944,945c914,919 | |
< lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) | |
< higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) | |
--- | |
> lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, | |
> return_intermediate=True, | |
> solver_type=solver_type) | |
> higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, | |
> solver_type=solver_type, | |
> **kwargs) | |
971d944 | |
< | |
973d945 | |
< | |
1012d983 | |
< | |
1014d984 | |
< | |
1028d997 | |
< | |
1033d1001 | |
< | |
1053d1020 | |
< | |
1069d1035 | |
< | |
1076c1042,1043 | |
< x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) | |
--- | |
> x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, | |
> solver_type=solver_type) | |
1086c1053 | |
< for init_order in range(1, order): | |
--- | |
> for init_order in tqdm(range(1, order), desc="DPM init order"): | |
1088c1055,1056 | |
< x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) | |
--- | |
> x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, | |
> solver_type=solver_type) | |
1092c1060 | |
< for step in range(order, steps + 1): | |
--- | |
> for step in tqdm(range(order, steps + 1), desc="DPM multistep"): | |
1098c1066,1067 | |
< x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) | |
--- | |
> x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, | |
> solver_type=solver_type) | |
1108c1077,1080 | |
< timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) | |
--- | |
> timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, | |
> skip_type=skip_type, | |
> t_T=t_T, t_0=t_0, | |
> device=device) | |
1115c1087,1088 | |
< timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) | |
--- | |
> timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), | |
> N=order, device=device) | |
1127d1099 | |
< | |
1137d1108 | |
< | |
1177d1147 | |
< | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/dpm_solver/sampler.py sd2/ldm/models/diffusion/dpm_solver/sampler.py | |
2d1 | |
< | |
7a7,12 | |
> MODEL_TYPES = { | |
> "eps": "noise", | |
> "v": "v" | |
> } | |
> | |
> | |
59c64 | |
< # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') | |
--- | |
> print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') | |
72c77 | |
< model_type="noise", | |
--- | |
> model_type=MODEL_TYPES[self.model.parameterization], | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/models/diffusion/plms.py sd2/ldm/models/diffusion/plms.py | |
8a9 | |
> from ldm.models.diffusion.sampling_util import norm_thresholding | |
79a81 | |
> dynamic_threshold=None, | |
110a113 | |
> dynamic_threshold=dynamic_threshold, | |
120c123,124 | |
< unconditional_guidance_scale=1., unconditional_conditioning=None,): | |
--- | |
> unconditional_guidance_scale=1., unconditional_conditioning=None, | |
> dynamic_threshold=None): | |
158c162,163 | |
< old_eps=old_eps, t_next=ts_next) | |
--- | |
> old_eps=old_eps, t_next=ts_next, | |
> dynamic_threshold=dynamic_threshold) | |
175c180,181 | |
< unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): | |
--- | |
> unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, | |
> dynamic_threshold=None): | |
209a216,217 | |
> if dynamic_threshold is not None: | |
> pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) | |
Only in sd2/ldm/models/diffusion: sampling_util.py | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/attention.py sd2/ldm/modules/attention.py | |
6a7 | |
> from typing import Optional, Any | |
10a12,19 | |
> try: | |
> import xformers | |
> import xformers.ops | |
> XFORMERS_IS_AVAILBLE = True | |
> except: | |
> XFORMERS_IS_AVAILBLE = False | |
> | |
> | |
80,98d88 | |
< class LinearAttention(nn.Module): | |
< def __init__(self, dim, heads=4, dim_head=32): | |
< super().__init__() | |
< self.heads = heads | |
< hidden_dim = dim_head * heads | |
< self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) | |
< self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
< | |
< def forward(self, x): | |
< b, c, h, w = x.shape | |
< qkv = self.to_qkv(x) | |
< q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) | |
< k = k.softmax(dim=-1) | |
< context = torch.einsum('bhdn,bhen->bhde', k, v) | |
< out = torch.einsum('bhde,bhdn->bhen', context, q) | |
< out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) | |
< return self.to_out(out) | |
< | |
< | |
180a171 | |
> del q, k | |
189c180 | |
< attn = sim.softmax(dim=-1) | |
--- | |
> sim = sim.softmax(dim=-1) | |
191c182 | |
< out = einsum('b i j, b j d -> b i d', attn, v) | |
--- | |
> out = einsum('b i j, b j d -> b i d', sim, v) | |
195a187,235 | |
> class MemoryEfficientCrossAttention(nn.Module): | |
> # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 | |
> def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): | |
> super().__init__() | |
> print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " | |
> f"{heads} heads.") | |
> inner_dim = dim_head * heads | |
> context_dim = default(context_dim, query_dim) | |
> | |
> self.heads = heads | |
> self.dim_head = dim_head | |
> | |
> self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
> self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
> self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
> | |
> self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) | |
> self.attention_op: Optional[Any] = None | |
> | |
> def forward(self, x, context=None, mask=None): | |
> q = self.to_q(x) | |
> context = default(context, x) | |
> k = self.to_k(context) | |
> v = self.to_v(context) | |
> | |
> b, _, _ = q.shape | |
> q, k, v = map( | |
> lambda t: t.unsqueeze(3) | |
> .reshape(b, t.shape[1], self.heads, self.dim_head) | |
> .permute(0, 2, 1, 3) | |
> .reshape(b * self.heads, t.shape[1], self.dim_head) | |
> .contiguous(), | |
> (q, k, v), | |
> ) | |
> | |
> # actually compute the attention, what we cannot get enough of | |
> out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) | |
> | |
> if exists(mask): | |
> raise NotImplementedError | |
> out = ( | |
> out.unsqueeze(0) | |
> .reshape(b, self.heads, out.shape[1], self.dim_head) | |
> .permute(0, 2, 1, 3) | |
> .reshape(b, out.shape[1], self.heads * self.dim_head) | |
> ) | |
> return self.to_out(out) | |
> | |
> | |
197c237,242 | |
< def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): | |
--- | |
> ATTENTION_MODES = { | |
> "softmax": CrossAttention, # vanilla attention | |
> "softmax-xformers": MemoryEfficientCrossAttention | |
> } | |
> def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, | |
> disable_self_attn=False): | |
199c244,249 | |
< self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention | |
--- | |
> attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" | |
> assert attn_mode in self.ATTENTION_MODES | |
> attn_cls = self.ATTENTION_MODES[attn_mode] | |
> self.disable_self_attn = disable_self_attn | |
> self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, | |
> context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn | |
201c251 | |
< self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, | |
--- | |
> self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, | |
212c262 | |
< x = self.attn1(self.norm1(x)) + x | |
--- | |
> x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x | |
224a275 | |
> NEW: use_linear for more efficiency instead of the 1x1 convs | |
227c278,280 | |
< depth=1, dropout=0., context_dim=None): | |
--- | |
> depth=1, dropout=0., context_dim=None, | |
> disable_self_attn=False, use_linear=False, | |
> use_checkpoint=True): | |
228a282,283 | |
> if exists(context_dim) and not isinstance(context_dim, list): | |
> context_dim = [context_dim] | |
232c287 | |
< | |
--- | |
> if not use_linear: | |
237a293,294 | |
> else: | |
> self.proj_in = nn.Linear(in_channels, inner_dim) | |
240c297,298 | |
< [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) | |
--- | |
> [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], | |
> disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) | |
243c301 | |
< | |
--- | |
> if not use_linear: | |
248a307,309 | |
> else: | |
> self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) | |
> self.use_linear = use_linear | |
251a313,314 | |
> if not isinstance(context, list): | |
> context = [context] | |
254a318,321 | |
> if not self.use_linear: | |
> x = self.proj_in(x) | |
> x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
> if self.use_linear: | |
256,259c323,328 | |
< x = rearrange(x, 'b c h w -> b (h w) c') | |
< for block in self.transformer_blocks: | |
< x = block(x, context=context) | |
< x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) | |
--- | |
> for i, block in enumerate(self.transformer_blocks): | |
> x = block(x, context=context[i]) | |
> if self.use_linear: | |
> x = self.proj_out(x) | |
> x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
> if not self.use_linear: | |
261a331 | |
> | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/diffusionmodules/model.py sd2/ldm/modules/diffusionmodules/model.py | |
6a7 | |
> from typing import Optional, Any | |
8,9c9,17 | |
< from ldm.util import instantiate_from_config | |
< from ldm.modules.attention import LinearAttention | |
--- | |
> from ldm.modules.attention import MemoryEfficientCrossAttention | |
> | |
> try: | |
> import xformers | |
> import xformers.ops | |
> XFORMERS_IS_AVAILBLE = True | |
> except: | |
> XFORMERS_IS_AVAILBLE = False | |
> print("No module 'xformers'. Proceeding without it.") | |
144,149d151 | |
< class LinAttnBlock(LinearAttention): | |
< """to match AttnBlock usage""" | |
< def __init__(self, in_channels): | |
< super().__init__(dim=in_channels, heads=1, dim_head=in_channels) | |
< | |
< | |
177d178 | |
< | |
203a205,258 | |
> class MemoryEfficientAttnBlock(nn.Module): | |
> """ | |
> Uses xformers efficient implementation, | |
> see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 | |
> Note: this is a single-head self-attention operation | |
> """ | |
> # | |
> def __init__(self, in_channels): | |
> super().__init__() | |
> self.in_channels = in_channels | |
> | |
> self.norm = Normalize(in_channels) | |
> self.q = torch.nn.Conv2d(in_channels, | |
> in_channels, | |
> kernel_size=1, | |
> stride=1, | |
> padding=0) | |
> self.k = torch.nn.Conv2d(in_channels, | |
> in_channels, | |
> kernel_size=1, | |
> stride=1, | |
> padding=0) | |
> self.v = torch.nn.Conv2d(in_channels, | |
> in_channels, | |
> kernel_size=1, | |
> stride=1, | |
> padding=0) | |
> self.proj_out = torch.nn.Conv2d(in_channels, | |
> in_channels, | |
> kernel_size=1, | |
> stride=1, | |
> padding=0) | |
> self.attention_op: Optional[Any] = None | |
> | |
> def forward(self, x): | |
> h_ = x | |
> h_ = self.norm(h_) | |
> q = self.q(h_) | |
> k = self.k(h_) | |
> v = self.v(h_) | |
> | |
> # compute attention | |
> B, C, H, W = q.shape | |
> q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) | |
> | |
> q, k, v = map( | |
> lambda t: t.unsqueeze(3) | |
> .reshape(B, t.shape[1], 1, C) | |
> .permute(0, 2, 1, 3) | |
> .reshape(B * 1, t.shape[1], C) | |
> .contiguous(), | |
> (q, k, v), | |
> ) | |
> out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) | |
205,206c260,283 | |
< def make_attn(in_channels, attn_type="vanilla"): | |
< assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' | |
--- | |
> out = ( | |
> out.unsqueeze(0) | |
> .reshape(B, 1, out.shape[1], C) | |
> .permute(0, 2, 1, 3) | |
> .reshape(B, out.shape[1], C) | |
> ) | |
> out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) | |
> out = self.proj_out(out) | |
> return x+out | |
> | |
> | |
> class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): | |
> def forward(self, x, context=None, mask=None): | |
> b, c, h, w = x.shape | |
> x = rearrange(x, 'b c h w -> b (h w) c') | |
> out = super().forward(x, context=context, mask=mask) | |
> out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) | |
> return x + out | |
> | |
> | |
> def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): | |
> assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' | |
> if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": | |
> attn_type = "vanilla-xformers" | |
208a286 | |
> assert attn_kwargs is None | |
209a288,293 | |
> elif attn_type == "vanilla-xformers": | |
> print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") | |
> return MemoryEfficientAttnBlock(in_channels) | |
> elif type == "memory-efficient-cross-attn": | |
> attn_kwargs["query_dim"] = in_channels | |
> return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) | |
213c297 | |
< return LinAttnBlock(in_channels) | |
--- | |
> raise NotImplementedError() | |
769,835d852 | |
< | |
< class FirstStagePostProcessor(nn.Module): | |
< | |
< def __init__(self, ch_mult:list, in_channels, | |
< pretrained_model:nn.Module=None, | |
< reshape=False, | |
< n_channels=None, | |
< dropout=0., | |
< pretrained_config=None): | |
< super().__init__() | |
< if pretrained_config is None: | |
< assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' | |
< self.pretrained_model = pretrained_model | |
< else: | |
< assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' | |
< self.instantiate_pretrained(pretrained_config) | |
< | |
< self.do_reshape = reshape | |
< | |
< if n_channels is None: | |
< n_channels = self.pretrained_model.encoder.ch | |
< | |
< self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) | |
< self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, | |
< stride=1,padding=1) | |
< | |
< blocks = [] | |
< downs = [] | |
< ch_in = n_channels | |
< for m in ch_mult: | |
< blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) | |
< ch_in = m * n_channels | |
< downs.append(Downsample(ch_in, with_conv=False)) | |
< | |
< self.model = nn.ModuleList(blocks) | |
< self.downsampler = nn.ModuleList(downs) | |
< | |
< | |
< def instantiate_pretrained(self, config): | |
< model = instantiate_from_config(config) | |
< self.pretrained_model = model.eval() | |
< # self.pretrained_model.train = False | |
< for param in self.pretrained_model.parameters(): | |
< param.requires_grad = False | |
< | |
< | |
< @torch.no_grad() | |
< def encode_with_pretrained(self,x): | |
< c = self.pretrained_model.encode(x) | |
< if isinstance(c, DiagonalGaussianDistribution): | |
< c = c.mode() | |
< return c | |
< | |
< def forward(self,x): | |
< z_fs = self.encode_with_pretrained(x) | |
< z = self.proj_norm(z_fs) | |
< z = self.proj(z) | |
< z = nonlinearity(z) | |
< | |
< for submodel, downmodel in zip(self.model,self.downsampler): | |
< z = submodel(z,temb=None) | |
< z = downmodel(z) | |
< | |
< if self.do_reshape: | |
< z = rearrange(z,'b c h w -> b (h w) c') | |
< return z | |
< | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/diffusionmodules/openaimodel.py sd2/ldm/modules/diffusionmodules/openaimodel.py | |
2d1 | |
< from functools import partial | |
4d2 | |
< from typing import Iterable | |
20a19 | |
> from ldm.util import exists | |
468a468,471 | |
> disable_self_attentions=None, | |
> num_attention_blocks=None, | |
> disable_middle_self_attn=False, | |
> use_linear_in_transformer=False, | |
492a496,501 | |
> if isinstance(num_res_blocks, int): | |
> self.num_res_blocks = len(channel_mult) * [num_res_blocks] | |
> else: | |
> if len(num_res_blocks) != len(channel_mult): | |
> raise ValueError("provide num_res_blocks either as an int (globally constant) or " | |
> "as a list/tuple (per-level) with the same length as channel_mult") | |
493a503,513 | |
> if disable_self_attentions is not None: | |
> # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not | |
> assert len(disable_self_attentions) == len(channel_mult) | |
> if num_attention_blocks is not None: | |
> assert len(num_attention_blocks) == len(self.num_res_blocks) | |
> assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) | |
> print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " | |
> f"This option has LESS priority than attention_resolutions {attention_resolutions}, " | |
> f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " | |
> f"attention will still not be set.") | |
> | |
513a534 | |
> if isinstance(self.num_classes, int): | |
514a536,540 | |
> elif self.num_classes == "continuous": | |
> print("setting up linear c_adm embedding layer") | |
> self.label_emb = nn.Linear(1, time_embed_dim) | |
> else: | |
> raise ValueError() | |
528c554 | |
< for _ in range(num_res_blocks): | |
--- | |
> for nr in range(self.num_res_blocks[level]): | |
549a576,581 | |
> if exists(disable_self_attentions): | |
> disabled_sa = disable_self_attentions[level] | |
> else: | |
> disabled_sa = False | |
> | |
> if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: | |
558c590,592 | |
< ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim | |
--- | |
> ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, | |
> disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, | |
> use_checkpoint=use_checkpoint | |
612,613c646,649 | |
< ) if not use_spatial_transformer else SpatialTransformer( | |
< ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim | |
--- | |
> ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn | |
> ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, | |
> disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, | |
> use_checkpoint=use_checkpoint | |
628c664 | |
< for i in range(num_res_blocks + 1): | |
--- | |
> for i in range(self.num_res_blocks[level] + 1): | |
650a687,692 | |
> if exists(disable_self_attentions): | |
> disabled_sa = disable_self_attentions[level] | |
> else: | |
> disabled_sa = False | |
> | |
> if not exists(num_attention_blocks) or i < num_attention_blocks[level]: | |
659c701,703 | |
< ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim | |
--- | |
> ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, | |
> disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, | |
> use_checkpoint=use_checkpoint | |
662c706 | |
< if level and i == num_res_blocks: | |
--- | |
> if level and i == self.num_res_blocks[level]: | |
727c771 | |
< assert y.shape == (x.shape[0],) | |
--- | |
> assert y.shape[0] == x.shape[0] | |
743,961d786 | |
< | |
< | |
< class EncoderUNetModel(nn.Module): | |
< """ | |
< The half UNet model with attention and timestep embedding. | |
< For usage, see UNet. | |
< """ | |
< | |
< def __init__( | |
< self, | |
< image_size, | |
< in_channels, | |
< model_channels, | |
< out_channels, | |
< num_res_blocks, | |
< attention_resolutions, | |
< dropout=0, | |
< channel_mult=(1, 2, 4, 8), | |
< conv_resample=True, | |
< dims=2, | |
< use_checkpoint=False, | |
< use_fp16=False, | |
< num_heads=1, | |
< num_head_channels=-1, | |
< num_heads_upsample=-1, | |
< use_scale_shift_norm=False, | |
< resblock_updown=False, | |
< use_new_attention_order=False, | |
< pool="adaptive", | |
< *args, | |
< **kwargs | |
< ): | |
< super().__init__() | |
< | |
< if num_heads_upsample == -1: | |
< num_heads_upsample = num_heads | |
< | |
< self.in_channels = in_channels | |
< self.model_channels = model_channels | |
< self.out_channels = out_channels | |
< self.num_res_blocks = num_res_blocks | |
< self.attention_resolutions = attention_resolutions | |
< self.dropout = dropout | |
< self.channel_mult = channel_mult | |
< self.conv_resample = conv_resample | |
< self.use_checkpoint = use_checkpoint | |
< self.dtype = th.float16 if use_fp16 else th.float32 | |
< self.num_heads = num_heads | |
< self.num_head_channels = num_head_channels | |
< self.num_heads_upsample = num_heads_upsample | |
< | |
< time_embed_dim = model_channels * 4 | |
< self.time_embed = nn.Sequential( | |
< linear(model_channels, time_embed_dim), | |
< nn.SiLU(), | |
< linear(time_embed_dim, time_embed_dim), | |
< ) | |
< | |
< self.input_blocks = nn.ModuleList( | |
< [ | |
< TimestepEmbedSequential( | |
< conv_nd(dims, in_channels, model_channels, 3, padding=1) | |
< ) | |
< ] | |
< ) | |
< self._feature_size = model_channels | |
< input_block_chans = [model_channels] | |
< ch = model_channels | |
< ds = 1 | |
< for level, mult in enumerate(channel_mult): | |
< for _ in range(num_res_blocks): | |
< layers = [ | |
< ResBlock( | |
< ch, | |
< time_embed_dim, | |
< dropout, | |
< out_channels=mult * model_channels, | |
< dims=dims, | |
< use_checkpoint=use_checkpoint, | |
< use_scale_shift_norm=use_scale_shift_norm, | |
< ) | |
< ] | |
< ch = mult * model_channels | |
< if ds in attention_resolutions: | |
< layers.append( | |
< AttentionBlock( | |
< ch, | |
< use_checkpoint=use_checkpoint, | |
< num_heads=num_heads, | |
< num_head_channels=num_head_channels, | |
< use_new_attention_order=use_new_attention_order, | |
< ) | |
< ) | |
< self.input_blocks.append(TimestepEmbedSequential(*layers)) | |
< self._feature_size += ch | |
< input_block_chans.append(ch) | |
< if level != len(channel_mult) - 1: | |
< out_ch = ch | |
< self.input_blocks.append( | |
< TimestepEmbedSequential( | |
< ResBlock( | |
< ch, | |
< time_embed_dim, | |
< dropout, | |
< out_channels=out_ch, | |
< dims=dims, | |
< use_checkpoint=use_checkpoint, | |
< use_scale_shift_norm=use_scale_shift_norm, | |
< down=True, | |
< ) | |
< if resblock_updown | |
< else Downsample( | |
< ch, conv_resample, dims=dims, out_channels=out_ch | |
< ) | |
< ) | |
< ) | |
< ch = out_ch | |
< input_block_chans.append(ch) | |
< ds *= 2 | |
< self._feature_size += ch | |
< | |
< self.middle_block = TimestepEmbedSequential( | |
< ResBlock( | |
< ch, | |
< time_embed_dim, | |
< dropout, | |
< dims=dims, | |
< use_checkpoint=use_checkpoint, | |
< use_scale_shift_norm=use_scale_shift_norm, | |
< ), | |
< AttentionBlock( | |
< ch, | |
< use_checkpoint=use_checkpoint, | |
< num_heads=num_heads, | |
< num_head_channels=num_head_channels, | |
< use_new_attention_order=use_new_attention_order, | |
< ), | |
< ResBlock( | |
< ch, | |
< time_embed_dim, | |
< dropout, | |
< dims=dims, | |
< use_checkpoint=use_checkpoint, | |
< use_scale_shift_norm=use_scale_shift_norm, | |
< ), | |
< ) | |
< self._feature_size += ch | |
< self.pool = pool | |
< if pool == "adaptive": | |
< self.out = nn.Sequential( | |
< normalization(ch), | |
< nn.SiLU(), | |
< nn.AdaptiveAvgPool2d((1, 1)), | |
< zero_module(conv_nd(dims, ch, out_channels, 1)), | |
< nn.Flatten(), | |
< ) | |
< elif pool == "attention": | |
< assert num_head_channels != -1 | |
< self.out = nn.Sequential( | |
< normalization(ch), | |
< nn.SiLU(), | |
< AttentionPool2d( | |
< (image_size // ds), ch, num_head_channels, out_channels | |
< ), | |
< ) | |
< elif pool == "spatial": | |
< self.out = nn.Sequential( | |
< nn.Linear(self._feature_size, 2048), | |
< nn.ReLU(), | |
< nn.Linear(2048, self.out_channels), | |
< ) | |
< elif pool == "spatial_v2": | |
< self.out = nn.Sequential( | |
< nn.Linear(self._feature_size, 2048), | |
< normalization(2048), | |
< nn.SiLU(), | |
< nn.Linear(2048, self.out_channels), | |
< ) | |
< else: | |
< raise NotImplementedError(f"Unexpected {pool} pooling") | |
< | |
< def convert_to_fp16(self): | |
< """ | |
< Convert the torso of the model to float16. | |
< """ | |
< self.input_blocks.apply(convert_module_to_f16) | |
< self.middle_block.apply(convert_module_to_f16) | |
< | |
< def convert_to_fp32(self): | |
< """ | |
< Convert the torso of the model to float32. | |
< """ | |
< self.input_blocks.apply(convert_module_to_f32) | |
< self.middle_block.apply(convert_module_to_f32) | |
< | |
< def forward(self, x, timesteps): | |
< """ | |
< Apply the model to an input batch. | |
< :param x: an [N x C x ...] Tensor of inputs. | |
< :param timesteps: a 1-D batch of timesteps. | |
< :return: an [N x K] Tensor of outputs. | |
< """ | |
< emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) | |
< | |
< results = [] | |
< h = x.type(self.dtype) | |
< for module in self.input_blocks: | |
< h = module(h, emb) | |
< if self.pool.startswith("spatial"): | |
< results.append(h.type(x.dtype).mean(dim=(2, 3))) | |
< h = self.middle_block(h, emb) | |
< if self.pool.startswith("spatial"): | |
< results.append(h.type(x.dtype).mean(dim=(2, 3))) | |
< h = th.cat(results, axis=-1) | |
< return self.out(h) | |
< else: | |
< h = h.type(x.dtype) | |
< return self.out(h) | |
< | |
Only in sd2/ldm/modules/diffusionmodules: upscaling.py | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/diffusionmodules/util.py sd2/ldm/modules/diffusionmodules/util.py | |
125c125,127 | |
< | |
--- | |
> ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), | |
> "dtype": torch.get_autocast_gpu_dtype(), | |
> "cache_enabled": torch.is_autocast_cache_enabled()} | |
133c135,136 | |
< with torch.enable_grad(): | |
--- | |
> with torch.enable_grad(), \ | |
> torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/ema.py sd2/ldm/modules/ema.py | |
24a25,28 | |
> def reset_num_updates(self): | |
> del self.num_updates | |
> self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) | |
> | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/encoders/modules.py sd2/ldm/modules/encoders/modules.py | |
3,7c3 | |
< from functools import partial | |
< import clip | |
< from einops import rearrange, repeat | |
< from transformers import CLIPTokenizer, CLIPTextModel | |
< import kornia | |
--- | |
> from torch.utils.checkpoint import checkpoint | |
9c5,8 | |
< from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test | |
--- | |
> from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel | |
> | |
> import open_clip | |
> from ldm.util import default, count_params | |
19a19,23 | |
> class IdentityEncoder(AbstractEncoder): | |
> | |
> def encode(self, x): | |
> return x | |
> | |
22c26 | |
< def __init__(self, embed_dim, n_classes=1000, key='class'): | |
--- | |
> def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): | |
25a30,31 | |
> self.n_classes = n_classes | |
> self.ucg_rate = ucg_rate | |
27c33 | |
< def forward(self, batch, key=None): | |
--- | |
> def forward(self, batch, key=None, disable_dropout=False): | |
31a38,41 | |
> if self.ucg_rate > 0. and not disable_dropout: | |
> mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) | |
> c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) | |
> c = c.long() | |
34a45,49 | |
> def get_unconditional_conditioning(self, bs, device="cuda"): | |
> uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) | |
> uc = torch.ones((bs,), device=device) * uc_class | |
> uc = {self.key: uc} | |
> return uc | |
36,42d50 | |
< class TransformerEmbedder(AbstractEncoder): | |
< """Some transformer encoder layers""" | |
< def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): | |
< super().__init__() | |
< self.device = device | |
< self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, | |
< attn_layers=Encoder(dim=n_embed, depth=n_layer)) | |
44,50c52,55 | |
< def forward(self, tokens): | |
< tokens = tokens.to(self.device) # meh | |
< z = self.transformer(tokens, return_embeddings=True) | |
< return z | |
< | |
< def encode(self, x): | |
< return self(x) | |
--- | |
> def disabled_train(self, mode=True): | |
> """Overwrite model.train with this function to make sure train/eval mode | |
> does not change anymore.""" | |
> return self | |
53,55c58,60 | |
< class BERTTokenizer(AbstractEncoder): | |
< """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" | |
< def __init__(self, device="cuda", vq_interface=True, max_length=77): | |
--- | |
> class FrozenT5Embedder(AbstractEncoder): | |
> """Uses the T5 transformer encoder for text""" | |
> def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
57,58c62,63 | |
< from transformers import BertTokenizerFast # TODO: add to reuquirements | |
< self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |
--- | |
> self.tokenizer = T5Tokenizer.from_pretrained(version) | |
> self.transformer = T5EncoderModel.from_pretrained(version) | |
60,61c65,73 | |
< self.vq_interface = vq_interface | |
< self.max_length = max_length | |
--- | |
> self.max_length = max_length # TODO: typical value? | |
> if freeze: | |
> self.freeze() | |
> | |
> def freeze(self): | |
> self.transformer = self.transformer.eval() | |
> #self.train = disabled_train | |
> for param in self.parameters(): | |
> param.requires_grad = False | |
67,91c79 | |
< return tokens | |
< | |
< @torch.no_grad() | |
< def encode(self, text): | |
< tokens = self(text) | |
< if not self.vq_interface: | |
< return tokens | |
< return None, None, [None, None, tokens] | |
< | |
< def decode(self, text): | |
< return text | |
< | |
< | |
< class BERTEmbedder(AbstractEncoder): | |
< """Uses the BERT tokenizr model and add some transformer encoder layers""" | |
< def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, | |
< device="cuda",use_tokenizer=True, embedding_dropout=0.0): | |
< super().__init__() | |
< self.use_tknz_fn = use_tokenizer | |
< if self.use_tknz_fn: | |
< self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) | |
< self.device = device | |
< self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, | |
< attn_layers=Encoder(dim=n_embed, depth=n_layer), | |
< emb_dropout=embedding_dropout) | |
--- | |
> outputs = self.transformer(input_ids=tokens) | |
93,98c81 | |
< def forward(self, text): | |
< if self.use_tknz_fn: | |
< tokens = self.tknz_fn(text)#.to(self.device) | |
< else: | |
< tokens = text | |
< z = self.transformer(tokens, return_embeddings=True) | |
--- | |
> z = outputs.last_hidden_state | |
102d84 | |
< # output of length 77 | |
106,136d87 | |
< class SpatialRescaler(nn.Module): | |
< def __init__(self, | |
< n_stages=1, | |
< method='bilinear', | |
< multiplier=0.5, | |
< in_channels=3, | |
< out_channels=None, | |
< bias=False): | |
< super().__init__() | |
< self.n_stages = n_stages | |
< assert self.n_stages >= 0 | |
< assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] | |
< self.multiplier = multiplier | |
< self.interpolator = partial(torch.nn.functional.interpolate, mode=method) | |
< self.remap_output = out_channels is not None | |
< if self.remap_output: | |
< print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') | |
< self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) | |
< | |
< def forward(self,x): | |
< for stage in range(self.n_stages): | |
< x = self.interpolator(x, scale_factor=self.multiplier) | |
< | |
< | |
< if self.remap_output: | |
< x = self.channel_mapper(x) | |
< return x | |
< | |
< def encode(self, x): | |
< return self(x) | |
< | |
138,139c89,96 | |
< """Uses the CLIP transformer encoder for text (from Hugging Face)""" | |
< def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): | |
--- | |
> """Uses the CLIP transformer encoder for text (from huggingface)""" | |
> LAYERS = [ | |
> "last", | |
> "pooled", | |
> "hidden" | |
> ] | |
> def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, | |
> freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 | |
140a98 | |
> assert layer in self.LAYERS | |
144a103 | |
> if freeze: | |
145a105,109 | |
> self.layer = layer | |
> self.layer_idx = layer_idx | |
> if layer == "hidden": | |
> assert layer_idx is not None | |
> assert 0 <= abs(layer_idx) <= 12 | |
148a113 | |
> #self.train = disabled_train | |
156,157c121,122 | |
< outputs = self.transformer(input_ids=tokens) | |
< | |
--- | |
> outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") | |
> if self.layer == "last": | |
158a124,127 | |
> elif self.layer == "pooled": | |
> z = outputs.pooler_output[:, None, :] | |
> else: | |
> z = outputs.hidden_states[self.layer_idx] | |
165c134 | |
< class FrozenCLIPTextEmbedder(nn.Module): | |
--- | |
> class FrozenOpenCLIPEmbedder(AbstractEncoder): | |
167c136 | |
< Uses the CLIP transformer encoder for text. | |
--- | |
> Uses the OpenCLIP transformer encoder for text | |
169,171c138,150 | |
< def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): | |
< super().__init__() | |
< self.model, _ = clip.load(version, jit=False, device="cpu") | |
--- | |
> LAYERS = [ | |
> #"pooled", | |
> "last", | |
> "penultimate" | |
> ] | |
> def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, | |
> freeze=True, layer="last"): | |
> super().__init__() | |
> assert layer in self.LAYERS | |
> model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) | |
> del model.visual | |
> self.model = model | |
> | |
174,175c153,161 | |
< self.n_repeat = n_repeat | |
< self.normalize = normalize | |
--- | |
> if freeze: | |
> self.freeze() | |
> self.layer = layer | |
> if self.layer == "last": | |
> self.layer_idx = 0 | |
> elif self.layer == "penultimate": | |
> self.layer_idx = 1 | |
> else: | |
> raise NotImplementedError() | |
183,186c169,170 | |
< tokens = clip.tokenize(text).to(self.device) | |
< z = self.model.encode_text(tokens) | |
< if self.normalize: | |
< z = z / torch.linalg.norm(z, dim=1, keepdim=True) | |
--- | |
> tokens = open_clip.tokenize(text) | |
> z = self.encode_with_transformer(tokens.to(self.device)) | |
188a173,191 | |
> def encode_with_transformer(self, text): | |
> x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
> x = x + self.model.positional_embedding | |
> x = x.permute(1, 0, 2) # NLD -> LND | |
> x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
> x = x.permute(1, 0, 2) # LND -> NLD | |
> x = self.model.ln_final(x) | |
> return x | |
> | |
> def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): | |
> for i, r in enumerate(self.model.transformer.resblocks): | |
> if i == len(self.model.transformer.resblocks) - self.layer_idx: | |
> break | |
> if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): | |
> x = checkpoint(r, x, attn_mask) | |
> else: | |
> x = r(x, attn_mask=attn_mask) | |
> return x | |
> | |
190,194c193 | |
< z = self(text) | |
< if z.ndim==2: | |
< z = z[:, None, :] | |
< z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) | |
< return z | |
--- | |
> return self(text) | |
197,224c196,206 | |
< class FrozenClipImageEmbedder(nn.Module): | |
< """ | |
< Uses the CLIP image encoder. | |
< """ | |
< def __init__( | |
< self, | |
< model, | |
< jit=False, | |
< device='cuda' if torch.cuda.is_available() else 'cpu', | |
< antialias=False, | |
< ): | |
< super().__init__() | |
< self.model, _ = clip.load(name=model, device=device, jit=jit) | |
< | |
< self.antialias = antialias | |
< | |
< self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) | |
< self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) | |
< | |
< def preprocess(self, x): | |
< # normalize to [0,1] | |
< x = kornia.geometry.resize(x, (224, 224), | |
< interpolation='bicubic',align_corners=True, | |
< antialias=self.antialias) | |
< x = (x + 1.) / 2. | |
< # renormalize according to clip | |
< x = kornia.enhance.normalize(x, self.mean, self.std) | |
< return x | |
--- | |
> class FrozenCLIPT5Encoder(AbstractEncoder): | |
> def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", | |
> clip_max_length=77, t5_max_length=77): | |
> super().__init__() | |
> self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) | |
> self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) | |
> print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " | |
> f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") | |
> | |
> def encode(self, text): | |
> return self(text) | |
226,228c208,211 | |
< def forward(self, x): | |
< # x is assumed to be in range [-1,1] | |
< return self.model.encode_image(self.preprocess(x)) | |
--- | |
> def forward(self, text): | |
> clip_z = self.clip_encoder.encode(text) | |
> t5_z = self.t5_encoder.encode(text) | |
> return [clip_z, t5_z] | |
231,234d213 | |
< if __name__ == "__main__": | |
< from ldm.util import count_params | |
< model = FrozenCLIPEmbedder() | |
< count_params(model, verbose=True) | |
\ No newline at end of file | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/modules/image_degradation/bsrgan_light.py sd2/ldm/modules/image_degradation/bsrgan_light.py | |
28d27 | |
< | |
257c256 | |
< x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' | |
--- | |
> x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' | |
280c279 | |
< x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') | |
--- | |
> x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') | |
293c292 | |
< x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') | |
--- | |
> x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') | |
338c337 | |
< img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') | |
--- | |
> img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') | |
500c499 | |
< img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') | |
--- | |
> img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') | |
534c533 | |
< def degradation_bsrgan_variant(image, sf=4, isp_model=None): | |
--- | |
> def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): | |
592c591 | |
< image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') | |
--- | |
> image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') | |
619a619,620 | |
> if up: | |
> image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then | |
Only in sd1/ldm/modules: losses | |
Only in sd2/ldm/modules: midas | |
Only in sd1/ldm/modules: x_transformer.py | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/ldm/util.py sd2/ldm/util.py | |
3a4 | |
> from torch import optim | |
5,11d5 | |
< from collections import abc | |
< from einops import rearrange | |
< from functools import partial | |
< | |
< import multiprocessing as mp | |
< from threading import Thread | |
< from queue import Queue | |
96,169c90,195 | |
< def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): | |
< # create dummy dataset instance | |
< | |
< # run prefetching | |
< if idx_to_fn: | |
< res = func(data, worker_id=idx) | |
< else: | |
< res = func(data) | |
< Q.put([idx, res]) | |
< Q.put("Done") | |
< | |
< | |
< def parallel_data_prefetch( | |
< func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False | |
< ): | |
< # if target_data_type not in ["ndarray", "list"]: | |
< # raise ValueError( | |
< # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." | |
< # ) | |
< if isinstance(data, np.ndarray) and target_data_type == "list": | |
< raise ValueError("list expected but function got ndarray.") | |
< elif isinstance(data, abc.Iterable): | |
< if isinstance(data, dict): | |
< print( | |
< f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' | |
< ) | |
< data = list(data.values()) | |
< if target_data_type == "ndarray": | |
< data = np.asarray(data) | |
< else: | |
< data = list(data) | |
< else: | |
< raise TypeError( | |
< f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." | |
< ) | |
< | |
< if cpu_intensive: | |
< Q = mp.Queue(1000) | |
< proc = mp.Process | |
< else: | |
< Q = Queue(1000) | |
< proc = Thread | |
< # spawn processes | |
< if target_data_type == "ndarray": | |
< arguments = [ | |
< [func, Q, part, i, use_worker_id] | |
< for i, part in enumerate(np.array_split(data, n_proc)) | |
< ] | |
< else: | |
< step = ( | |
< int(len(data) / n_proc + 1) | |
< if len(data) % n_proc != 0 | |
< else int(len(data) / n_proc) | |
< ) | |
< arguments = [ | |
< [func, Q, part, i, use_worker_id] | |
< for i, part in enumerate( | |
< [data[i: i + step] for i in range(0, len(data), step)] | |
< ) | |
< ] | |
< processes = [] | |
< for i in range(n_proc): | |
< p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) | |
< processes += [p] | |
< | |
< # start processes | |
< print(f"Start prefetching...") | |
< import time | |
< | |
< start = time.time() | |
< gather_res = [[] for _ in range(n_proc)] | |
< try: | |
< for p in processes: | |
< p.start() | |
--- | |
> class AdamWwithEMAandWings(optim.Optimizer): | |
> # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 | |
> def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using | |
> weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code | |
> ema_power=1., param_names=()): | |
> """AdamW that saves EMA versions of the parameters.""" | |
> if not 0.0 <= lr: | |
> raise ValueError("Invalid learning rate: {}".format(lr)) | |
> if not 0.0 <= eps: | |
> raise ValueError("Invalid epsilon value: {}".format(eps)) | |
> if not 0.0 <= betas[0] < 1.0: | |
> raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |
> if not 0.0 <= betas[1] < 1.0: | |
> raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |
> if not 0.0 <= weight_decay: | |
> raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |
> if not 0.0 <= ema_decay <= 1.0: | |
> raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) | |
> defaults = dict(lr=lr, betas=betas, eps=eps, | |
> weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, | |
> ema_power=ema_power, param_names=param_names) | |
> super().__init__(params, defaults) | |
> | |
> def __setstate__(self, state): | |
> super().__setstate__(state) | |
> for group in self.param_groups: | |
> group.setdefault('amsgrad', False) | |
> | |
> @torch.no_grad() | |
> def step(self, closure=None): | |
> """Performs a single optimization step. | |
> Args: | |
> closure (callable, optional): A closure that reevaluates the model | |
> and returns the loss. | |
> """ | |
> loss = None | |
> if closure is not None: | |
> with torch.enable_grad(): | |
> loss = closure() | |
> | |
> for group in self.param_groups: | |
> params_with_grad = [] | |
> grads = [] | |
> exp_avgs = [] | |
> exp_avg_sqs = [] | |
> ema_params_with_grad = [] | |
> state_sums = [] | |
> max_exp_avg_sqs = [] | |
> state_steps = [] | |
> amsgrad = group['amsgrad'] | |
> beta1, beta2 = group['betas'] | |
> ema_decay = group['ema_decay'] | |
> ema_power = group['ema_power'] | |
> | |
> for p in group['params']: | |
> if p.grad is None: | |
> continue | |
> params_with_grad.append(p) | |
> if p.grad.is_sparse: | |
> raise RuntimeError('AdamW does not support sparse gradients') | |
> grads.append(p.grad) | |
> | |
> state = self.state[p] | |
> | |
> # State initialization | |
> if len(state) == 0: | |
> state['step'] = 0 | |
> # Exponential moving average of gradient values | |
> state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
> # Exponential moving average of squared gradient values | |
> state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
> if amsgrad: | |
> # Maintains max of all exp. moving avg. of sq. grad. values | |
> state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
> # Exponential moving average of parameter values | |
> state['param_exp_avg'] = p.detach().float().clone() | |
> | |
> exp_avgs.append(state['exp_avg']) | |
> exp_avg_sqs.append(state['exp_avg_sq']) | |
> ema_params_with_grad.append(state['param_exp_avg']) | |
> | |
> if amsgrad: | |
> max_exp_avg_sqs.append(state['max_exp_avg_sq']) | |
> | |
> # update the steps for each param group update | |
> state['step'] += 1 | |
> # record the step after step update | |
> state_steps.append(state['step']) | |
> | |
> optim._functional.adamw(params_with_grad, | |
> grads, | |
> exp_avgs, | |
> exp_avg_sqs, | |
> max_exp_avg_sqs, | |
> state_steps, | |
> amsgrad=amsgrad, | |
> beta1=beta1, | |
> beta2=beta2, | |
> lr=group['lr'], | |
> weight_decay=group['weight_decay'], | |
> eps=group['eps'], | |
> maximize=False) | |
> | |
> cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) | |
> for param, ema_param in zip(params_with_grad, ema_params_with_grad): | |
> ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) | |
171,203c197 | |
< k = 0 | |
< while k < n_proc: | |
< # get result | |
< res = Q.get() | |
< if res == "Done": | |
< k += 1 | |
< else: | |
< gather_res[res[0]] = res[1] | |
< | |
< except Exception as e: | |
< print("Exception: ", e) | |
< for p in processes: | |
< p.terminate() | |
< | |
< raise e | |
< finally: | |
< for p in processes: | |
< p.join() | |
< print(f"Prefetching complete. [{time.time() - start} sec.]") | |
< | |
< if target_data_type == 'ndarray': | |
< if not isinstance(gather_res[0], np.ndarray): | |
< return np.concatenate([np.asarray(r) for r in gather_res], axis=0) | |
< | |
< # order outputs | |
< return np.concatenate(gather_res, axis=0) | |
< elif target_data_type == 'list': | |
< out = [] | |
< for r in gather_res: | |
< out.extend(r) | |
< return out | |
< else: | |
< return gather_res | |
--- | |
> return loss | |
\ No newline at end of file | |
Only in sd1: main.py | |
Only in sd1: models | |
Only in sd1: notebook_helpers.py | |
Only in sd2: requirements.txt | |
Only in sd1/scripts: download_first_stages.sh | |
Only in sd1/scripts: download_models.sh | |
Only in sd2/scripts: gradio | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/scripts/img2img.py sd2/scripts/img2img.py | |
3c3 | |
< import argparse, os, sys, glob | |
--- | |
> import argparse, os | |
15d14 | |
< import time | |
16a16 | |
> from imwatermark import WatermarkEncoder | |
17a18,19 | |
> | |
> from scripts.txt2img import put_watermark | |
20d21 | |
< from ldm.models.diffusion.plms import PLMSSampler | |
52c53 | |
< w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
--- | |
> w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 | |
87,98d87 | |
< "--skip_grid", | |
< action='store_true', | |
< help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", | |
< ) | |
< | |
< parser.add_argument( | |
< "--skip_save", | |
< action='store_true', | |
< help="do not save indiviual samples. For speed measurements.", | |
< ) | |
< | |
< parser.add_argument( | |
106,110d94 | |
< "--plms", | |
< action='store_true', | |
< help="use plms sampling", | |
< ) | |
< parser.add_argument( | |
127a112 | |
> | |
139a125 | |
> | |
145a132 | |
> | |
151a139 | |
> | |
155c143 | |
< default=5.0, | |
--- | |
> default=9.0, | |
162c150 | |
< default=0.75, | |
--- | |
> default=0.8, | |
164a153 | |
> | |
173c162 | |
< default="configs/stable-diffusion/v1-inference.yaml", | |
--- | |
> default="configs/stable-diffusion/v2-inference.yaml", | |
179d167 | |
< default="models/ldm/stable-diffusion-v1/model.ckpt", | |
205,208d192 | |
< if opt.plms: | |
< raise NotImplementedError("PLMS sampler not (yet) supported") | |
< sampler = PLMSSampler(model) | |
< else: | |
213a198,202 | |
> print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") | |
> wm = "SDV2" | |
> wm_encoder = WatermarkEncoder() | |
> wm_encoder.set_watermark('bytes', wm.encode('utf-8')) | |
> | |
247d235 | |
< tic = time.time() | |
267d254 | |
< if not opt.skip_save: | |
270,271c257,259 | |
< Image.fromarray(x_sample.astype(np.uint8)).save( | |
< os.path.join(sample_path, f"{base_count:05}.png")) | |
--- | |
> img = Image.fromarray(x_sample.astype(np.uint8)) | |
> img = put_watermark(img, wm_encoder) | |
> img.save(os.path.join(sample_path, f"{base_count:05}.png")) | |
275d262 | |
< if not opt.skip_grid: | |
283c270,272 | |
< Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) | |
--- | |
> grid = Image.fromarray(grid.astype(np.uint8)) | |
> grid = put_watermark(grid, wm_encoder) | |
> grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) | |
286,289c275 | |
< toc = time.time() | |
< | |
< print(f"Your samples are ready and waiting for you here: \n{outpath} \n" | |
< f" \nEnjoy.") | |
--- | |
> print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") | |
Only in sd1/scripts: inpaint.py | |
Only in sd1/scripts: knn2img.py | |
Only in sd1/scripts: latent_imagenet_diffusion.ipynb | |
Only in sd1/scripts: sample_diffusion.py | |
Only in sd2/scripts: streamlit | |
Only in sd1/scripts: train_searcher.py | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/scripts/txt2img.py sd2/scripts/txt2img.py | |
1c1 | |
< import argparse, os, sys, glob | |
--- | |
> import argparse, os | |
8d7 | |
< from imwatermark import WatermarkEncoder | |
12d10 | |
< import time | |
15c13,14 | |
< from contextlib import contextmanager, nullcontext | |
--- | |
> from contextlib import nullcontext | |
> from imwatermark import WatermarkEncoder | |
22,30c21 | |
< from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
< from transformers import AutoFeatureExtractor | |
< | |
< | |
< # load safety model | |
< safety_model_id = "CompVis/stable-diffusion-safety-checker" | |
< safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) | |
< safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) | |
< | |
--- | |
> torch.set_grad_enabled(False) | |
37,48d27 | |
< def numpy_to_pil(images): | |
< """ | |
< Convert a numpy image or a batch of images to a PIL image. | |
< """ | |
< if images.ndim == 3: | |
< images = images[None, ...] | |
< images = (images * 255).round().astype("uint8") | |
< pil_images = [Image.fromarray(image) for image in images] | |
< | |
< return pil_images | |
< | |
< | |
69,98c48 | |
< def put_watermark(img, wm_encoder=None): | |
< if wm_encoder is not None: | |
< img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
< img = wm_encoder.encode(img, 'dwtDct') | |
< img = Image.fromarray(img[:, :, ::-1]) | |
< return img | |
< | |
< | |
< def load_replacement(x): | |
< try: | |
< hwc = x.shape | |
< y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) | |
< y = (np.array(y)/255.0).astype(x.dtype) | |
< assert y.shape == x.shape | |
< return y | |
< except Exception: | |
< return x | |
< | |
< | |
< def check_safety(x_image): | |
< safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") | |
< x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) | |
< assert x_checked_image.shape[0] == len(has_nsfw_concept) | |
< for i in range(len(has_nsfw_concept)): | |
< if has_nsfw_concept[i]: | |
< x_checked_image[i] = load_replacement(x_checked_image[i]) | |
< return x_checked_image, has_nsfw_concept | |
< | |
< | |
< def main(): | |
--- | |
> def parse_args(): | |
100d49 | |
< | |
105c54 | |
< default="a painting of a virus monster playing guitar", | |
--- | |
> default="a professional photograph of an astronaut riding a triceratops", | |
116,126c65 | |
< "--skip_grid", | |
< action='store_true', | |
< help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", | |
< ) | |
< parser.add_argument( | |
< "--skip_save", | |
< action='store_true', | |
< help="do not save individual samples. For speed measurements.", | |
< ) | |
< parser.add_argument( | |
< "--ddim_steps", | |
--- | |
> "--steps", | |
137,142c76 | |
< "--dpm_solver", | |
< action='store_true', | |
< help="use dpm_solver sampling", | |
< ) | |
< parser.add_argument( | |
< "--laion400m", | |
--- | |
> "--dpm", | |
144c78 | |
< help="uses the LAION400M model", | |
--- | |
> help="use DPM (2) sampler", | |
149c83 | |
< help="if enabled, uses the same starting code across samples ", | |
--- | |
> help="if enabled, uses the same starting code across all samples ", | |
160c94 | |
< default=2, | |
--- | |
> default=3, | |
185c119 | |
< help="downsampling factor", | |
--- | |
> help="downsampling factor, most often 8 or 16", | |
191c125 | |
< help="how many samples to produce for each given prompt. A.k.a. batch size", | |
--- | |
> help="how many samples to produce for each given prompt. A.k.a batch size", | |
202c136 | |
< default=7.5, | |
--- | |
> default=9.0, | |
208c142 | |
< help="if specified, load prompts from this file", | |
--- | |
> help="if specified, load prompts from this file, separated by newlines", | |
213c147 | |
< default="configs/stable-diffusion/v1-inference.yaml", | |
--- | |
> default="configs/stable-diffusion/v2-inference.yaml", | |
219d152 | |
< default="models/ldm/stable-diffusion-v1/model.ckpt", | |
234a168,173 | |
> parser.add_argument( | |
> "--repeat", | |
> type=int, | |
> default=1, | |
> help="repeat each prompt in file this often", | |
> ) | |
235a175 | |
> return opt | |
237,241d176 | |
< if opt.laion400m: | |
< print("Falling back to LAION 400M model...") | |
< opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" | |
< opt.ckpt = "models/ldm/text2img-large/model.ckpt" | |
< opt.outdir = "outputs/txt2img-samples-laion400m" | |
242a178,186 | |
> def put_watermark(img, wm_encoder=None): | |
> if wm_encoder is not None: | |
> img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
> img = wm_encoder.encode(img, 'dwtDct') | |
> img = Image.fromarray(img[:, :, ::-1]) | |
> return img | |
> | |
> | |
> def main(opt): | |
251,253c195 | |
< if opt.dpm_solver: | |
< sampler = DPMSolverSampler(model) | |
< elif opt.plms: | |
--- | |
> if opt.plms: | |
254a197,198 | |
> elif opt.dpm: | |
> sampler = DPMSolverSampler(model) | |
262c206 | |
< wm = "StableDiffusionV1" | |
--- | |
> wm = "SDV2" | |
276a221 | |
> data = [p for p in data for i in range(opt.repeat)] | |
280a226 | |
> sample_count = 0 | |
289,292c235,237 | |
< with torch.no_grad(): | |
< with precision_scope("cuda"): | |
< with model.ema_scope(): | |
< tic = time.time() | |
--- | |
> with torch.no_grad(), \ | |
> precision_scope("cuda"), \ | |
> model.ema_scope(): | |
303c248 | |
< samples_ddim, _ = sampler.sample(S=opt.ddim_steps, | |
--- | |
> samples, _ = sampler.sample(S=opt.steps, | |
313,319c258,259 | |
< x_samples_ddim = model.decode_first_stage(samples_ddim) | |
< x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
< x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() | |
< | |
< x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) | |
< | |
< x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) | |
--- | |
> x_samples = model.decode_first_stage(samples) | |
> x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
321,322c261 | |
< if not opt.skip_save: | |
< for x_sample in x_checked_image_torch: | |
--- | |
> for x_sample in x_samples: | |
327a267 | |
> sample_count += 1 | |
329,330c269 | |
< if not opt.skip_grid: | |
< all_samples.append(x_checked_image_torch) | |
--- | |
> all_samples.append(x_samples) | |
332d270 | |
< if not opt.skip_grid: | |
340,342c278,280 | |
< img = Image.fromarray(grid.astype(np.uint8)) | |
< img = put_watermark(img, wm_encoder) | |
< img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) | |
--- | |
> grid = Image.fromarray(grid.astype(np.uint8)) | |
> grid = put_watermark(grid, wm_encoder) | |
> grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) | |
345,346d282 | |
< toc = time.time() | |
< | |
352c288,289 | |
< main() | |
--- | |
> opt = parse_args() | |
> main(opt) | |
diff -w -x '*.gif' -x '*.png' -x '*.jpg' -x LICENSE -x '*.md' -x '*.git' -r sd1/setup.py sd2/setup.py | |
4c4 | |
< name='latent-diffusion', | |
--- | |
> name='stable-diffusion', |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment