Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active April 8, 2023 10:32
Show Gist options
  • Save Birch-san/c1f86f6e10c3e9ce091d6507f0966179 to your computer and use it in GitHub Desktop.
Save Birch-san/c1f86f6e10c3e9ce091d6507f0966179 to your computer and use it in GitHub Desktop.
from torch import FloatTensor, load, baddbmm, zeros
from dataclasses import dataclass
import torch
from os.path import join
@dataclass
class Fixtures:
q_proj: FloatTensor
k_proj: FloatTensor
device = torch.device('cuda')
def get_fixtures(
sigma: float,
head_dim: int,
key_length_factor: float,
key_tokens: int,
) -> Fixtures:
"""
Imports pre-saved tensors from filesystem, by building a file path to identify which stage of which diffusion run we want to load.
Args:
sigma:
I believe discretized sigmas for 22-step Karras schedule *should* be:
14.6146, 11.9484, 9.7548, 7.9216, 6.3493, 5.0878, 4.0300, 3.1667, 2.4743, 1.9103, 1.4601, 1.1084, 0.8299, 0.6127, 0.4471, 0.3213, 0.2281, 0.1580, 0.1072, 0.0720, 0.0507, 0.0292
Yet, k-diffusion invoked my Unet with the following sigmas:
14.6147, 11.9776, 9.7593, 7.9029, 6.3579, 5.0793, 4.0277, 3.1686, 2.4716, 1.9104, 1.4621, 1.1072, 0.8289, 0.6128, 0.4469, 0.3211, 0.2270, 0.1576, 0.1072, 0.0713, 0.0463, 0.0292
head_dim:
https://twitter.com/Birchlabs/status/1609605588601675776
varies depending on which self-attn layer you're on (at least in SD1.5, lol). I have files for:
40, 80, 160, 160
key_length_factor:
1.0 = in-distribution image, 512x512, key length as large as 4096
2.25 = out-of-distribution image, 768x768, key length as large as 9216
key_tokens:
varies depending on which self-attn layer you're on, and how large your latents are
4096, 1024, 256, 64
9216, 2304, 576, 144
"""
root_dir='/home/birch/git/diffusers-play'
in_dir=join(root_dir, 'out_tensor')
tensor_path_prefix=join(in_dir, f'f{key_length_factor}_s{sigma:.4f}_k{key_tokens}_c{head_dim}')
q_proj: FloatTensor = load(f'{tensor_path_prefix}_q_proj.pt', weights_only=True, map_location=device)
k_proj: FloatTensor = load(f'{tensor_path_prefix}_k_proj.pt', weights_only=True, map_location=device)
return Fixtures(
q_proj=q_proj,
k_proj=k_proj,
)
def get_attn_scores(
q_proj: FloatTensor,
k_proj: FloatTensor,
scale: float,
) -> FloatTensor:
"""Computes (q_proj @ k_proj.T)*scale"""
# no bias, but baddbmm's API requires a tensor even if coefficient is 0
attn_bias: FloatTensor = zeros(1, 1, 1, dtype=q_proj.dtype, device=q_proj.device)
attn_scores: FloatTensor = baddbmm(
attn_bias,
q_proj,
k_proj.transpose(-1, -2),
# means don't apply bias
beta=0,
alpha=scale,
)
return attn_scores
def softmax(x: FloatTensor, dim=-1) -> FloatTensor:
"""Typical softmax. same as PyTorch's built-in torch.Tensor.softmax(), but step-by-step in case you want to modify it."""
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
x_exp_sum = x_exp.sum(dim, keepdim=True)
quotient = x_exp/x_exp_sum
return quotient
def topk_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
"""Softmax with a modified denominator, which sums fewer elements (the topk only) to help you size it to a magnitude on which the model was trained."""
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
x_exp_sum = x_exp.topk(k=k, dim=dim).values.sum(dim, keepdim=True)
quotient = x_exp/x_exp_sum
return quotient
sigma = 14.6147
head_dim = 40
scale = head_dim ** -.5
in_dist_key_tokens = 4096
in_dist: Fixtures = get_fixtures(
sigma=sigma,
head_dim=head_dim,
key_length_factor = in_dist_key_tokens/in_dist_key_tokens, # 1.0
key_tokens = in_dist_key_tokens,
)
out_dist_key_tokens = 9216
out_dist: Fixtures = get_fixtures(
sigma=sigma,
head_dim=head_dim,
key_length_factor = out_dist_key_tokens/in_dist_key_tokens, # 2.25
key_tokens = out_dist_key_tokens,
)
in_dist_attn_scores: FloatTensor = get_attn_scores(
q_proj=in_dist.q_proj,
k_proj=in_dist.k_proj,
scale=scale,
)
out_dist_attn_scores: FloatTensor = get_attn_scores(
q_proj=out_dist.q_proj,
k_proj=out_dist.k_proj,
scale=scale,
)
in_dist_attn_probs: FloatTensor = in_dist_attn_scores.softmax(dim=-1)
out_dist_attn_probs: FloatTensor = out_dist_attn_scores.softmax(dim=-1)
key_tokens: int = out_dist_attn_scores.size(-1)
preferred_token_count: int = in_dist_key_tokens # 4096
out_dist_topk_attn_probs: FloatTensor = topk_softmax(out_dist_attn_scores, k=preferred_token_count, dim=-1)
#### you probably want to put the following into a separate Jupyter cell so you can iterate on the charting
import matplotlib.pyplot as plt
from torch import histogram
from typing import NamedTuple
class Histogram(NamedTuple):
boundaries: FloatTensor
counts: FloatTensor
density=True
bins=200
hist, bin_edges = histogram(in_dist_attn_probs[0][0].log().float().cpu(), bins=bins, density=density)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='in-dist')
hist, bin_edges = histogram(out_dist_attn_probs[0][0].log().float().cpu(), bins=bins, density=density)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist')
hist, bin_edges = histogram((out_dist_attn_probs[0][0]*(out_dist_key_tokens/in_dist_key_tokens)).log().float().cpu(), bins=bins, density=density)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist (scaled)')
boundaries, counts = histogram(out_dist_topk_attn_probs[0][0].log().float().cpu(), bins)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist (topk)')
plt.title(f'σ={sigma} log self-attn probs for uncond batch, first down-block, head 0, query token 0')
plt.legend()
plt.show()
@Birch-san
Copy link
Author

instead of taking topk key tokens (which advantages larger scores): I tried a nearest-neighbour downsample of the 9216 key tokens to 4096. basically like discarding every second token.

resample_softmax attempts to sort the distribution first (which I think could give a fairer resample, assumping you interpolate between the datapoints), prior to resampling. however the sort ran out of memory, so wasn't able to evaluate this.

def resample_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
  """Softmax with a modified denominator. for each query token: sorts attn_scores, resamples key dimension to size k; you can use this to increase/decrease denominator to the magnitude on which the model was trained."""
  maxes = x.max(dim, keepdim=True).values
  diffs = x-maxes
  del maxes
  torch.cuda.empty_cache()
  gc.collect()
  diffs_sorted = diffs.sort(dim=dim).values
  x_exp = diffs.exp()
  del diffs
  # for downsampling:
  #   mode='area' has best PSNR.. but maybe that's not important.
  # for upsampling:
  #   lerping between attn scores (mode='linear') feels reasonable
  #   repeating attn scores (mode='nearest_exact') might be reasonable too
  # not whether we'd care about anti-aliasing
  diffs_resampled = interpolate(diffs_sorted, size=(*diffs_sorted.shape[:-1], k), mode='linear', antialias=False)
  del diffs_sorted
  diffs_exp_sum = diffs_resampled.exp().sum(dim, keepdim=True)
  del diffs_resampled
  quotient = x_exp/diffs_exp_sum
  return quotient

nearest-neighbour resample ran to completion but completely destroyed the image.

def resample_crude_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
  """Softmax with a modified denominator. for each query token: resamples key dimension to size k; you can use this to increase/decrease denominator to the magnitude on which the model was trained."""
  maxes = x.max(dim, keepdim=True).values
  diffs = x-maxes
  del maxes
  x_exp = diffs.exp()
  diffs_resampled = interpolate(diffs, scale_factor=k/diffs.size(-1), mode='nearest-exact', antialias=False)
  del diffs
  diffs_exp_sum = diffs_resampled.exp().sum(dim, keepdim=True)
  del diffs_resampled
  quotient = x_exp/diffs_exp_sum
  return quotient

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment