Last active
April 8, 2023 10:32
-
-
Save Birch-san/c1f86f6e10c3e9ce091d6507f0966179 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import FloatTensor, load, baddbmm, zeros | |
from dataclasses import dataclass | |
import torch | |
from os.path import join | |
@dataclass | |
class Fixtures: | |
q_proj: FloatTensor | |
k_proj: FloatTensor | |
device = torch.device('cuda') | |
def get_fixtures( | |
sigma: float, | |
head_dim: int, | |
key_length_factor: float, | |
key_tokens: int, | |
) -> Fixtures: | |
""" | |
Imports pre-saved tensors from filesystem, by building a file path to identify which stage of which diffusion run we want to load. | |
Args: | |
sigma: | |
I believe discretized sigmas for 22-step Karras schedule *should* be: | |
14.6146, 11.9484, 9.7548, 7.9216, 6.3493, 5.0878, 4.0300, 3.1667, 2.4743, 1.9103, 1.4601, 1.1084, 0.8299, 0.6127, 0.4471, 0.3213, 0.2281, 0.1580, 0.1072, 0.0720, 0.0507, 0.0292 | |
Yet, k-diffusion invoked my Unet with the following sigmas: | |
14.6147, 11.9776, 9.7593, 7.9029, 6.3579, 5.0793, 4.0277, 3.1686, 2.4716, 1.9104, 1.4621, 1.1072, 0.8289, 0.6128, 0.4469, 0.3211, 0.2270, 0.1576, 0.1072, 0.0713, 0.0463, 0.0292 | |
head_dim: | |
https://twitter.com/Birchlabs/status/1609605588601675776 | |
varies depending on which self-attn layer you're on (at least in SD1.5, lol). I have files for: | |
40, 80, 160, 160 | |
key_length_factor: | |
1.0 = in-distribution image, 512x512, key length as large as 4096 | |
2.25 = out-of-distribution image, 768x768, key length as large as 9216 | |
key_tokens: | |
varies depending on which self-attn layer you're on, and how large your latents are | |
4096, 1024, 256, 64 | |
9216, 2304, 576, 144 | |
""" | |
root_dir='/home/birch/git/diffusers-play' | |
in_dir=join(root_dir, 'out_tensor') | |
tensor_path_prefix=join(in_dir, f'f{key_length_factor}_s{sigma:.4f}_k{key_tokens}_c{head_dim}') | |
q_proj: FloatTensor = load(f'{tensor_path_prefix}_q_proj.pt', weights_only=True, map_location=device) | |
k_proj: FloatTensor = load(f'{tensor_path_prefix}_k_proj.pt', weights_only=True, map_location=device) | |
return Fixtures( | |
q_proj=q_proj, | |
k_proj=k_proj, | |
) | |
def get_attn_scores( | |
q_proj: FloatTensor, | |
k_proj: FloatTensor, | |
scale: float, | |
) -> FloatTensor: | |
"""Computes (q_proj @ k_proj.T)*scale""" | |
# no bias, but baddbmm's API requires a tensor even if coefficient is 0 | |
attn_bias: FloatTensor = zeros(1, 1, 1, dtype=q_proj.dtype, device=q_proj.device) | |
attn_scores: FloatTensor = baddbmm( | |
attn_bias, | |
q_proj, | |
k_proj.transpose(-1, -2), | |
# means don't apply bias | |
beta=0, | |
alpha=scale, | |
) | |
return attn_scores | |
def softmax(x: FloatTensor, dim=-1) -> FloatTensor: | |
"""Typical softmax. same as PyTorch's built-in torch.Tensor.softmax(), but step-by-step in case you want to modify it.""" | |
maxes = x.max(dim, keepdim=True).values | |
diffs = x-maxes | |
x_exp = diffs.exp() | |
x_exp_sum = x_exp.sum(dim, keepdim=True) | |
quotient = x_exp/x_exp_sum | |
return quotient | |
def topk_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor: | |
"""Softmax with a modified denominator, which sums fewer elements (the topk only) to help you size it to a magnitude on which the model was trained.""" | |
maxes = x.max(dim, keepdim=True).values | |
diffs = x-maxes | |
x_exp = diffs.exp() | |
x_exp_sum = x_exp.topk(k=k, dim=dim).values.sum(dim, keepdim=True) | |
quotient = x_exp/x_exp_sum | |
return quotient | |
sigma = 14.6147 | |
head_dim = 40 | |
scale = head_dim ** -.5 | |
in_dist_key_tokens = 4096 | |
in_dist: Fixtures = get_fixtures( | |
sigma=sigma, | |
head_dim=head_dim, | |
key_length_factor = in_dist_key_tokens/in_dist_key_tokens, # 1.0 | |
key_tokens = in_dist_key_tokens, | |
) | |
out_dist_key_tokens = 9216 | |
out_dist: Fixtures = get_fixtures( | |
sigma=sigma, | |
head_dim=head_dim, | |
key_length_factor = out_dist_key_tokens/in_dist_key_tokens, # 2.25 | |
key_tokens = out_dist_key_tokens, | |
) | |
in_dist_attn_scores: FloatTensor = get_attn_scores( | |
q_proj=in_dist.q_proj, | |
k_proj=in_dist.k_proj, | |
scale=scale, | |
) | |
out_dist_attn_scores: FloatTensor = get_attn_scores( | |
q_proj=out_dist.q_proj, | |
k_proj=out_dist.k_proj, | |
scale=scale, | |
) | |
in_dist_attn_probs: FloatTensor = in_dist_attn_scores.softmax(dim=-1) | |
out_dist_attn_probs: FloatTensor = out_dist_attn_scores.softmax(dim=-1) | |
key_tokens: int = out_dist_attn_scores.size(-1) | |
preferred_token_count: int = in_dist_key_tokens # 4096 | |
out_dist_topk_attn_probs: FloatTensor = topk_softmax(out_dist_attn_scores, k=preferred_token_count, dim=-1) | |
#### you probably want to put the following into a separate Jupyter cell so you can iterate on the charting | |
import matplotlib.pyplot as plt | |
from torch import histogram | |
from typing import NamedTuple | |
class Histogram(NamedTuple): | |
boundaries: FloatTensor | |
counts: FloatTensor | |
density=True | |
bins=200 | |
hist, bin_edges = histogram(in_dist_attn_probs[0][0].log().float().cpu(), bins=bins, density=density) | |
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='in-dist') | |
hist, bin_edges = histogram(out_dist_attn_probs[0][0].log().float().cpu(), bins=bins, density=density) | |
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist') | |
hist, bin_edges = histogram((out_dist_attn_probs[0][0]*(out_dist_key_tokens/in_dist_key_tokens)).log().float().cpu(), bins=bins, density=density) | |
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist (scaled)') | |
boundaries, counts = histogram(out_dist_topk_attn_probs[0][0].log().float().cpu(), bins) | |
plt.fill_between(bin_edges[:-1], 0, hist, alpha=0.6, label='out-dist (topk)') | |
plt.title(f'σ={sigma} log self-attn probs for uncond batch, first down-block, head 0, query token 0') | |
plt.legend() | |
plt.show() |
these are the files I have exported (but not uploaded), in case you want different sigmas or different self-attn layers. I think I only have the down blocks (I dunno an easy way to look up up/downness to include it in the filename):
2.6M f1.0_s0.0292_k1024_c80_k_proj.pt
2.6M f1.0_s0.0292_k1024_c80_q_proj.pt
1.3M f1.0_s0.0292_k256_c160_k_proj.pt
1.3M f1.0_s0.0292_k256_c160_q_proj.pt
5.1M f1.0_s0.0292_k4096_c40_k_proj.pt
5.1M f1.0_s0.0292_k4096_c40_q_proj.pt
321K f1.0_s0.0292_k64_c160_k_proj.pt
321K f1.0_s0.0292_k64_c160_q_proj.pt
2.6M f1.0_s0.0463_k1024_c80_k_proj.pt
2.6M f1.0_s0.0463_k1024_c80_q_proj.pt
1.3M f1.0_s0.0463_k256_c160_k_proj.pt
1.3M f1.0_s0.0463_k256_c160_q_proj.pt
5.1M f1.0_s0.0463_k4096_c40_k_proj.pt
5.1M f1.0_s0.0463_k4096_c40_q_proj.pt
321K f1.0_s0.0463_k64_c160_k_proj.pt
321K f1.0_s0.0463_k64_c160_q_proj.pt
2.6M f1.0_s0.0713_k1024_c80_k_proj.pt
2.6M f1.0_s0.0713_k1024_c80_q_proj.pt
1.3M f1.0_s0.0713_k256_c160_k_proj.pt
1.3M f1.0_s0.0713_k256_c160_q_proj.pt
5.1M f1.0_s0.0713_k4096_c40_k_proj.pt
5.1M f1.0_s0.0713_k4096_c40_q_proj.pt
321K f1.0_s0.0713_k64_c160_k_proj.pt
321K f1.0_s0.0713_k64_c160_q_proj.pt
2.6M f1.0_s0.1072_k1024_c80_k_proj.pt
2.6M f1.0_s0.1072_k1024_c80_q_proj.pt
1.3M f1.0_s0.1072_k256_c160_k_proj.pt
1.3M f1.0_s0.1072_k256_c160_q_proj.pt
5.1M f1.0_s0.1072_k4096_c40_k_proj.pt
5.1M f1.0_s0.1072_k4096_c40_q_proj.pt
321K f1.0_s0.1072_k64_c160_k_proj.pt
321K f1.0_s0.1072_k64_c160_q_proj.pt
2.6M f1.0_s0.1576_k1024_c80_k_proj.pt
2.6M f1.0_s0.1576_k1024_c80_q_proj.pt
1.3M f1.0_s0.1576_k256_c160_k_proj.pt
1.3M f1.0_s0.1576_k256_c160_q_proj.pt
5.1M f1.0_s0.1576_k4096_c40_k_proj.pt
5.1M f1.0_s0.1576_k4096_c40_q_proj.pt
321K f1.0_s0.1576_k64_c160_k_proj.pt
321K f1.0_s0.1576_k64_c160_q_proj.pt
2.6M f1.0_s0.2270_k1024_c80_k_proj.pt
2.6M f1.0_s0.2270_k1024_c80_q_proj.pt
1.3M f1.0_s0.2270_k256_c160_k_proj.pt
1.3M f1.0_s0.2270_k256_c160_q_proj.pt
5.1M f1.0_s0.2270_k4096_c40_k_proj.pt
5.1M f1.0_s0.2270_k4096_c40_q_proj.pt
321K f1.0_s0.2270_k64_c160_k_proj.pt
321K f1.0_s0.2270_k64_c160_q_proj.pt
2.6M f1.0_s0.3211_k1024_c80_k_proj.pt
2.6M f1.0_s0.3211_k1024_c80_q_proj.pt
1.3M f1.0_s0.3211_k256_c160_k_proj.pt
1.3M f1.0_s0.3211_k256_c160_q_proj.pt
5.1M f1.0_s0.3211_k4096_c40_k_proj.pt
5.1M f1.0_s0.3211_k4096_c40_q_proj.pt
321K f1.0_s0.3211_k64_c160_k_proj.pt
321K f1.0_s0.3211_k64_c160_q_proj.pt
2.6M f1.0_s0.4469_k1024_c80_k_proj.pt
2.6M f1.0_s0.4469_k1024_c80_q_proj.pt
1.3M f1.0_s0.4469_k256_c160_k_proj.pt
1.3M f1.0_s0.4469_k256_c160_q_proj.pt
5.1M f1.0_s0.4469_k4096_c40_k_proj.pt
5.1M f1.0_s0.4469_k4096_c40_q_proj.pt
321K f1.0_s0.4469_k64_c160_k_proj.pt
321K f1.0_s0.4469_k64_c160_q_proj.pt
2.6M f1.0_s0.6128_k1024_c80_k_proj.pt
2.6M f1.0_s0.6128_k1024_c80_q_proj.pt
1.3M f1.0_s0.6128_k256_c160_k_proj.pt
1.3M f1.0_s0.6128_k256_c160_q_proj.pt
5.1M f1.0_s0.6128_k4096_c40_k_proj.pt
5.1M f1.0_s0.6128_k4096_c40_q_proj.pt
321K f1.0_s0.6128_k64_c160_k_proj.pt
321K f1.0_s0.6128_k64_c160_q_proj.pt
2.6M f1.0_s0.8289_k1024_c80_k_proj.pt
2.6M f1.0_s0.8289_k1024_c80_q_proj.pt
1.3M f1.0_s0.8289_k256_c160_k_proj.pt
1.3M f1.0_s0.8289_k256_c160_q_proj.pt
5.1M f1.0_s0.8289_k4096_c40_k_proj.pt
5.1M f1.0_s0.8289_k4096_c40_q_proj.pt
321K f1.0_s0.8289_k64_c160_k_proj.pt
321K f1.0_s0.8289_k64_c160_q_proj.pt
2.6M f1.0_s1.1072_k1024_c80_k_proj.pt
2.6M f1.0_s1.1072_k1024_c80_q_proj.pt
1.3M f1.0_s1.1072_k256_c160_k_proj.pt
1.3M f1.0_s1.1072_k256_c160_q_proj.pt
5.1M f1.0_s1.1072_k4096_c40_k_proj.pt
5.1M f1.0_s1.1072_k4096_c40_q_proj.pt
321K f1.0_s1.1072_k64_c160_k_proj.pt
321K f1.0_s1.1072_k64_c160_q_proj.pt
2.6M f1.0_s11.9776_k1024_c80_k_proj.pt
2.6M f1.0_s11.9776_k1024_c80_q_proj.pt
1.3M f1.0_s11.9776_k256_c160_k_proj.pt
1.3M f1.0_s11.9776_k256_c160_q_proj.pt
5.1M f1.0_s11.9776_k4096_c40_k_proj.pt
5.1M f1.0_s11.9776_k4096_c40_q_proj.pt
321K f1.0_s11.9776_k64_c160_k_proj.pt
321K f1.0_s11.9776_k64_c160_q_proj.pt
2.6M f1.0_s14.6147_k1024_c80_k_proj.pt
2.6M f1.0_s14.6147_k1024_c80_q_proj.pt
1.3M f1.0_s14.6147_k256_c160_k_proj.pt
1.3M f1.0_s14.6147_k256_c160_q_proj.pt
5.1M f1.0_s14.6147_k4096_c40_k_proj.pt
5.1M f1.0_s14.6147_k4096_c40_q_proj.pt
321K f1.0_s14.6147_k64_c160_k_proj.pt
321K f1.0_s14.6147_k64_c160_q_proj.pt
2.6M f1.0_s1.4621_k1024_c80_k_proj.pt
2.6M f1.0_s1.4621_k1024_c80_q_proj.pt
1.3M f1.0_s1.4621_k256_c160_k_proj.pt
1.3M f1.0_s1.4621_k256_c160_q_proj.pt
5.1M f1.0_s1.4621_k4096_c40_k_proj.pt
5.1M f1.0_s1.4621_k4096_c40_q_proj.pt
321K f1.0_s1.4621_k64_c160_k_proj.pt
321K f1.0_s1.4621_k64_c160_q_proj.pt
2.6M f1.0_s1.9104_k1024_c80_k_proj.pt
2.6M f1.0_s1.9104_k1024_c80_q_proj.pt
1.3M f1.0_s1.9104_k256_c160_k_proj.pt
1.3M f1.0_s1.9104_k256_c160_q_proj.pt
5.1M f1.0_s1.9104_k4096_c40_k_proj.pt
5.1M f1.0_s1.9104_k4096_c40_q_proj.pt
321K f1.0_s1.9104_k64_c160_k_proj.pt
321K f1.0_s1.9104_k64_c160_q_proj.pt
2.6M f1.0_s2.4716_k1024_c80_k_proj.pt
2.6M f1.0_s2.4716_k1024_c80_q_proj.pt
1.3M f1.0_s2.4716_k256_c160_k_proj.pt
1.3M f1.0_s2.4716_k256_c160_q_proj.pt
5.1M f1.0_s2.4716_k4096_c40_k_proj.pt
5.1M f1.0_s2.4716_k4096_c40_q_proj.pt
321K f1.0_s2.4716_k64_c160_k_proj.pt
321K f1.0_s2.4716_k64_c160_q_proj.pt
2.6M f1.0_s3.1686_k1024_c80_k_proj.pt
2.6M f1.0_s3.1686_k1024_c80_q_proj.pt
1.3M f1.0_s3.1686_k256_c160_k_proj.pt
1.3M f1.0_s3.1686_k256_c160_q_proj.pt
5.1M f1.0_s3.1686_k4096_c40_k_proj.pt
5.1M f1.0_s3.1686_k4096_c40_q_proj.pt
321K f1.0_s3.1686_k64_c160_k_proj.pt
321K f1.0_s3.1686_k64_c160_q_proj.pt
2.6M f1.0_s4.0277_k1024_c80_k_proj.pt
2.6M f1.0_s4.0277_k1024_c80_q_proj.pt
1.3M f1.0_s4.0277_k256_c160_k_proj.pt
1.3M f1.0_s4.0277_k256_c160_q_proj.pt
5.1M f1.0_s4.0277_k4096_c40_k_proj.pt
5.1M f1.0_s4.0277_k4096_c40_q_proj.pt
321K f1.0_s4.0277_k64_c160_k_proj.pt
321K f1.0_s4.0277_k64_c160_q_proj.pt
2.6M f1.0_s5.0793_k1024_c80_k_proj.pt
2.6M f1.0_s5.0793_k1024_c80_q_proj.pt
1.3M f1.0_s5.0793_k256_c160_k_proj.pt
1.3M f1.0_s5.0793_k256_c160_q_proj.pt
5.1M f1.0_s5.0793_k4096_c40_k_proj.pt
5.1M f1.0_s5.0793_k4096_c40_q_proj.pt
321K f1.0_s5.0793_k64_c160_k_proj.pt
321K f1.0_s5.0793_k64_c160_q_proj.pt
2.6M f1.0_s6.3579_k1024_c80_k_proj.pt
2.6M f1.0_s6.3579_k1024_c80_q_proj.pt
1.3M f1.0_s6.3579_k256_c160_k_proj.pt
1.3M f1.0_s6.3579_k256_c160_q_proj.pt
5.1M f1.0_s6.3579_k4096_c40_k_proj.pt
5.1M f1.0_s6.3579_k4096_c40_q_proj.pt
321K f1.0_s6.3579_k64_c160_k_proj.pt
321K f1.0_s6.3579_k64_c160_q_proj.pt
2.6M f1.0_s7.9029_k1024_c80_k_proj.pt
2.6M f1.0_s7.9029_k1024_c80_q_proj.pt
1.3M f1.0_s7.9029_k256_c160_k_proj.pt
1.3M f1.0_s7.9029_k256_c160_q_proj.pt
5.1M f1.0_s7.9029_k4096_c40_k_proj.pt
5.1M f1.0_s7.9029_k4096_c40_q_proj.pt
321K f1.0_s7.9029_k64_c160_k_proj.pt
321K f1.0_s7.9029_k64_c160_q_proj.pt
2.6M f1.0_s9.7593_k1024_c80_k_proj.pt
2.6M f1.0_s9.7593_k1024_c80_q_proj.pt
1.3M f1.0_s9.7593_k256_c160_k_proj.pt
1.3M f1.0_s9.7593_k256_c160_q_proj.pt
5.1M f1.0_s9.7593_k4096_c40_k_proj.pt
5.1M f1.0_s9.7593_k4096_c40_q_proj.pt
321K f1.0_s9.7593_k64_c160_k_proj.pt
321K f1.0_s9.7593_k64_c160_q_proj.pt
721K f2.25_s0.0292_k144_c160_k_proj.pt
721K f2.25_s0.0292_k144_c160_q_proj.pt
5.7M f2.25_s0.0292_k2304_c80_k_proj.pt
5.7M f2.25_s0.0292_k2304_c80_q_proj.pt
2.9M f2.25_s0.0292_k576_c160_k_proj.pt
2.9M f2.25_s0.0292_k576_c160_q_proj.pt
12M f2.25_s0.0292_k9216_c40_k_proj.pt
12M f2.25_s0.0292_k9216_c40_q_proj.pt
721K f2.25_s0.0463_k144_c160_k_proj.pt
721K f2.25_s0.0463_k144_c160_q_proj.pt
5.7M f2.25_s0.0463_k2304_c80_k_proj.pt
5.7M f2.25_s0.0463_k2304_c80_q_proj.pt
2.9M f2.25_s0.0463_k576_c160_k_proj.pt
2.9M f2.25_s0.0463_k576_c160_q_proj.pt
12M f2.25_s0.0463_k9216_c40_k_proj.pt
12M f2.25_s0.0463_k9216_c40_q_proj.pt
721K f2.25_s0.0713_k144_c160_k_proj.pt
721K f2.25_s0.0713_k144_c160_q_proj.pt
5.7M f2.25_s0.0713_k2304_c80_k_proj.pt
5.7M f2.25_s0.0713_k2304_c80_q_proj.pt
2.9M f2.25_s0.0713_k576_c160_k_proj.pt
2.9M f2.25_s0.0713_k576_c160_q_proj.pt
12M f2.25_s0.0713_k9216_c40_k_proj.pt
12M f2.25_s0.0713_k9216_c40_q_proj.pt
721K f2.25_s0.1072_k144_c160_k_proj.pt
721K f2.25_s0.1072_k144_c160_q_proj.pt
5.7M f2.25_s0.1072_k2304_c80_k_proj.pt
5.7M f2.25_s0.1072_k2304_c80_q_proj.pt
2.9M f2.25_s0.1072_k576_c160_k_proj.pt
2.9M f2.25_s0.1072_k576_c160_q_proj.pt
12M f2.25_s0.1072_k9216_c40_k_proj.pt
12M f2.25_s0.1072_k9216_c40_q_proj.pt
721K f2.25_s0.1576_k144_c160_k_proj.pt
721K f2.25_s0.1576_k144_c160_q_proj.pt
5.7M f2.25_s0.1576_k2304_c80_k_proj.pt
5.7M f2.25_s0.1576_k2304_c80_q_proj.pt
2.9M f2.25_s0.1576_k576_c160_k_proj.pt
2.9M f2.25_s0.1576_k576_c160_q_proj.pt
12M f2.25_s0.1576_k9216_c40_k_proj.pt
12M f2.25_s0.1576_k9216_c40_q_proj.pt
721K f2.25_s0.2270_k144_c160_k_proj.pt
721K f2.25_s0.2270_k144_c160_q_proj.pt
5.7M f2.25_s0.2270_k2304_c80_k_proj.pt
5.7M f2.25_s0.2270_k2304_c80_q_proj.pt
2.9M f2.25_s0.2270_k576_c160_k_proj.pt
2.9M f2.25_s0.2270_k576_c160_q_proj.pt
12M f2.25_s0.2270_k9216_c40_k_proj.pt
12M f2.25_s0.2270_k9216_c40_q_proj.pt
721K f2.25_s0.3211_k144_c160_k_proj.pt
721K f2.25_s0.3211_k144_c160_q_proj.pt
5.7M f2.25_s0.3211_k2304_c80_k_proj.pt
5.7M f2.25_s0.3211_k2304_c80_q_proj.pt
2.9M f2.25_s0.3211_k576_c160_k_proj.pt
2.9M f2.25_s0.3211_k576_c160_q_proj.pt
12M f2.25_s0.3211_k9216_c40_k_proj.pt
12M f2.25_s0.3211_k9216_c40_q_proj.pt
721K f2.25_s0.4469_k144_c160_k_proj.pt
721K f2.25_s0.4469_k144_c160_q_proj.pt
5.7M f2.25_s0.4469_k2304_c80_k_proj.pt
5.7M f2.25_s0.4469_k2304_c80_q_proj.pt
2.9M f2.25_s0.4469_k576_c160_k_proj.pt
2.9M f2.25_s0.4469_k576_c160_q_proj.pt
12M f2.25_s0.4469_k9216_c40_k_proj.pt
12M f2.25_s0.4469_k9216_c40_q_proj.pt
721K f2.25_s0.6128_k144_c160_k_proj.pt
721K f2.25_s0.6128_k144_c160_q_proj.pt
5.7M f2.25_s0.6128_k2304_c80_k_proj.pt
5.7M f2.25_s0.6128_k2304_c80_q_proj.pt
2.9M f2.25_s0.6128_k576_c160_k_proj.pt
2.9M f2.25_s0.6128_k576_c160_q_proj.pt
12M f2.25_s0.6128_k9216_c40_k_proj.pt
12M f2.25_s0.6128_k9216_c40_q_proj.pt
721K f2.25_s0.8289_k144_c160_k_proj.pt
721K f2.25_s0.8289_k144_c160_q_proj.pt
5.7M f2.25_s0.8289_k2304_c80_k_proj.pt
5.7M f2.25_s0.8289_k2304_c80_q_proj.pt
2.9M f2.25_s0.8289_k576_c160_k_proj.pt
2.9M f2.25_s0.8289_k576_c160_q_proj.pt
12M f2.25_s0.8289_k9216_c40_k_proj.pt
12M f2.25_s0.8289_k9216_c40_q_proj.pt
721K f2.25_s1.1072_k144_c160_k_proj.pt
721K f2.25_s1.1072_k144_c160_q_proj.pt
5.7M f2.25_s1.1072_k2304_c80_k_proj.pt
5.7M f2.25_s1.1072_k2304_c80_q_proj.pt
2.9M f2.25_s1.1072_k576_c160_k_proj.pt
2.9M f2.25_s1.1072_k576_c160_q_proj.pt
12M f2.25_s1.1072_k9216_c40_k_proj.pt
12M f2.25_s1.1072_k9216_c40_q_proj.pt
721K f2.25_s11.9776_k144_c160_k_proj.pt
721K f2.25_s11.9776_k144_c160_q_proj.pt
5.7M f2.25_s11.9776_k2304_c80_k_proj.pt
5.7M f2.25_s11.9776_k2304_c80_q_proj.pt
2.9M f2.25_s11.9776_k576_c160_k_proj.pt
2.9M f2.25_s11.9776_k576_c160_q_proj.pt
12M f2.25_s11.9776_k9216_c40_k_proj.pt
12M f2.25_s11.9776_k9216_c40_q_proj.pt
721K f2.25_s14.6147_k144_c160_k_proj.pt
721K f2.25_s14.6147_k144_c160_q_proj.pt
5.7M f2.25_s14.6147_k2304_c80_k_proj.pt
5.7M f2.25_s14.6147_k2304_c80_q_proj.pt
2.9M f2.25_s14.6147_k576_c160_k_proj.pt
2.9M f2.25_s14.6147_k576_c160_q_proj.pt
12M f2.25_s14.6147_k9216_c40_k_proj.pt
12M f2.25_s14.6147_k9216_c40_q_proj.pt
721K f2.25_s1.4621_k144_c160_k_proj.pt
721K f2.25_s1.4621_k144_c160_q_proj.pt
5.7M f2.25_s1.4621_k2304_c80_k_proj.pt
5.7M f2.25_s1.4621_k2304_c80_q_proj.pt
2.9M f2.25_s1.4621_k576_c160_k_proj.pt
2.9M f2.25_s1.4621_k576_c160_q_proj.pt
12M f2.25_s1.4621_k9216_c40_k_proj.pt
12M f2.25_s1.4621_k9216_c40_q_proj.pt
721K f2.25_s1.9104_k144_c160_k_proj.pt
721K f2.25_s1.9104_k144_c160_q_proj.pt
5.7M f2.25_s1.9104_k2304_c80_k_proj.pt
5.7M f2.25_s1.9104_k2304_c80_q_proj.pt
2.9M f2.25_s1.9104_k576_c160_k_proj.pt
2.9M f2.25_s1.9104_k576_c160_q_proj.pt
12M f2.25_s1.9104_k9216_c40_k_proj.pt
12M f2.25_s1.9104_k9216_c40_q_proj.pt
721K f2.25_s2.4716_k144_c160_k_proj.pt
721K f2.25_s2.4716_k144_c160_q_proj.pt
5.7M f2.25_s2.4716_k2304_c80_k_proj.pt
5.7M f2.25_s2.4716_k2304_c80_q_proj.pt
2.9M f2.25_s2.4716_k576_c160_k_proj.pt
2.9M f2.25_s2.4716_k576_c160_q_proj.pt
12M f2.25_s2.4716_k9216_c40_k_proj.pt
12M f2.25_s2.4716_k9216_c40_q_proj.pt
721K f2.25_s3.1686_k144_c160_k_proj.pt
721K f2.25_s3.1686_k144_c160_q_proj.pt
5.7M f2.25_s3.1686_k2304_c80_k_proj.pt
5.7M f2.25_s3.1686_k2304_c80_q_proj.pt
2.9M f2.25_s3.1686_k576_c160_k_proj.pt
2.9M f2.25_s3.1686_k576_c160_q_proj.pt
12M f2.25_s3.1686_k9216_c40_k_proj.pt
12M f2.25_s3.1686_k9216_c40_q_proj.pt
721K f2.25_s4.0277_k144_c160_k_proj.pt
721K f2.25_s4.0277_k144_c160_q_proj.pt
5.7M f2.25_s4.0277_k2304_c80_k_proj.pt
5.7M f2.25_s4.0277_k2304_c80_q_proj.pt
2.9M f2.25_s4.0277_k576_c160_k_proj.pt
2.9M f2.25_s4.0277_k576_c160_q_proj.pt
12M f2.25_s4.0277_k9216_c40_k_proj.pt
12M f2.25_s4.0277_k9216_c40_q_proj.pt
721K f2.25_s5.0793_k144_c160_k_proj.pt
721K f2.25_s5.0793_k144_c160_q_proj.pt
5.7M f2.25_s5.0793_k2304_c80_k_proj.pt
5.7M f2.25_s5.0793_k2304_c80_q_proj.pt
2.9M f2.25_s5.0793_k576_c160_k_proj.pt
2.9M f2.25_s5.0793_k576_c160_q_proj.pt
12M f2.25_s5.0793_k9216_c40_k_proj.pt
12M f2.25_s5.0793_k9216_c40_q_proj.pt
721K f2.25_s6.3579_k144_c160_k_proj.pt
721K f2.25_s6.3579_k144_c160_q_proj.pt
5.7M f2.25_s6.3579_k2304_c80_k_proj.pt
5.7M f2.25_s6.3579_k2304_c80_q_proj.pt
2.9M f2.25_s6.3579_k576_c160_k_proj.pt
2.9M f2.25_s6.3579_k576_c160_q_proj.pt
12M f2.25_s6.3579_k9216_c40_k_proj.pt
12M f2.25_s6.3579_k9216_c40_q_proj.pt
721K f2.25_s7.9029_k144_c160_k_proj.pt
721K f2.25_s7.9029_k144_c160_q_proj.pt
5.7M f2.25_s7.9029_k2304_c80_k_proj.pt
5.7M f2.25_s7.9029_k2304_c80_q_proj.pt
2.9M f2.25_s7.9029_k576_c160_k_proj.pt
2.9M f2.25_s7.9029_k576_c160_q_proj.pt
12M f2.25_s7.9029_k9216_c40_k_proj.pt
12M f2.25_s7.9029_k9216_c40_q_proj.pt
721K f2.25_s9.7593_k144_c160_k_proj.pt
721K f2.25_s9.7593_k144_c160_q_proj.pt
5.7M f2.25_s9.7593_k2304_c80_k_proj.pt
5.7M f2.25_s9.7593_k2304_c80_q_proj.pt
2.9M f2.25_s9.7593_k576_c160_k_proj.pt
2.9M f2.25_s9.7593_k576_c160_q_proj.pt
12M f2.25_s9.7593_k9216_c40_k_proj.pt
12M f2.25_s9.7593_k9216_c40_q_proj.pt
instead of taking topk key tokens (which advantages larger scores): I tried a nearest-neighbour downsample of the 9216 key tokens to 4096. basically like discarding every second token.
resample_softmax
attempts to sort the distribution first (which I think could give a fairer resample, assumping you interpolate between the datapoints), prior to resampling. however the sort ran out of memory, so wasn't able to evaluate this.
def resample_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
"""Softmax with a modified denominator. for each query token: sorts attn_scores, resamples key dimension to size k; you can use this to increase/decrease denominator to the magnitude on which the model was trained."""
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
del maxes
torch.cuda.empty_cache()
gc.collect()
diffs_sorted = diffs.sort(dim=dim).values
x_exp = diffs.exp()
del diffs
# for downsampling:
# mode='area' has best PSNR.. but maybe that's not important.
# for upsampling:
# lerping between attn scores (mode='linear') feels reasonable
# repeating attn scores (mode='nearest_exact') might be reasonable too
# not whether we'd care about anti-aliasing
diffs_resampled = interpolate(diffs_sorted, size=(*diffs_sorted.shape[:-1], k), mode='linear', antialias=False)
del diffs_sorted
diffs_exp_sum = diffs_resampled.exp().sum(dim, keepdim=True)
del diffs_resampled
quotient = x_exp/diffs_exp_sum
return quotient
nearest-neighbour resample ran to completion but completely destroyed the image.
def resample_crude_softmax(x: FloatTensor, k:int, dim=-1) -> FloatTensor:
"""Softmax with a modified denominator. for each query token: resamples key dimension to size k; you can use this to increase/decrease denominator to the magnitude on which the model was trained."""
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
del maxes
x_exp = diffs.exp()
diffs_resampled = interpolate(diffs, scale_factor=k/diffs.size(-1), mode='nearest-exact', antialias=False)
del diffs
diffs_exp_sum = diffs_resampled.exp().sum(dim, keepdim=True)
del diffs_resampled
quotient = x_exp/diffs_exp_sum
return quotient
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Get yer tensors here (q_proj, k_proj). from these you can compute
attn_scores
andattn_probs
as above.These are for the first (biggest) self-attn layer, at the first (noisiest) sigma, 14.6147.
batch size is 16 (cond and uncond repeated over 8 attn heads)
512x512 (in-distribution; 4096 tokens):
https://birchlabs.co.uk/share/out_tensor/f1.0_s14.6147_k4096_c40_q_proj.pt (5.2MB)
https://birchlabs.co.uk/share/out_tensor/f1.0_s14.6147_k4096_c40_k_proj.pt (5.2MB)
768x768 (out-of-distribution by 2.25x; 9216 tokens):
https://birchlabs.co.uk/share/out_tensor/f2.25_s14.6147_k9216_c40_q_proj.pt (11.8MB)
https://birchlabs.co.uk/share/out_tensor/f2.25_s14.6147_k9216_c40_k_proj.pt (11.8MB)