Skip to content

Instantly share code, notes, and snippets.

Last active December 1, 2022 03:43
Show Gist options
  • Save pattontim/864469ef1f7cb7ebab8ef810b2dc6b3d to your computer and use it in GitHub Desktop.
Save pattontim/864469ef1f7cb7ebab8ef810b2dc6b3d to your computer and use it in GitHub Desktop.
import torch
import os
import gc
# About:
# Tests GPU fast load when initializing LatentDiffusion and using safetensors
# to recreate potential unclean memory state in torch:
# Requirements:
# - ldm (grab folder from one of:
# CompVis - (2022-11-30)
# webui - copy from repository folder of a webui install
# - omegaconf
# - safetensors
# - taming folder from:
# CompVis - (2022-11-30)
# any sufficiently large safetensor model (here SD-1.4)
# v1-inference.yaml from webui, in gist
os.environ['SAFETENSORS_FAST_GPU'] = '1'
from safetensors.torch import save_file, load_file
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
concat_keys=("mask", "masked_image"),
print("init LID")
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
x_in =[x] * 2)
t_in =[t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [[unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
c_in[k] =[unconditional_conditioning[k], c[k]])
c_in =[unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device('cuda'):
if __name__ == "__main__":
sd_config = OmegaConf.load('D:\\SD\\v1-inference.yaml')
print("inst model from config", sd_config.model)
sd_model = instantiate_from_config(sd_config.model)
# removing the above lines causes load_file to be fast
# you can dispose of sd_model, clear the memory, etc...
sd_model = None
sf_filename = "D:\\SD\\convert\\sd-v1-4.safetensors" #any model
sf_loaded = load_file(sf_filename, device='cuda')
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
target: ldm.models.autoencoder.AutoencoderKL
embed_dim: 4
monitor: val/rec_loss
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
target: torch.nn.Identity
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment