Skip to content

Instantly share code, notes, and snippets.

@wdevazelhes
Last active June 9, 2019 18:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wdevazelhes/f7f619bebc3afc3f3f99331beb773b0f to your computer and use it in GitHub Desktop.
Save wdevazelhes/f7f619bebc3afc3f3f99331beb773b0f to your computer and use it in GitHub Desktop.
trying to use kernel approximations by explicit feature maps for softmax self-attention
import numpy as np
from sklearn.utils.extmath import softmax
from sklearn.kernel_approximation import RBFSampler
from sklearn_extra.kernel_approximation import Fastfood
seed = 42
rng = np.random.RandomState(seed)
D = 20
# It seems that it does not work in every case: it seems to work better if the
# sample are positive and we divide by the sum. See this example from https://
# github.com/scikit-learn/scikit-learn/commit/3f7cec39997e28b4056bdea4fd04572
# cfaad0080#diff-364d5b0b1ecfc510277a8f8072d884e7
X = rng.random_sample(size=(100, D))
X /= X.sum(axis=1)[:, np.newaxis]
# Let's take the mean as the key for instance
key = np.mean(X, axis=0, keepdims=True)
class TweakedRBFSampler(RBFSampler):
def transform(self, X):
tweak = np.exp((np.linalg.norm(X, axis=1, keepdims=True)**2) / 2)
return super(TweakedRBFSampler, self).transform(X) * tweak
class TweakedFastfoodSampler(Fastfood):
def transform(self, X):
tweak = np.exp(
(np.linalg.norm(X, axis=1, keepdims=True)**2) / 2)
return super(TweakedFastfoodSampler, self).transform(X) * tweak
def attention_projection(X, key):
return softmax(X.dot(key.T).T)[0].dot(X)
def attention_projection_approx(sampler, X, key):
X_f = sampler.fit_transform(X)
key_f = sampler.transform(key)
A = X.T.dot(X_f)
Z = X_f.sum(axis=0)
return A.dot(key_f.T) / Z.dot(key_f.T)
# The two kernels have not the same notation for the coefficient, for
# RBFSampler it's gamma, and for FastFood it's sigma
sampler_1 = TweakedRBFSampler(n_components=10000, gamma=0.5)
sampler_2 = TweakedFastfoodSampler(n_components=10000, sigma=1.)
print('True value of self-attention:')
print(attention_projection(X, key))
print("Approximation of self-attention by a modified version of scikit-learn's"
"RBFSampler:")
print(attention_projection_approx(sampler_1, X, key))
print('Approximation of self-attention by a modified FastFood Sampler:')
print(attention_projection_approx(sampler_2, X, key))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment