Last active
December 1, 2022 03:43
-
-
Save pattontim/864469ef1f7cb7ebab8ef810b2dc6b3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import os | |
import gc | |
# About: | |
# Tests GPU fast load when initializing LatentDiffusion and using safetensors | |
# to recreate potential unclean memory state in torch: https://github.com/huggingface/safetensors/issues/115 | |
# Requirements: | |
# - ldm (grab folder from one of: | |
# CompVis - https://github.com/CompVis/latent-diffusion/tree/main/ldm (2022-11-30) | |
# webui - copy from repository folder of a webui install | |
# - omegaconf | |
# - safetensors | |
# - taming folder from: | |
# CompVis - https://github.com/CompVis/taming-transformers/tree/master/taming (2022-11-30) | |
# any sufficiently large safetensor model (here SD-1.4) | |
# v1-inference.yaml from webui, in gist | |
os.environ['SAFETENSORS_FAST_GPU'] = '1' | |
from safetensors.torch import save_file, load_file | |
from omegaconf import OmegaConf | |
from ldm.util import instantiate_from_config | |
import ldm.models.diffusion.ddpm | |
import ldm.models.diffusion.ddim | |
import ldm.models.diffusion.plms | |
from ldm.models.diffusion.ddpm import LatentDiffusion | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.models.diffusion.ddim import DDIMSampler, noise_like | |
class LatentInpaintDiffusion(LatentDiffusion): | |
def __init__( | |
self, | |
concat_keys=("mask", "masked_image"), | |
masked_image_key="masked_image", | |
*args, | |
**kwargs, | |
): | |
print("init LID") | |
super().__init__(*args, **kwargs) | |
self.masked_image_key = masked_image_key | |
assert self.masked_image_key in concat_keys | |
self.concat_keys = concat_keys | |
@torch.no_grad() | |
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): | |
b, *_, device = *x.shape, x.device | |
def get_model_output(x, t): | |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
e_t = self.model.apply_model(x, t, c) | |
else: | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([t] * 2) | |
if isinstance(c, dict): | |
assert isinstance(unconditional_conditioning, dict) | |
c_in = dict() | |
for k in c: | |
if isinstance(c[k], list): | |
c_in[k] = [ | |
torch.cat([unconditional_conditioning[k][i], c[k][i]]) | |
for i in range(len(c[k])) | |
] | |
else: | |
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) | |
else: | |
c_in = torch.cat([unconditional_conditioning, c]) | |
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | |
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
if score_corrector is not None: | |
assert self.model.parameterization == "eps" | |
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
return e_t | |
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas | |
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev | |
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas | |
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas | |
def get_x_prev_and_pred_x0(e_t, index): | |
# select parameters corresponding to the currently considered timestep | |
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
# current prediction for x_0 | |
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
if quantize_denoised: | |
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) | |
if dynamic_threshold is not None: | |
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) | |
# direction pointing to x_t | |
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | |
if noise_dropout > 0.: | |
noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
return x_prev, pred_x0 | |
e_t = get_model_output(x, t) | |
if len(old_eps) == 0: | |
# Pseudo Improved Euler (2nd order) | |
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) | |
e_t_next = get_model_output(x_prev, t_next) | |
e_t_prime = (e_t + e_t_next) / 2 | |
elif len(old_eps) == 1: | |
# 2nd order Pseudo Linear Multistep (Adams-Bashforth) | |
e_t_prime = (3 * e_t - old_eps[-1]) / 2 | |
elif len(old_eps) == 2: | |
# 3nd order Pseudo Linear Multistep (Adams-Bashforth) | |
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 | |
elif len(old_eps) >= 3: | |
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) | |
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 | |
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) | |
return x_prev, pred_x0, e_t | |
def do_inpainting_hijack(): | |
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion | |
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms | |
def torch_gc(): | |
if torch.cuda.is_available(): | |
with torch.cuda.device('cuda'): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
if __name__ == "__main__": | |
sd_config = OmegaConf.load('D:\\SD\\v1-inference.yaml') | |
do_inpainting_hijack() | |
print("inst model from config", sd_config.model) | |
sd_model = instantiate_from_config(sd_config.model) | |
# removing the above lines causes load_file to be fast | |
# you can dispose of sd_model, clear the memory, etc... | |
sd_model = None | |
gc.collect() | |
torch_gc() | |
sf_filename = "D:\\SD\\convert\\sd-v1-4.safetensors" #any model | |
sf_loaded = load_file(sf_filename, device='cuda') | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model: | |
base_learning_rate: 1.0e-04 | |
target: ldm.models.diffusion.ddpm.LatentDiffusion | |
params: | |
linear_start: 0.00085 | |
linear_end: 0.0120 | |
num_timesteps_cond: 1 | |
log_every_t: 200 | |
timesteps: 1000 | |
first_stage_key: "jpg" | |
cond_stage_key: "txt" | |
image_size: 64 | |
channels: 4 | |
cond_stage_trainable: false # Note: different from the one we trained before | |
conditioning_key: crossattn | |
monitor: val/loss_simple_ema | |
scale_factor: 0.18215 | |
use_ema: False | |
scheduler_config: # 10000 warmup steps | |
target: ldm.lr_scheduler.LambdaLinearScheduler | |
params: | |
warm_up_steps: [ 10000 ] | |
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases | |
f_start: [ 1.e-6 ] | |
f_max: [ 1. ] | |
f_min: [ 1. ] | |
unet_config: | |
target: ldm.modules.diffusionmodules.openaimodel.UNetModel | |
params: | |
image_size: 32 # unused | |
in_channels: 4 | |
out_channels: 4 | |
model_channels: 320 | |
attention_resolutions: [ 4, 2, 1 ] | |
num_res_blocks: 2 | |
channel_mult: [ 1, 2, 4, 4 ] | |
num_heads: 8 | |
use_spatial_transformer: True | |
transformer_depth: 1 | |
context_dim: 768 | |
use_checkpoint: True | |
legacy: False | |
first_stage_config: | |
target: ldm.models.autoencoder.AutoencoderKL | |
params: | |
embed_dim: 4 | |
monitor: val/rec_loss | |
ddconfig: | |
double_z: true | |
z_channels: 4 | |
resolution: 256 | |
in_channels: 3 | |
out_ch: 3 | |
ch: 128 | |
ch_mult: | |
- 1 | |
- 2 | |
- 4 | |
- 4 | |
num_res_blocks: 2 | |
attn_resolutions: [] | |
dropout: 0.0 | |
lossconfig: | |
target: torch.nn.Identity | |
cond_stage_config: | |
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment