Last active
November 18, 2023 17:09
-
-
Save mgostIH/0d4a5b614f1b9829f7471d888133d99b to your computer and use it in GitHub Desktop.
An implementation of the GateLoop paper using real matrices only. The jax.checkpoint is an idea from the Mamba paper. Provided as a guide with types and more descriptive operations.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import equinox as eqx | |
from jax.lax import associative_scan | |
from jaxtyping import Float, Array, PRNGKeyArray | |
# This will compute the linear recursion from the paper GateLoop | |
# We work with real valued vectors and matrices | |
# v and k are vectors of shape (n, d) | |
# Together they will be outer producted to a matrix of shape (n, d, d) | |
# The internal state of the RNN is therefore a DxD matrix | |
# We'll compute all the internal states in parallel at each time using associative_scan | |
# a is technically a matrix of shape (n, d, d) but we'll only use the diagonal | |
# Therefore we'll actually have a shape (n, d) and use the standard product (Not matrix product) | |
# We'll obtain the next state like the following: | |
# h_t = a_t * h_{t-1} + v_t outer k_t | |
# While for the output we'll use the following: | |
# y_t = q @ h_t | |
# Where @ is the matrix product, q is a vector of size d (we have n of them, so shape (n, d)) | |
# jax.checkpoint avoids storing all the intermediate states for backpropagation | |
@eqx.filter_checkpoint | |
def gate(k: Float[Array, "N D"], v: Float[Array, "N D"], q: Float[Array, "N D"], a: Float[Array, "N D"]) -> Float[Array, "N D"]: | |
def binary_operator(e_i, e_j): | |
a_i, kv_i = e_i | |
a_j, kv_j = e_j | |
return a_j * a_i, a_j * kv_i + kv_j | |
# Compute outer product between k and v | |
kv = jnp.einsum("nd,ne->nde", k, v) | |
# Compute the linear recursion | |
# a[..., None] makes a of shape (n, d, 1) so it can be broadcast when multiplying with kv | |
_, y = associative_scan(binary_operator, (a[..., None], kv)) | |
y = jnp.einsum("nde,ne->nd", y, q) | |
return y | |
class TimeMixer(eqx.Module): | |
W_K : eqx.nn.Linear | |
W_V : eqx.nn.Linear | |
W_A : eqx.nn.Linear | |
W_Q : eqx.nn.Linear | |
def __init__(self, d: int, *, key : PRNGKeyArray): | |
super().__init__() | |
key_k, key_v, key_a, key_q = jax.random.split(key, 4) | |
self.W_K = eqx.nn.Linear(in_features=d, out_features=d, key = key_k) | |
self.W_V = eqx.nn.Linear(in_features=d, out_features=d, key = key_v) | |
self.W_A = eqx.nn.Linear(in_features=d, out_features=d, key = key_a) | |
self.W_Q = eqx.nn.Linear(in_features=d, out_features=d, key = key_q) | |
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]: | |
k = jax.vmap(self.W_K)(x) | |
v = jax.vmap(self.W_V)(x) | |
a = jax.vmap(self.W_A)(x) | |
q = jax.vmap(self.W_Q)(x) | |
# a needs to be from 0 to 1 | |
a = jax.nn.sigmoid(a) | |
return gate(k, v, q, a) | |
class GateLoopLayer(eqx.Module): | |
time_mixer: TimeMixer | |
layer_norm_time : eqx.nn.LayerNorm | |
layer_norm_channel : eqx.nn.LayerNorm | |
mlp : eqx.nn.MLP | |
def __init__(self, d: int, *, key : PRNGKeyArray): | |
super().__init__() | |
mixer_key, mlp_key = jax.random.split(key, 2) | |
self.time_mixer = TimeMixer(d, key = mixer_key) | |
self.layer_norm_time = eqx.nn.LayerNorm(shape=d) | |
self.layer_norm_channel = eqx.nn.LayerNorm(shape=d) | |
self.mlp = eqx.nn.MLP(in_size=d, out_size=d, width_size=d, depth=1, key = mlp_key) | |
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]: | |
x = x + jax.vmap(self.layer_norm_time)(self.time_mixer(x)) | |
return x + jax.vmap(self.layer_norm_channel)(jax.vmap(self.mlp)(x)) |
Hello, I want to clearify the shapes (see updated code example). In our experiments we used d_h=1 for all reported experiments due to the memory bottleneck imposed by the parallel scan. The general case is included to highlight the relation to standard MHA.
def gate_loop_operator(k, v, q, a):
"""
:param k: Input gates (l, nr_heads, d_h, 1)
:param v: Values (l, nr_heads, 1, d_h)
:param q: Output gates (l, nr_heads, 1, d_h)
:param a: State transitions (l, nr_heads, d_h, 1)
"""
def binary_operator(e_i, e_j):
a_i, kv_i = e_i
a_j, kv_j = e_j
return a_j * a_i, a_j * kv_i + kv_j
kv = jnp.matmul(k, v)
_, y = associative_scan(binary_operator, (a, kv), axis=1)
y = jnp.matmul(q, y)
return y
Nice, I'll keep this gist as it is now, but I'd argue that the memory bottleneck could be fixed with jax.checkpoint (or in this specific script equinox.filter_checkpoint would be more correct), as the Mamba paper suggests they do themselves.
Granted, their formulation isn't exactly the same, but close enough I'd argue that insights from one transfer from the other.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks; I guess we are on the same page in our state of confusion.
I noticed that the lucidrains reproduction is also trying to figure this out; its JAX implementation uses the SISO approach but the torch implementation makes this configurable; latest update comments on the performance implications of that. 1