Skip to content

Instantly share code, notes, and snippets.

@RuiFilipeCampos
Created March 25, 2024 15:58
Show Gist options
  • Save RuiFilipeCampos/4f8888661412f128d0f76530a0bf50c9 to your computer and use it in GitHub Desktop.
Save RuiFilipeCampos/4f8888661412f128d0f76530a0bf50c9 to your computer and use it in GitHub Desktop.
From scaled dot product to metric tensor

From scaled dot product to metric tensor

NOTE: WIP

In this section, we point out that the multi-headed scaled dot product attention introduced in 2017 is equivalent to a general quadratic form that lends itself to a more efficient reformulation. Furthermore, we argue on the grounds of efficiency, interpretability and regularization for the imposition that the form be a metric. What follows is a short exposition of scaled dot product using Ricci calculus, transitioning into the proposed quadratic and metric attentions.

Let $Q_d^{nk}$, $K_d^{nk}$ and $V_d^{nk}$ each be $n$ learnable linear maps from $R^d$ to $R^k$ that act on $b$ sequences of $c$ input embeddings to produce the well known keys, queries and values,

$$ k^{bnck} = K_d^{nk} x^{bcd} $$

$$ q^{bnck} = Q_d^{nk} x^{bcd} $$

$$ v^{bnck} = V_d^{nk} x^{bcd} $$

Each query is dotted with every other key and the result is inversly scaled by the root of the dimensionality of the projection space before being softmaxed along one of the directions, producing

$$ s^{bncc'} = \textrm{softmax}^{c'} \left ( \frac{1}{\sqrt{k}} q^{bnck} k^{bnc'k'} \delta_{kk'} \right ) $$

where $s^{bncc'}$ represents the influence of embedding $c$ on embedding $c'$. The use of $N_k$ is what gives this core machanism the name of scaled dot product attention. The scores are then used on a weighted sum of the values to produce new representations

$$ t^{bnck} = s^{bncc'} v^{bnc''k} \delta_{c'c''} $$

and the result is reflatened and projected to the original embedding space

$$ \bar t^{bcl} = t^{bnck} $$

$$ y^{bcd} = E_l^d \bar t^{bcl} $$

Our focus is on the step right before the application of a softmax

$$ r^{bncc'} = q^{bnck} k^{bnc'k'} \delta_{kk'} $$

By substituting the operations that produced the queries and keys,

$$ r^{bncc'} = Q_d^{nk} K_{d'}^{nk'} \delta_{kk'} x^{bcd} x^{bc'd'} $$

and by defining $U^n_{dd'}=Q_d^{nk} K_{d'}^{nk'} \delta_{kk'} $, we can see how the quadratic form emerges

$$ r^{bncc'} = U^n_{dd'} x^{bcd} x^{bc'd'} $$

Disregarding training dynamics and efficiency considerations, it is evident that this is a complete mathematical equivalence. However, there is good reason not to keep this form. Indeed, the motivation for using multiple heads that operate on a smaller dimensional space is that, whearas the quadratic form makes use of $nd^2$ parameters, the original formulation uses $2ndk$, thus, as long as $k < d/2$, that approach is more memory efficient.

However, it is not the most efficient reformulation that can be squeezed out of the quadratic form,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment