Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active April 8, 2023 10:32
Show Gist options
  • Save Birch-san/c1f86f6e10c3e9ce091d6507f0966179 to your computer and use it in GitHub Desktop.
Save Birch-san/c1f86f6e10c3e9ce091d6507f0966179 to your computer and use it in GitHub Desktop.
from torch import FloatTensor, load, baddbmm, zeros
from dataclasses import dataclass
import torch
from os.path import join
@dataclass
class Fixtures:
q_proj: FloatTensor
k_proj: FloatTensor
device = torch.device('cuda')
def get_fixtures(
sigma: float,
head_dim: int,
key_length_factor: float,
key_tokens: int,
) -> Fixtures:
"""
Imports pre-saved tensors from filesystem, by building a file path to identify which stage of which diffusion run we want to load.
Args:
sigma:
I believe discretized sigmas for 22-step Karras schedule *should* be:
14.6146, 11.9484, 9.7548, 7.9216, 6.3493, 5.0878, 4.0300, 3.1667, 2.4743, 1.9103, 1.4601, 1.1084, 0.8299, 0.6127, 0.4471, 0.3213, 0.2281, 0.1580, 0.1072, 0.0720, 0.0507, 0.0292
Yet, k-diffusion invoked my Unet with the following sigmas:
14.6147, 11.9776, 9.7593, 7.9029, 6.3579, 5.0793, 4.0277, 3.1686, 2.4716, 1.9104, 1.4621, 1.1072, 0.8289, 0.6128, 0.4469, 0.3211, 0.2270, 0.1576, 0.1072, 0.0713, 0.0463, 0.0292
head_dim:
https://twitter.com/Birchlabs/status/1609605588601675776
varies depending on which self-attn layer you're on (at least in SD1.5, lol). I have files for:
40, 80, 160, 160
key_length_factor:
1.0 = in-distribution image, 512x512, key length as large as 4096
2.25 = out-of-distribution image, 768x768, key length as large as 9216
key_tokens:
varies depending on which self-attn layer you're on, and how large your latents are
4096, 1024, 256, 64
9216, 2304, 576, 144
"""
root_dir='/home/birch/git/diffusers-play'
in_dir=join(root_dir, 'out_tensor')
tensor_path_prefix=join(in_dir, f'f{key_length_factor}_s{sigma:.4f}_k{key_tokens}_c{head_dim}')
q_proj: FloatTensor = load(f'{tensor_path_prefix}_q_proj.pt', weights_only=True, map_location=device)
k_proj: FloatTensor = load(f'{tensor_path_prefix}_k_proj.pt', weights_only=True, map_location=device)
return Fixtures(
q_proj=q_proj,
k_proj=k_proj,
)
def get_attn_scores(
q_proj: FloatTensor,
k_proj: FloatTensor,
scale: float,
) -> FloatTensor:
"""Computes (q_proj @ k_proj.T)*scale"""
# no bias, but baddbmm's API requires a tensor even if coefficient is 0
attn_bias: FloatTensor = zeros(1, 1, 1, dtype=q_proj.dtype, device=q_proj.device)
attn_scores: FloatTensor = baddbmm(
attn_bias,
q_proj,
k_proj.transpose(-1, -2),
# means don't apply bias
beta=0,
alpha=scale,
)
return attn_scores
def softmax(x: FloatTensor, dim=-1) -> FloatTensor:
"""Typical softmax. same as PyTorch's built-in torch.Tensor.softmax(), but step-by-step in case you want to modify it."""
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
x_exp_sum = x_exp.sum(dim, keepdim=True)
quotient = x_exp/x_exp_sum
return quotient
def topk_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
"""Softmax with a modified denominator, which sums fewer elements (the topk only) to help you size it to a magnitude on which the model was trained."""
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
x_exp_sum = x_exp.topk(k=k, dim=dim).values.sum(dim, keepdim=True)
quotient = x_exp/x_exp_sum
return quotient
sigma = 14.6147
head_dim = 40
scale = head_dim ** -.5
in_dist_key_tokens = 4096
in_dist: Fixtures = get_fixtures(
sigma=sigma,
head_dim=head_dim,
key_length_factor = in_dist_key_tokens/in_dist_key_tokens, # 1.0
key_tokens = in_dist_key_tokens,
)
out_dist_key_tokens = 9216
out_dist: Fixtures = get_fixtures(
sigma=sigma,
head_dim=head_dim,
key_length_factor = out_dist_key_tokens/in_dist_key_tokens, # 2.25
key_tokens = out_dist_key_tokens,
)
in_dist_attn_scores: FloatTensor = get_attn_scores(
q_proj=in_dist.q_proj,
k_proj=in_dist.k_proj,
scale=scale,
)
out_dist_attn_scores: FloatTensor = get_attn_scores(
q_proj=out_dist.q_proj,
k_proj=out_dist.k_proj,
scale=scale,
)
in_dist_attn_probs: FloatTensor = in_dist_attn_scores.softmax(dim=-1)
out_dist_attn_probs: FloatTensor = out_dist_attn_scores.softmax(dim=-1)
key_tokens: int = out_dist_attn_scores.size(-1)
preferred_token_count: int = in_dist_key_tokens # 4096
out_dist_topk_attn_probs: FloatTensor = topk_softmax(out_dist_attn_scores, k=preferred_token_count, dim=-1)
#### you probably want to put the following into a separate Jupyter cell so you can iterate on the charting
import matplotlib.pyplot as plt
from torch import histogram
from typing import NamedTuple
class Histogram(NamedTuple):
boundaries: FloatTensor
counts: FloatTensor
density=True
bins=200
hist, bin_edges = histogram(in_dist_attn_probs[0][0].log().float().cpu(), bins=bins, density=density)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='in-dist')
hist, bin_edges = histogram(out_dist_attn_probs[0][0].log().float().cpu(), bins=bins, density=density)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist')
hist, bin_edges = histogram((out_dist_attn_probs[0][0]*(out_dist_key_tokens/in_dist_key_tokens)).log().float().cpu(), bins=bins, density=density)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist (scaled)')
boundaries, counts = histogram(out_dist_topk_attn_probs[0][0].log().float().cpu(), bins)
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist (topk)')
plt.title(f'σ={sigma} log self-attn probs for uncond batch, first down-block, head 0, query token 0')
plt.legend()
plt.show()
@Birch-san
Copy link
Author

Birch-san commented Apr 5, 2023

Get yer tensors here (q_proj, k_proj). from these you can compute attn_scores and attn_probs as above.

These are for the first (biggest) self-attn layer, at the first (noisiest) sigma, 14.6147.
batch size is 16 (cond and uncond repeated over 8 attn heads)

512x512 (in-distribution; 4096 tokens):
https://birchlabs.co.uk/share/out_tensor/f1.0_s14.6147_k4096_c40_q_proj.pt (5.2MB)
https://birchlabs.co.uk/share/out_tensor/f1.0_s14.6147_k4096_c40_k_proj.pt (5.2MB)

768x768 (out-of-distribution by 2.25x; 9216 tokens):
https://birchlabs.co.uk/share/out_tensor/f2.25_s14.6147_k9216_c40_q_proj.pt (11.8MB)
https://birchlabs.co.uk/share/out_tensor/f2.25_s14.6147_k9216_c40_k_proj.pt (11.8MB)

@Birch-san
Copy link
Author

these are the files I have exported (but not uploaded), in case you want different sigmas or different self-attn layers. I think I only have the down blocks (I dunno an easy way to look up up/downness to include it in the filename):

2.6M f1.0_s0.0292_k1024_c80_k_proj.pt
2.6M f1.0_s0.0292_k1024_c80_q_proj.pt
1.3M f1.0_s0.0292_k256_c160_k_proj.pt
1.3M f1.0_s0.0292_k256_c160_q_proj.pt
5.1M f1.0_s0.0292_k4096_c40_k_proj.pt
5.1M f1.0_s0.0292_k4096_c40_q_proj.pt
321K f1.0_s0.0292_k64_c160_k_proj.pt
321K f1.0_s0.0292_k64_c160_q_proj.pt
2.6M f1.0_s0.0463_k1024_c80_k_proj.pt
2.6M f1.0_s0.0463_k1024_c80_q_proj.pt
1.3M f1.0_s0.0463_k256_c160_k_proj.pt
1.3M f1.0_s0.0463_k256_c160_q_proj.pt
5.1M f1.0_s0.0463_k4096_c40_k_proj.pt
5.1M f1.0_s0.0463_k4096_c40_q_proj.pt
321K f1.0_s0.0463_k64_c160_k_proj.pt
321K f1.0_s0.0463_k64_c160_q_proj.pt
2.6M f1.0_s0.0713_k1024_c80_k_proj.pt
2.6M f1.0_s0.0713_k1024_c80_q_proj.pt
1.3M f1.0_s0.0713_k256_c160_k_proj.pt
1.3M f1.0_s0.0713_k256_c160_q_proj.pt
5.1M f1.0_s0.0713_k4096_c40_k_proj.pt
5.1M f1.0_s0.0713_k4096_c40_q_proj.pt
321K f1.0_s0.0713_k64_c160_k_proj.pt
321K f1.0_s0.0713_k64_c160_q_proj.pt
2.6M f1.0_s0.1072_k1024_c80_k_proj.pt
2.6M f1.0_s0.1072_k1024_c80_q_proj.pt
1.3M f1.0_s0.1072_k256_c160_k_proj.pt
1.3M f1.0_s0.1072_k256_c160_q_proj.pt
5.1M f1.0_s0.1072_k4096_c40_k_proj.pt
5.1M f1.0_s0.1072_k4096_c40_q_proj.pt
321K f1.0_s0.1072_k64_c160_k_proj.pt
321K f1.0_s0.1072_k64_c160_q_proj.pt
2.6M f1.0_s0.1576_k1024_c80_k_proj.pt
2.6M f1.0_s0.1576_k1024_c80_q_proj.pt
1.3M f1.0_s0.1576_k256_c160_k_proj.pt
1.3M f1.0_s0.1576_k256_c160_q_proj.pt
5.1M f1.0_s0.1576_k4096_c40_k_proj.pt
5.1M f1.0_s0.1576_k4096_c40_q_proj.pt
321K f1.0_s0.1576_k64_c160_k_proj.pt
321K f1.0_s0.1576_k64_c160_q_proj.pt
2.6M f1.0_s0.2270_k1024_c80_k_proj.pt
2.6M f1.0_s0.2270_k1024_c80_q_proj.pt
1.3M f1.0_s0.2270_k256_c160_k_proj.pt
1.3M f1.0_s0.2270_k256_c160_q_proj.pt
5.1M f1.0_s0.2270_k4096_c40_k_proj.pt
5.1M f1.0_s0.2270_k4096_c40_q_proj.pt
321K f1.0_s0.2270_k64_c160_k_proj.pt
321K f1.0_s0.2270_k64_c160_q_proj.pt
2.6M f1.0_s0.3211_k1024_c80_k_proj.pt
2.6M f1.0_s0.3211_k1024_c80_q_proj.pt
1.3M f1.0_s0.3211_k256_c160_k_proj.pt
1.3M f1.0_s0.3211_k256_c160_q_proj.pt
5.1M f1.0_s0.3211_k4096_c40_k_proj.pt
5.1M f1.0_s0.3211_k4096_c40_q_proj.pt
321K f1.0_s0.3211_k64_c160_k_proj.pt
321K f1.0_s0.3211_k64_c160_q_proj.pt
2.6M f1.0_s0.4469_k1024_c80_k_proj.pt
2.6M f1.0_s0.4469_k1024_c80_q_proj.pt
1.3M f1.0_s0.4469_k256_c160_k_proj.pt
1.3M f1.0_s0.4469_k256_c160_q_proj.pt
5.1M f1.0_s0.4469_k4096_c40_k_proj.pt
5.1M f1.0_s0.4469_k4096_c40_q_proj.pt
321K f1.0_s0.4469_k64_c160_k_proj.pt
321K f1.0_s0.4469_k64_c160_q_proj.pt
2.6M f1.0_s0.6128_k1024_c80_k_proj.pt
2.6M f1.0_s0.6128_k1024_c80_q_proj.pt
1.3M f1.0_s0.6128_k256_c160_k_proj.pt
1.3M f1.0_s0.6128_k256_c160_q_proj.pt
5.1M f1.0_s0.6128_k4096_c40_k_proj.pt
5.1M f1.0_s0.6128_k4096_c40_q_proj.pt
321K f1.0_s0.6128_k64_c160_k_proj.pt
321K f1.0_s0.6128_k64_c160_q_proj.pt
2.6M f1.0_s0.8289_k1024_c80_k_proj.pt
2.6M f1.0_s0.8289_k1024_c80_q_proj.pt
1.3M f1.0_s0.8289_k256_c160_k_proj.pt
1.3M f1.0_s0.8289_k256_c160_q_proj.pt
5.1M f1.0_s0.8289_k4096_c40_k_proj.pt
5.1M f1.0_s0.8289_k4096_c40_q_proj.pt
321K f1.0_s0.8289_k64_c160_k_proj.pt
321K f1.0_s0.8289_k64_c160_q_proj.pt
2.6M f1.0_s1.1072_k1024_c80_k_proj.pt
2.6M f1.0_s1.1072_k1024_c80_q_proj.pt
1.3M f1.0_s1.1072_k256_c160_k_proj.pt
1.3M f1.0_s1.1072_k256_c160_q_proj.pt
5.1M f1.0_s1.1072_k4096_c40_k_proj.pt
5.1M f1.0_s1.1072_k4096_c40_q_proj.pt
321K f1.0_s1.1072_k64_c160_k_proj.pt
321K f1.0_s1.1072_k64_c160_q_proj.pt
2.6M f1.0_s11.9776_k1024_c80_k_proj.pt
2.6M f1.0_s11.9776_k1024_c80_q_proj.pt
1.3M f1.0_s11.9776_k256_c160_k_proj.pt
1.3M f1.0_s11.9776_k256_c160_q_proj.pt
5.1M f1.0_s11.9776_k4096_c40_k_proj.pt
5.1M f1.0_s11.9776_k4096_c40_q_proj.pt
321K f1.0_s11.9776_k64_c160_k_proj.pt
321K f1.0_s11.9776_k64_c160_q_proj.pt
2.6M f1.0_s14.6147_k1024_c80_k_proj.pt
2.6M f1.0_s14.6147_k1024_c80_q_proj.pt
1.3M f1.0_s14.6147_k256_c160_k_proj.pt
1.3M f1.0_s14.6147_k256_c160_q_proj.pt
5.1M f1.0_s14.6147_k4096_c40_k_proj.pt
5.1M f1.0_s14.6147_k4096_c40_q_proj.pt
321K f1.0_s14.6147_k64_c160_k_proj.pt
321K f1.0_s14.6147_k64_c160_q_proj.pt
2.6M f1.0_s1.4621_k1024_c80_k_proj.pt
2.6M f1.0_s1.4621_k1024_c80_q_proj.pt
1.3M f1.0_s1.4621_k256_c160_k_proj.pt
1.3M f1.0_s1.4621_k256_c160_q_proj.pt
5.1M f1.0_s1.4621_k4096_c40_k_proj.pt
5.1M f1.0_s1.4621_k4096_c40_q_proj.pt
321K f1.0_s1.4621_k64_c160_k_proj.pt
321K f1.0_s1.4621_k64_c160_q_proj.pt
2.6M f1.0_s1.9104_k1024_c80_k_proj.pt
2.6M f1.0_s1.9104_k1024_c80_q_proj.pt
1.3M f1.0_s1.9104_k256_c160_k_proj.pt
1.3M f1.0_s1.9104_k256_c160_q_proj.pt
5.1M f1.0_s1.9104_k4096_c40_k_proj.pt
5.1M f1.0_s1.9104_k4096_c40_q_proj.pt
321K f1.0_s1.9104_k64_c160_k_proj.pt
321K f1.0_s1.9104_k64_c160_q_proj.pt
2.6M f1.0_s2.4716_k1024_c80_k_proj.pt
2.6M f1.0_s2.4716_k1024_c80_q_proj.pt
1.3M f1.0_s2.4716_k256_c160_k_proj.pt
1.3M f1.0_s2.4716_k256_c160_q_proj.pt
5.1M f1.0_s2.4716_k4096_c40_k_proj.pt
5.1M f1.0_s2.4716_k4096_c40_q_proj.pt
321K f1.0_s2.4716_k64_c160_k_proj.pt
321K f1.0_s2.4716_k64_c160_q_proj.pt
2.6M f1.0_s3.1686_k1024_c80_k_proj.pt
2.6M f1.0_s3.1686_k1024_c80_q_proj.pt
1.3M f1.0_s3.1686_k256_c160_k_proj.pt
1.3M f1.0_s3.1686_k256_c160_q_proj.pt
5.1M f1.0_s3.1686_k4096_c40_k_proj.pt
5.1M f1.0_s3.1686_k4096_c40_q_proj.pt
321K f1.0_s3.1686_k64_c160_k_proj.pt
321K f1.0_s3.1686_k64_c160_q_proj.pt
2.6M f1.0_s4.0277_k1024_c80_k_proj.pt
2.6M f1.0_s4.0277_k1024_c80_q_proj.pt
1.3M f1.0_s4.0277_k256_c160_k_proj.pt
1.3M f1.0_s4.0277_k256_c160_q_proj.pt
5.1M f1.0_s4.0277_k4096_c40_k_proj.pt
5.1M f1.0_s4.0277_k4096_c40_q_proj.pt
321K f1.0_s4.0277_k64_c160_k_proj.pt
321K f1.0_s4.0277_k64_c160_q_proj.pt
2.6M f1.0_s5.0793_k1024_c80_k_proj.pt
2.6M f1.0_s5.0793_k1024_c80_q_proj.pt
1.3M f1.0_s5.0793_k256_c160_k_proj.pt
1.3M f1.0_s5.0793_k256_c160_q_proj.pt
5.1M f1.0_s5.0793_k4096_c40_k_proj.pt
5.1M f1.0_s5.0793_k4096_c40_q_proj.pt
321K f1.0_s5.0793_k64_c160_k_proj.pt
321K f1.0_s5.0793_k64_c160_q_proj.pt
2.6M f1.0_s6.3579_k1024_c80_k_proj.pt
2.6M f1.0_s6.3579_k1024_c80_q_proj.pt
1.3M f1.0_s6.3579_k256_c160_k_proj.pt
1.3M f1.0_s6.3579_k256_c160_q_proj.pt
5.1M f1.0_s6.3579_k4096_c40_k_proj.pt
5.1M f1.0_s6.3579_k4096_c40_q_proj.pt
321K f1.0_s6.3579_k64_c160_k_proj.pt
321K f1.0_s6.3579_k64_c160_q_proj.pt
2.6M f1.0_s7.9029_k1024_c80_k_proj.pt
2.6M f1.0_s7.9029_k1024_c80_q_proj.pt
1.3M f1.0_s7.9029_k256_c160_k_proj.pt
1.3M f1.0_s7.9029_k256_c160_q_proj.pt
5.1M f1.0_s7.9029_k4096_c40_k_proj.pt
5.1M f1.0_s7.9029_k4096_c40_q_proj.pt
321K f1.0_s7.9029_k64_c160_k_proj.pt
321K f1.0_s7.9029_k64_c160_q_proj.pt
2.6M f1.0_s9.7593_k1024_c80_k_proj.pt
2.6M f1.0_s9.7593_k1024_c80_q_proj.pt
1.3M f1.0_s9.7593_k256_c160_k_proj.pt
1.3M f1.0_s9.7593_k256_c160_q_proj.pt
5.1M f1.0_s9.7593_k4096_c40_k_proj.pt
5.1M f1.0_s9.7593_k4096_c40_q_proj.pt
321K f1.0_s9.7593_k64_c160_k_proj.pt
321K f1.0_s9.7593_k64_c160_q_proj.pt
721K f2.25_s0.0292_k144_c160_k_proj.pt
721K f2.25_s0.0292_k144_c160_q_proj.pt
5.7M f2.25_s0.0292_k2304_c80_k_proj.pt
5.7M f2.25_s0.0292_k2304_c80_q_proj.pt
2.9M f2.25_s0.0292_k576_c160_k_proj.pt
2.9M f2.25_s0.0292_k576_c160_q_proj.pt
 12M f2.25_s0.0292_k9216_c40_k_proj.pt
 12M f2.25_s0.0292_k9216_c40_q_proj.pt
721K f2.25_s0.0463_k144_c160_k_proj.pt
721K f2.25_s0.0463_k144_c160_q_proj.pt
5.7M f2.25_s0.0463_k2304_c80_k_proj.pt
5.7M f2.25_s0.0463_k2304_c80_q_proj.pt
2.9M f2.25_s0.0463_k576_c160_k_proj.pt
2.9M f2.25_s0.0463_k576_c160_q_proj.pt
 12M f2.25_s0.0463_k9216_c40_k_proj.pt
 12M f2.25_s0.0463_k9216_c40_q_proj.pt
721K f2.25_s0.0713_k144_c160_k_proj.pt
721K f2.25_s0.0713_k144_c160_q_proj.pt
5.7M f2.25_s0.0713_k2304_c80_k_proj.pt
5.7M f2.25_s0.0713_k2304_c80_q_proj.pt
2.9M f2.25_s0.0713_k576_c160_k_proj.pt
2.9M f2.25_s0.0713_k576_c160_q_proj.pt
 12M f2.25_s0.0713_k9216_c40_k_proj.pt
 12M f2.25_s0.0713_k9216_c40_q_proj.pt
721K f2.25_s0.1072_k144_c160_k_proj.pt
721K f2.25_s0.1072_k144_c160_q_proj.pt
5.7M f2.25_s0.1072_k2304_c80_k_proj.pt
5.7M f2.25_s0.1072_k2304_c80_q_proj.pt
2.9M f2.25_s0.1072_k576_c160_k_proj.pt
2.9M f2.25_s0.1072_k576_c160_q_proj.pt
 12M f2.25_s0.1072_k9216_c40_k_proj.pt
 12M f2.25_s0.1072_k9216_c40_q_proj.pt
721K f2.25_s0.1576_k144_c160_k_proj.pt
721K f2.25_s0.1576_k144_c160_q_proj.pt
5.7M f2.25_s0.1576_k2304_c80_k_proj.pt
5.7M f2.25_s0.1576_k2304_c80_q_proj.pt
2.9M f2.25_s0.1576_k576_c160_k_proj.pt
2.9M f2.25_s0.1576_k576_c160_q_proj.pt
 12M f2.25_s0.1576_k9216_c40_k_proj.pt
 12M f2.25_s0.1576_k9216_c40_q_proj.pt
721K f2.25_s0.2270_k144_c160_k_proj.pt
721K f2.25_s0.2270_k144_c160_q_proj.pt
5.7M f2.25_s0.2270_k2304_c80_k_proj.pt
5.7M f2.25_s0.2270_k2304_c80_q_proj.pt
2.9M f2.25_s0.2270_k576_c160_k_proj.pt
2.9M f2.25_s0.2270_k576_c160_q_proj.pt
 12M f2.25_s0.2270_k9216_c40_k_proj.pt
 12M f2.25_s0.2270_k9216_c40_q_proj.pt
721K f2.25_s0.3211_k144_c160_k_proj.pt
721K f2.25_s0.3211_k144_c160_q_proj.pt
5.7M f2.25_s0.3211_k2304_c80_k_proj.pt
5.7M f2.25_s0.3211_k2304_c80_q_proj.pt
2.9M f2.25_s0.3211_k576_c160_k_proj.pt
2.9M f2.25_s0.3211_k576_c160_q_proj.pt
 12M f2.25_s0.3211_k9216_c40_k_proj.pt
 12M f2.25_s0.3211_k9216_c40_q_proj.pt
721K f2.25_s0.4469_k144_c160_k_proj.pt
721K f2.25_s0.4469_k144_c160_q_proj.pt
5.7M f2.25_s0.4469_k2304_c80_k_proj.pt
5.7M f2.25_s0.4469_k2304_c80_q_proj.pt
2.9M f2.25_s0.4469_k576_c160_k_proj.pt
2.9M f2.25_s0.4469_k576_c160_q_proj.pt
 12M f2.25_s0.4469_k9216_c40_k_proj.pt
 12M f2.25_s0.4469_k9216_c40_q_proj.pt
721K f2.25_s0.6128_k144_c160_k_proj.pt
721K f2.25_s0.6128_k144_c160_q_proj.pt
5.7M f2.25_s0.6128_k2304_c80_k_proj.pt
5.7M f2.25_s0.6128_k2304_c80_q_proj.pt
2.9M f2.25_s0.6128_k576_c160_k_proj.pt
2.9M f2.25_s0.6128_k576_c160_q_proj.pt
 12M f2.25_s0.6128_k9216_c40_k_proj.pt
 12M f2.25_s0.6128_k9216_c40_q_proj.pt
721K f2.25_s0.8289_k144_c160_k_proj.pt
721K f2.25_s0.8289_k144_c160_q_proj.pt
5.7M f2.25_s0.8289_k2304_c80_k_proj.pt
5.7M f2.25_s0.8289_k2304_c80_q_proj.pt
2.9M f2.25_s0.8289_k576_c160_k_proj.pt
2.9M f2.25_s0.8289_k576_c160_q_proj.pt
 12M f2.25_s0.8289_k9216_c40_k_proj.pt
 12M f2.25_s0.8289_k9216_c40_q_proj.pt
721K f2.25_s1.1072_k144_c160_k_proj.pt
721K f2.25_s1.1072_k144_c160_q_proj.pt
5.7M f2.25_s1.1072_k2304_c80_k_proj.pt
5.7M f2.25_s1.1072_k2304_c80_q_proj.pt
2.9M f2.25_s1.1072_k576_c160_k_proj.pt
2.9M f2.25_s1.1072_k576_c160_q_proj.pt
 12M f2.25_s1.1072_k9216_c40_k_proj.pt
 12M f2.25_s1.1072_k9216_c40_q_proj.pt
721K f2.25_s11.9776_k144_c160_k_proj.pt
721K f2.25_s11.9776_k144_c160_q_proj.pt
5.7M f2.25_s11.9776_k2304_c80_k_proj.pt
5.7M f2.25_s11.9776_k2304_c80_q_proj.pt
2.9M f2.25_s11.9776_k576_c160_k_proj.pt
2.9M f2.25_s11.9776_k576_c160_q_proj.pt
 12M f2.25_s11.9776_k9216_c40_k_proj.pt
 12M f2.25_s11.9776_k9216_c40_q_proj.pt
721K f2.25_s14.6147_k144_c160_k_proj.pt
721K f2.25_s14.6147_k144_c160_q_proj.pt
5.7M f2.25_s14.6147_k2304_c80_k_proj.pt
5.7M f2.25_s14.6147_k2304_c80_q_proj.pt
2.9M f2.25_s14.6147_k576_c160_k_proj.pt
2.9M f2.25_s14.6147_k576_c160_q_proj.pt
 12M f2.25_s14.6147_k9216_c40_k_proj.pt
 12M f2.25_s14.6147_k9216_c40_q_proj.pt
721K f2.25_s1.4621_k144_c160_k_proj.pt
721K f2.25_s1.4621_k144_c160_q_proj.pt
5.7M f2.25_s1.4621_k2304_c80_k_proj.pt
5.7M f2.25_s1.4621_k2304_c80_q_proj.pt
2.9M f2.25_s1.4621_k576_c160_k_proj.pt
2.9M f2.25_s1.4621_k576_c160_q_proj.pt
 12M f2.25_s1.4621_k9216_c40_k_proj.pt
 12M f2.25_s1.4621_k9216_c40_q_proj.pt
721K f2.25_s1.9104_k144_c160_k_proj.pt
721K f2.25_s1.9104_k144_c160_q_proj.pt
5.7M f2.25_s1.9104_k2304_c80_k_proj.pt
5.7M f2.25_s1.9104_k2304_c80_q_proj.pt
2.9M f2.25_s1.9104_k576_c160_k_proj.pt
2.9M f2.25_s1.9104_k576_c160_q_proj.pt
 12M f2.25_s1.9104_k9216_c40_k_proj.pt
 12M f2.25_s1.9104_k9216_c40_q_proj.pt
721K f2.25_s2.4716_k144_c160_k_proj.pt
721K f2.25_s2.4716_k144_c160_q_proj.pt
5.7M f2.25_s2.4716_k2304_c80_k_proj.pt
5.7M f2.25_s2.4716_k2304_c80_q_proj.pt
2.9M f2.25_s2.4716_k576_c160_k_proj.pt
2.9M f2.25_s2.4716_k576_c160_q_proj.pt
 12M f2.25_s2.4716_k9216_c40_k_proj.pt
 12M f2.25_s2.4716_k9216_c40_q_proj.pt
721K f2.25_s3.1686_k144_c160_k_proj.pt
721K f2.25_s3.1686_k144_c160_q_proj.pt
5.7M f2.25_s3.1686_k2304_c80_k_proj.pt
5.7M f2.25_s3.1686_k2304_c80_q_proj.pt
2.9M f2.25_s3.1686_k576_c160_k_proj.pt
2.9M f2.25_s3.1686_k576_c160_q_proj.pt
 12M f2.25_s3.1686_k9216_c40_k_proj.pt
 12M f2.25_s3.1686_k9216_c40_q_proj.pt
721K f2.25_s4.0277_k144_c160_k_proj.pt
721K f2.25_s4.0277_k144_c160_q_proj.pt
5.7M f2.25_s4.0277_k2304_c80_k_proj.pt
5.7M f2.25_s4.0277_k2304_c80_q_proj.pt
2.9M f2.25_s4.0277_k576_c160_k_proj.pt
2.9M f2.25_s4.0277_k576_c160_q_proj.pt
 12M f2.25_s4.0277_k9216_c40_k_proj.pt
 12M f2.25_s4.0277_k9216_c40_q_proj.pt
721K f2.25_s5.0793_k144_c160_k_proj.pt
721K f2.25_s5.0793_k144_c160_q_proj.pt
5.7M f2.25_s5.0793_k2304_c80_k_proj.pt
5.7M f2.25_s5.0793_k2304_c80_q_proj.pt
2.9M f2.25_s5.0793_k576_c160_k_proj.pt
2.9M f2.25_s5.0793_k576_c160_q_proj.pt
 12M f2.25_s5.0793_k9216_c40_k_proj.pt
 12M f2.25_s5.0793_k9216_c40_q_proj.pt
721K f2.25_s6.3579_k144_c160_k_proj.pt
721K f2.25_s6.3579_k144_c160_q_proj.pt
5.7M f2.25_s6.3579_k2304_c80_k_proj.pt
5.7M f2.25_s6.3579_k2304_c80_q_proj.pt
2.9M f2.25_s6.3579_k576_c160_k_proj.pt
2.9M f2.25_s6.3579_k576_c160_q_proj.pt
 12M f2.25_s6.3579_k9216_c40_k_proj.pt
 12M f2.25_s6.3579_k9216_c40_q_proj.pt
721K f2.25_s7.9029_k144_c160_k_proj.pt
721K f2.25_s7.9029_k144_c160_q_proj.pt
5.7M f2.25_s7.9029_k2304_c80_k_proj.pt
5.7M f2.25_s7.9029_k2304_c80_q_proj.pt
2.9M f2.25_s7.9029_k576_c160_k_proj.pt
2.9M f2.25_s7.9029_k576_c160_q_proj.pt
 12M f2.25_s7.9029_k9216_c40_k_proj.pt
 12M f2.25_s7.9029_k9216_c40_q_proj.pt
721K f2.25_s9.7593_k144_c160_k_proj.pt
721K f2.25_s9.7593_k144_c160_q_proj.pt
5.7M f2.25_s9.7593_k2304_c80_k_proj.pt
5.7M f2.25_s9.7593_k2304_c80_q_proj.pt
2.9M f2.25_s9.7593_k576_c160_k_proj.pt
2.9M f2.25_s9.7593_k576_c160_q_proj.pt
 12M f2.25_s9.7593_k9216_c40_k_proj.pt
 12M f2.25_s9.7593_k9216_c40_q_proj.pt

@Birch-san
Copy link
Author

instead of taking topk key tokens (which advantages larger scores): I tried a nearest-neighbour downsample of the 9216 key tokens to 4096. basically like discarding every second token.

resample_softmax attempts to sort the distribution first (which I think could give a fairer resample, assumping you interpolate between the datapoints), prior to resampling. however the sort ran out of memory, so wasn't able to evaluate this.

def resample_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
  """Softmax with a modified denominator. for each query token: sorts attn_scores, resamples key dimension to size k; you can use this to increase/decrease denominator to the magnitude on which the model was trained."""
  maxes = x.max(dim, keepdim=True).values
  diffs = x-maxes
  del maxes
  torch.cuda.empty_cache()
  gc.collect()
  diffs_sorted = diffs.sort(dim=dim).values
  x_exp = diffs.exp()
  del diffs
  # for downsampling:
  #   mode='area' has best PSNR.. but maybe that's not important.
  # for upsampling:
  #   lerping between attn scores (mode='linear') feels reasonable
  #   repeating attn scores (mode='nearest_exact') might be reasonable too
  # not whether we'd care about anti-aliasing
  diffs_resampled = interpolate(diffs_sorted, size=(*diffs_sorted.shape[:-1], k), mode='linear', antialias=False)
  del diffs_sorted
  diffs_exp_sum = diffs_resampled.exp().sum(dim, keepdim=True)
  del diffs_resampled
  quotient = x_exp/diffs_exp_sum
  return quotient

nearest-neighbour resample ran to completion but completely destroyed the image.

def resample_crude_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
  """Softmax with a modified denominator. for each query token: resamples key dimension to size k; you can use this to increase/decrease denominator to the magnitude on which the model was trained."""
  maxes = x.max(dim, keepdim=True).values
  diffs = x-maxes
  del maxes
  x_exp = diffs.exp()
  diffs_resampled = interpolate(diffs, scale_factor=k/diffs.size(-1), mode='nearest-exact', antialias=False)
  del diffs
  diffs_exp_sum = diffs_resampled.exp().sum(dim, keepdim=True)
  del diffs_resampled
  quotient = x_exp/diffs_exp_sum
  return quotient

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment