Skip to content

Instantly share code, notes, and snippets.

@RuiFilipeCampos
Last active March 25, 2024 15:56
Show Gist options
  • Save RuiFilipeCampos/f006c3ac2b2409ff18b02620cfc1f64a to your computer and use it in GitHub Desktop.
Save RuiFilipeCampos/f006c3ac2b2409ff18b02620cfc1f64a to your computer and use it in GitHub Desktop.
CUDA Kernel of the Metric Tensor Attention

CUDA Kernel of the Metric Tensor Attention

NOTE: WIP

Forwards Pass

Let $P^{nk}_d$ be $N_n$ learnable projections from $\mathbf R^{N_d}$ to $\mathbf R^{N_k}$ and $x^{bcd}$ a batch of $N_b$ sequences containing $N_c$ embeddings from $\mathbf R^{N_d}$. The action of these projections is expressed in Ricci notation by

$$p^{bnck} = P^{nk}_d x^{bcd}$$

At the heart of the proposed attention mechanism is a learnable dot product of each projected embedding with each other embedding. This is achieved using $N_n$ learnable metric tensors $M^{n}_{kk'}$ and is given by

$$q^{bncc'} = M^{n}_{kk'} p^{bnck} p^{bnc'k'}$$

The metric tensor is symmetric, so we can reduce the number of computations by grouping the terms strategically,

$$q^{bncc'} = \delta^{kk'} M^n_{kk'} p^{bnck} p^{bnc'k'} + 2 \delta^{k>k'} M^n_{kk'} p^{bnck} p^{bnc'k'}$$

Let $F_N(v, w)$ be a pairing function that indexes the elements above and including the diagonal of a matrix from $\mathbf R^{N\times N}$, and $f$ and $g$ integer valued functions that retrieve the first and second argument of $F_N$, that is

$$ v = f(F_{N}(v, w)) $$

and

$$ w = g(F_{N}(v, w)) $$

Such an arrangement is easily achieved by storing two arrays to be used as a lookup table for $f$ and $g$. Finally, let $l=F_{N_l}(k, k')$, and define

$$ \bar M^n_{l} = M^n_{f(l)g(l)} $$

which we use to rewrite our original expression as

$$q^{bncc'} = \delta^{f(l)g(l)} \bar M^n_{l} p^{bncf(l)} p^{bnc'f(l)} + 2 \tilde \delta^{f(l)g(l)} \bar M^n_l p^{bncf(l)} p^{bnc'g(l)}$$

where $\tilde \delta^{f(l)g(l)} = 1 - \delta^{f(l)g(l)} $.

At this point, our expression already fits quite well within a cuda kernel. Note how the $\delta$'s neatly define which expression needs to be calculated for a given value of $l$ and how easily that can be determined with an if-statement on $l$.

However, a further computational saving is unlocked with the usage of a metric tensor, since dot products are comutative it follows that $q^{bncc'} =q^{bnc'c}$, so we only need to perform the computation once for each $cc'$ where $c \geq c'$. Let $u=F_{N_c}(c, c')$ and agree on the convention that when $f$ and $g$ act on $l$, they'll recover $k$ and $k'$, but when they act on $u$, they'll recover $c$ and $c'$, so we rewrite the forwards kernel as

$$\bar q^{bnu} = \delta^{f(l)g(l)} \bar M^n_{l} p^{bnf(u)f(l)} p^{bng(u)f(l)} + 2 \tilde \delta^{f(l)g(l)} \bar M^n_l p^{bnf(u)f(l)} p^{bng(u)g(l)}$$

To avoid repetition, I'll do the treatment for the following expression

$$\rho^{bncc'l} = p^{bncf(l)} p^{bnc'g(l)}$$

and perform symbol substitution where necessary in order to place it back on the expression we're working. Performing direct substitution we get

$$\rho^{bnul} = p^{bnf(u)f(l)} p^{bng(u)g(l)}$$

which we can similarly split into two expressions

$$\rho^{bnul} = \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bng(u)g(l)} + 2 \tilde \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bng(u)g(l)}$$

$$\rho^{bnul} = \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bnf(u)g(l)} + 2 \tilde \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bng(u)g(l)}$$

Substituting this back, while attending to the relevant substitution on the first term of the original expression,

$$ \begin{aligned} q^{bnu} &= \delta^{f(l)g(l)} \bar M^n_{l} \left [ \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bnf(u)f(l)} + 2 \tilde \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bng(u)f(l)} \right ] \\ &\quad + 2 \tilde \delta^{f(l)g(l)} \bar M^n_l \left [ \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bnf(u)g(l)} + 2 \tilde \delta^{f(u)g(u)} p^{bnf(u)f(l)} p^{bng(u)g(l)} \right ] \end{aligned} $$

which we'll now group according to the $\delta$'s

$$ \begin{aligned} q^{bnu} &= \bar M^n _ {l} p^{bnf(u)f(l)} p^{bnf(u)f(l)} \delta^{f(l)g(l)} \delta^{f(u)g(u)} \\ &\quad + 2 \bar M^n_{l} p^{bnf(u)f(l)} p^{bng(u)f(l)} \delta^{f(l)g(l)} \tilde \delta^{f(u)g(u)} \\ &\quad + 2 \bar M^n_l p^{bnf(u)f(l)} p^{bnf(u)g(l)} \delta^{f(u)g(u)} \tilde \delta^{f(l)g(l)} \\ &\quad + 4 \bar M^n_l p^{bnf(u)f(l)} p^{bng(u)g(l)} \tilde \delta^{f(u)g(u)} \tilde \delta^{f(l)g(l)} \end{aligned} $$

Note that for every combination of $l$ and $u$, only one term in this expression needs to be computed and the number of floating point calculations has been reduced from $N_k^2N_c^2$ to $N_kN_c / 2$ (note to self: verify this ).

To proceed with the rest of the attention mechanism, $q^{bncc'}$ is recovered and the standard application of a softmax is made

$$ s^{bncc'} = \textrm{softmax}^{c'} \left ( \frac{q^{bncc'} }{\sqrt{N_k}} \right ) $$

but followed by the application of the scores on the same projection

$$ t^{bnck} = s^{bncc'} p^{bnc''k} \delta_{c'c''} = s^{bnc}_ {c'} p^{bnc'k} $$

The result is then reflattened and a final transformation is applied to ensure mixing of the features and align the dimensionality to the original embedding space

$$ \bar t^{bcl} = t^{bnck} $$

$$ y^{bcd} = E_l^d \bar t^{bcl} $$

Backwards Pass

Gradient with respect with the metric coordinates:

$$\partial_{M^{n'}_ {k'''k''''}} q^{bncc'} = \partial_{M^{n'}_{k'''k''''}} M^{n} _{kk'} p^{bnck} p^{bnc'k'}$$

$$\partial_{M^{n'} _ {k'''k''''}} q^{bncc'} = p^{bnck} p^{bnc'k'} \partial_{M^{n'}_ {k'''k''''}} M^{n}_{kk'}$$

$$ \partial_{M^{n'} _ {k'''k''''}} q^{bncc'} = p^{bnck} p^{bnc'k'} \delta^{nn'} \delta ^ {kk'''} \delta^{k'k''''} $$

$$ \partial_{M^{n}_{k'''k''''}} q^{bncc'} = p^{bnck'''} p^{bnc'k''''} $$

$$ \partial_{M^{n}_{kk'}} q^{bncc'} = p^{bnck} p^{bnc'k'} $$

$$ \partial_{M^n_l} q^{bnu} = p^{bnf(u)f(l)} p^{bng(u)g(l)} $$

Gradient with respect to the input coordinates

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n}_ {kk'} \partial_{p^{bnc''k''}} p^{bnck} p^{bnc'k'} $$

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n} _ {kk'} \left ( p^{bnc'k'} \partial_{p^{bnc''k''}} p^{bnck} + p^{bnck} \partial_{p^{bnc''k''}} p^{bnc'k'} \right ) $$

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n}_ {kk'} \left ( p^{bnc'k'} \delta^{c''c} \delta^{k''k} + p^{bnck} \delta^{c''c'} \delta^{k''k'} \right ) $$

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n} _ {kk'} p^{bnc'k'} \delta^{c''c} \delta^{k''k} + M^{n}_ {kk'} p^{bnck} \delta^{c''c'} \delta^{k''k'} $$

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n} _ {k''k'} p^{bnc'k'} \delta^{c''c} + M^{n}_ {kk''} p^{bnck} \delta^{c''c'} $$

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n} _ {k''k'} p^{bnc'k'} \delta^{c''c} + M^{n}_ {k''k} p^{bnck} \delta^{c''c'} $$

$$ \partial_{p^{bnc''k''}} q^{bncc'} = M^{n} _ {k''k} p^{bnc'k} \delta^{c''c} + M^{n}_ {k''k} p^{bnck} \delta^{c''c'} $$

$$ \partial_{p^{bnc''k'}} q^{bncc'} = M^{n} _ {k'k} p^{bnc'k} \delta^{c''c} + M^{n}_ {k'k} p^{bnck} \delta^{c''c'} $$

$$ \partial_{p^{bnc''k'}} q^{bncc'} = M^{n} _ {kk'} p^{bnc'k} \delta^{c''c} + M^{n}_ {kk'} p^{bnck} \delta^{c''c'} $$

$$ \partial_{p^{bnc''k'}} q^{bnu} = \bar M^{n} _ l p^{bng(u)f(l)} \delta^{c''f(u)} + \bar M^{n}_ {l} p^{bnf(u)f(l)} \delta^{c''g(u)} $$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment