You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Let $P^{nk}_d$ be $N_n$ learnable projections from $\mathbf R^{N_d}$ to $\mathbf R^{N_k}$ and $x^{bcd}$ a batch of $N_b$ sequences containing $N_c$ embeddings from $\mathbf R^{N_d}$. The action of these projections is expressed in Ricci notation by
$$p^{bnck} = P^{nk}_d x^{bcd}$$
At the heart of the proposed attention mechanism is a learnable dot product of each projected embedding with each other embedding. This is achieved using $N_n$ learnable metric tensors $M^{n}_{kk'}$ and is given by
$$q^{bncc'} = M^{n}_{kk'} p^{bnck} p^{bnc'k'}$$
The metric tensor is symmetric, so we can reduce the number of computations by grouping the terms strategically,
Let $F_N(v, w)$ be a pairing function that indexes the elements above and including the diagonal of a matrix from $\mathbf R^{N\times N}$, and $f$ and $g$ integer valued functions that retrieve the first and second argument of $F_N$, that is
$$ v = f(F_{N}(v, w)) $$
and
$$ w = g(F_{N}(v, w)) $$
Such an arrangement is easily achieved by storing two arrays to be used as a lookup table for $f$ and $g$. Finally, let $l=F_{N_l}(k, k')$, and define
$$ \bar M^n_{l} = M^n_{f(l)g(l)} $$
which we use to rewrite our original expression as
where $\tilde \delta^{f(l)g(l)} = 1 - \delta^{f(l)g(l)} $.
At this point, our expression already fits quite well within a cuda kernel. Note how the $\delta$'s neatly define which expression needs to be calculated for a given value of $l$ and how easily that can be determined with an if-statement on $l$.
However, a further computational saving is unlocked with the usage of a metric tensor, since dot products are comutative it follows that $q^{bncc'} =q^{bnc'c}$, so we only need to perform the computation once for each $cc'$ where $c \geq c'$. Let $u=F_{N_c}(c, c')$ and agree on the convention that when $f$ and $g$ act on $l$, they'll recover $k$ and $k'$, but when they act on $u$, they'll recover $c$ and $c'$, so we rewrite the forwards kernel as
Note that for every combination of $l$ and $u$, only one term in this expression needs to be computed and the number of floating point calculations has been reduced from $N_k^2N_c^2$ to $N_kN_c / 2$ (note to self: verify this ).
To proceed with the rest of the attention mechanism, $q^{bncc'}$ is recovered and the standard application of a softmax is made
The result is then reflattened and a final transformation is applied to ensure mixing of the features and align the dimensionality to the original embedding space
$$
\bar t^{bcl} = t^{bnck}
$$
$$
y^{bcd} = E_l^d \bar t^{bcl}
$$
Backwards Pass
Gradient with respect with the metric coordinates: