Skip to content

Instantly share code, notes, and snippets.

@mgostIH
Last active November 18, 2023 17:09
Show Gist options
  • Save mgostIH/0d4a5b614f1b9829f7471d888133d99b to your computer and use it in GitHub Desktop.
Save mgostIH/0d4a5b614f1b9829f7471d888133d99b to your computer and use it in GitHub Desktop.
An implementation of the GateLoop paper using real matrices only. The jax.checkpoint is an idea from the Mamba paper. Provided as a guide with types and more descriptive operations.
import jax
import jax.numpy as jnp
import equinox as eqx
from jax.lax import associative_scan
from jaxtyping import Float, Array, PRNGKeyArray
# This will compute the linear recursion from the paper GateLoop
# We work with real valued vectors and matrices
# v and k are vectors of shape (n, d)
# Together they will be outer producted to a matrix of shape (n, d, d)
# The internal state of the RNN is therefore a DxD matrix
# We'll compute all the internal states in parallel at each time using associative_scan
# a is technically a matrix of shape (n, d, d) but we'll only use the diagonal
# Therefore we'll actually have a shape (n, d) and use the standard product (Not matrix product)
# We'll obtain the next state like the following:
# h_t = a_t * h_{t-1} + v_t outer k_t
# While for the output we'll use the following:
# y_t = q @ h_t
# Where @ is the matrix product, q is a vector of size d (we have n of them, so shape (n, d))
# jax.checkpoint avoids storing all the intermediate states for backpropagation
@eqx.filter_checkpoint
def gate(k: Float[Array, "N D"], v: Float[Array, "N D"], q: Float[Array, "N D"], a: Float[Array, "N D"]) -> Float[Array, "N D"]:
def binary_operator(e_i, e_j):
a_i, kv_i = e_i
a_j, kv_j = e_j
return a_j * a_i, a_j * kv_i + kv_j
# Compute outer product between k and v
kv = jnp.einsum("nd,ne->nde", k, v)
# Compute the linear recursion
# a[..., None] makes a of shape (n, d, 1) so it can be broadcast when multiplying with kv
_, y = associative_scan(binary_operator, (a[..., None], kv))
y = jnp.einsum("nde,ne->nd", y, q)
return y
class TimeMixer(eqx.Module):
W_K : eqx.nn.Linear
W_V : eqx.nn.Linear
W_A : eqx.nn.Linear
W_Q : eqx.nn.Linear
def __init__(self, d: int, *, key : PRNGKeyArray):
super().__init__()
key_k, key_v, key_a, key_q = jax.random.split(key, 4)
self.W_K = eqx.nn.Linear(in_features=d, out_features=d, key = key_k)
self.W_V = eqx.nn.Linear(in_features=d, out_features=d, key = key_v)
self.W_A = eqx.nn.Linear(in_features=d, out_features=d, key = key_a)
self.W_Q = eqx.nn.Linear(in_features=d, out_features=d, key = key_q)
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]:
k = jax.vmap(self.W_K)(x)
v = jax.vmap(self.W_V)(x)
a = jax.vmap(self.W_A)(x)
q = jax.vmap(self.W_Q)(x)
# a needs to be from 0 to 1
a = jax.nn.sigmoid(a)
return gate(k, v, q, a)
class GateLoopLayer(eqx.Module):
time_mixer: TimeMixer
layer_norm_time : eqx.nn.LayerNorm
layer_norm_channel : eqx.nn.LayerNorm
mlp : eqx.nn.MLP
def __init__(self, d: int, *, key : PRNGKeyArray):
super().__init__()
mixer_key, mlp_key = jax.random.split(key, 2)
self.time_mixer = TimeMixer(d, key = mixer_key)
self.layer_norm_time = eqx.nn.LayerNorm(shape=d)
self.layer_norm_channel = eqx.nn.LayerNorm(shape=d)
self.mlp = eqx.nn.MLP(in_size=d, out_size=d, width_size=d, depth=1, key = mlp_key)
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]:
x = x + jax.vmap(self.layer_norm_time)(self.time_mixer(x))
return x + jax.vmap(self.layer_norm_channel)(jax.vmap(self.mlp)(x))
@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Nov 12, 2023

Thanks; I guess we are on the same page in our state of confusion.

I noticed that the lucidrains reproduction is also trying to figure this out; its JAX implementation uses the SISO approach but the torch implementation makes this configurable; latest update comments on the performance implications of that. 1

@tobiaskatsch
Copy link

tobiaskatsch commented Nov 12, 2023

Hello, I want to clearify the shapes (see updated code example). In our experiments we used d_h=1 for all reported experiments due to the memory bottleneck imposed by the parallel scan. The general case is included to highlight the relation to standard MHA.

def gate_loop_operator(k, v, q, a):     
    """
    :param k: Input gates           (l, nr_heads, d_h, 1)
    :param v: Values                (l, nr_heads, 1, d_h)
    :param q: Output gates          (l, nr_heads, 1, d_h)
    :param a: State transitions     (l, nr_heads, d_h, 1)
    """
    def binary_operator(e_i, e_j):
        a_i, kv_i = e_i
        a_j, kv_j = e_j
        return a_j * a_i, a_j * kv_i + kv_j

    kv = jnp.matmul(k, v)
    _, y = associative_scan(binary_operator, (a, kv), axis=1)
    y = jnp.matmul(q, y)
    return y

@mgostIH
Copy link
Author

mgostIH commented Nov 13, 2023

Nice, I'll keep this gist as it is now, but I'd argue that the memory bottleneck could be fixed with jax.checkpoint (or in this specific script equinox.filter_checkpoint would be more correct), as the Mamba paper suggests they do themselves.

Granted, their formulation isn't exactly the same, but close enough I'd argue that insights from one transfer from the other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment