Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created April 3, 2023 00:16
Show Gist options
  • Save Birch-san/6902c1437fae9081561457d094242da5 to your computer and use it in GitHub Desktop.
Save Birch-san/6902c1437fae9081561457d094242da5 to your computer and use it in GitHub Desktop.
Questionable softmax
from torch import FloatTensor
vae_scale_factor = 8
typical_self_attn_key_length = (512/vae_scale_factor) * (512/vae_scale_factor)
desired_self_attn_key_length = (200/vae_scale_factor) * (200/vae_scale_factor)
key_length_factor=desired_self_attn_key_length/typical_self_attn_key_length if is_self_attn else 1.
def softmax(x: FloatTensor, dim=-1) -> FloatTensor:
key_tokens = x.size(-1)
maxes = x.max(dim, keepdim=True).values
diffs = x-maxes
x_exp = diffs.exp()
avg_diff = diffs.float().quantile(.175, dim=-1, keepdim=True).to(diffs.dtype)
avg_diff_exp = avg_diff.exp()
x_exp_sum = x_exp.sum(dim, keepdim=True)
preferred_token_count = key_tokens/key_length_factor
x_exp_sum = x_exp_sum + avg_diff_exp * (preferred_token_count-key_tokens)
quotient = x_exp/x_exp_sum
return quotient
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment