Skip to content

Instantly share code, notes, and snippets.

Last active November 6, 2022 18:39
Show Gist options
  • Save Birch-san/e995e76b42bb8c27d16e992398f5cf4b to your computer and use it in GitHub Desktop.
Save Birch-san/e995e76b42bb8c27d16e992398f5cf4b to your computer and use it in GitHub Desktop.
Dynamic thresholding of stable-diffusion latents, by referring to known-good CFG7.5's dynamic range
from torch import Tensor, FloatTensor
from typing import Protocol, Optional
from k_diffusion.external import CompVisDenoiser
from k_diffusion.sampling import sample_heun
class DiffusionModel(Protocol):
def __call__(self, x: Tensor, sigma: Tensor, **kwargs) -> Tensor: ...
class DiffusionModelMixin(DiffusionModel):
inner_model: DiffusionModel
# workaround until k-diffusion introduces official base model wrapper,
# to make the wrapper forward all method calls to the wrapped model
class CompVisDenoiserWrapper(CompVisDenoiser, DiffusionModelMixin):
inner_model: DiffusionModel
def __init__(self, model: DiffusionModel, quantize=False):
CompVisDenoiser.__init__(self, model, quantize=quantize)
class BaseModelWrapper(nn.Module, DiffusionModelMixin):
inner_model: DiffusionModel
def __init__(self, inner_model: DiffusionModel):
self.inner_model = inner_model
def repeat_along_dim_0(t: Tensor, factor: int) -> Tensor:
Repeats a tensor's contents along its 0th dim `factor` times.
repeat_along_dim_0(torch.tensor([[0,1]]), 2)
tensor([[0, 1],
[0, 1]])
# shape changes from (1, 2)
# to (2, 2)
repeat_along_dim_0(torch.tensor([[0,1],[2,3]]), 2)
tensor([[0, 1],
[2, 3],
[0, 1],
[2, 3]])
# shape changes from (2, 2)
# to (4, 2)
assert factor >= 1
if factor == 1:
return t
if t.size(dim=0) == 1:
# prefer expand() whenever we can, since doesn't copy
return t.expand(factor * t.size(dim=0), *(-1,)*(t.ndim-1))
return t.repeat((factor, *(1,)*(t.ndim-1)))
class CFGDynTheshDenoiser(BaseModelWrapper):
dynamic_thresholding_percentile: float
dynamic_thresholding_mimic_scale: float
def __init__(
model: DiffusionModel,
dynamic_thresholding_percentile: float,
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
self.dynamic_thresholding_mimic_scale = dynamic_thresholding_mimic_scale
def forward(
x: FloatTensor,
sigma: FloatTensor,
cond: FloatTensor,
cond_scale: float = 1.0,
uncond: Optional[FloatTensor] = None,
) -> FloatTensor:
if uncond is None or cond_scale == 1.0:
return self.inner_model(x, sigma, cond=cond)
cond_in =[uncond, cond])
del uncond, cond
x_in = repeat_along_dim_0(x, cond_in.size(dim=0))
del x
sigma_in = repeat_along_dim_0(sigma, cond_in.size(dim=0))
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(cond_in.size(dim=0))
del x_in, sigma_in, cond_in
diff: Tensor = cond - uncond
dynthresh_target: Tensor = uncond + diff * self.dynamic_thresholding_mimic_scale
dt_flattened: Tensor = dynthresh_target.flatten(2)
dt_means: Tensor = dt_flattened.mean(dim=2).unsqueeze(2)
dt_recentered: Tensor = dt_flattened-dt_means
dt_abs = dt_recentered.abs()
dt_max = dt_abs.max(dim=2).values.unsqueeze(2)
ut: Tensor = uncond + diff * cond_scale
ut_flattened: Tensor = ut.flatten(2)
ut_means: Tensor = ut_flattened.mean(dim=2).unsqueeze(2)
ut_centered: Tensor = ut_flattened-ut_means
a = ut_centered.abs()
ut_q = torch.quantile(a, self.dynamic_thresholding_percentile, dim=2).unsqueeze(2)
s = torch.maximum(ut_q, dt_max)
t_clamped = ut_centered.clamp(-s, s)
t_normalized = t_clamped / s
t_renormalized = t_normalized * dt_max
uncentered: Tensor = t_renormalized+ut_means
unflattened: Tensor = uncentered.unflatten(2, dynthresh_target.shape[2:])
return unflattened
model = # put your LatentDiffusionModel here
model_k_wrapped = CompVisDenoiserWrapper(model, quantize=True)
model_k_guidance = CFGDynTheshDenoiser(
# clamp away latent values that exceed the 99.9%ile
# use CFG7.5 as our reference for a "known-good" dynamic range
# now sample from model_k_guidance the same way you'd sample from any k-diffusion wrapped model
sample_heun(model_k_guidance, x, ...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment