Skip to content

Instantly share code, notes, and snippets.

@mgostIH
Last active November 18, 2023 17:09
Show Gist options
  • Save mgostIH/0d4a5b614f1b9829f7471d888133d99b to your computer and use it in GitHub Desktop.
Save mgostIH/0d4a5b614f1b9829f7471d888133d99b to your computer and use it in GitHub Desktop.
An implementation of the GateLoop paper using real matrices only. The jax.checkpoint is an idea from the Mamba paper. Provided as a guide with types and more descriptive operations.
import jax
import jax.numpy as jnp
import equinox as eqx
from jax.lax import associative_scan
from jaxtyping import Float, Array, PRNGKeyArray
# This will compute the linear recursion from the paper GateLoop
# We work with real valued vectors and matrices
# v and k are vectors of shape (n, d)
# Together they will be outer producted to a matrix of shape (n, d, d)
# The internal state of the RNN is therefore a DxD matrix
# We'll compute all the internal states in parallel at each time using associative_scan
# a is technically a matrix of shape (n, d, d) but we'll only use the diagonal
# Therefore we'll actually have a shape (n, d) and use the standard product (Not matrix product)
# We'll obtain the next state like the following:
# h_t = a_t * h_{t-1} + v_t outer k_t
# While for the output we'll use the following:
# y_t = q @ h_t
# Where @ is the matrix product, q is a vector of size d (we have n of them, so shape (n, d))
# jax.checkpoint avoids storing all the intermediate states for backpropagation
@eqx.filter_checkpoint
def gate(k: Float[Array, "N D"], v: Float[Array, "N D"], q: Float[Array, "N D"], a: Float[Array, "N D"]) -> Float[Array, "N D"]:
def binary_operator(e_i, e_j):
a_i, kv_i = e_i
a_j, kv_j = e_j
return a_j * a_i, a_j * kv_i + kv_j
# Compute outer product between k and v
kv = jnp.einsum("nd,ne->nde", k, v)
# Compute the linear recursion
# a[..., None] makes a of shape (n, d, 1) so it can be broadcast when multiplying with kv
_, y = associative_scan(binary_operator, (a[..., None], kv))
y = jnp.einsum("nde,ne->nd", y, q)
return y
class TimeMixer(eqx.Module):
W_K : eqx.nn.Linear
W_V : eqx.nn.Linear
W_A : eqx.nn.Linear
W_Q : eqx.nn.Linear
def __init__(self, d: int, *, key : PRNGKeyArray):
super().__init__()
key_k, key_v, key_a, key_q = jax.random.split(key, 4)
self.W_K = eqx.nn.Linear(in_features=d, out_features=d, key = key_k)
self.W_V = eqx.nn.Linear(in_features=d, out_features=d, key = key_v)
self.W_A = eqx.nn.Linear(in_features=d, out_features=d, key = key_a)
self.W_Q = eqx.nn.Linear(in_features=d, out_features=d, key = key_q)
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]:
k = jax.vmap(self.W_K)(x)
v = jax.vmap(self.W_V)(x)
a = jax.vmap(self.W_A)(x)
q = jax.vmap(self.W_Q)(x)
# a needs to be from 0 to 1
a = jax.nn.sigmoid(a)
return gate(k, v, q, a)
class GateLoopLayer(eqx.Module):
time_mixer: TimeMixer
layer_norm_time : eqx.nn.LayerNorm
layer_norm_channel : eqx.nn.LayerNorm
mlp : eqx.nn.MLP
def __init__(self, d: int, *, key : PRNGKeyArray):
super().__init__()
mixer_key, mlp_key = jax.random.split(key, 2)
self.time_mixer = TimeMixer(d, key = mixer_key)
self.layer_norm_time = eqx.nn.LayerNorm(shape=d)
self.layer_norm_channel = eqx.nn.LayerNorm(shape=d)
self.mlp = eqx.nn.MLP(in_size=d, out_size=d, width_size=d, depth=1, key = mlp_key)
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]:
x = x + jax.vmap(self.layer_norm_time)(self.time_mixer(x))
return x + jax.vmap(self.layer_norm_channel)(jax.vmap(self.mlp)(x))
@tobiaskatsch
Copy link

We used complex state transitions, with distinct linear projections for both the magnitude and the phase. Curious if your variant works too.

@EelcoHoogendoorn
Copy link

I dont think the kv term is meant to be an outer product; rather its intended to be kv = jnp.einsum("nd,nd->nd", k, v)

@mgostIH
Copy link
Author

mgostIH commented Nov 12, 2023

@EelcoHoogendoorn

I dont think the kv term is meant to be an outer product; rather its intended to be kv = jnp.einsum("nd,nd->nd", k, v)

I was also considering this, the paper is a unclear about it for the following reasons:
In page 4

Note, that for generality we define an outer product k_n^T v_n entering the gate loop. Therefore, k_n^T v_n and h_n are of shape C^(d_h×d_h)

Here it's hightlighted that the operation is indeed an outer product, but the notation is not standard, since usually it'd be k_n * v_n^T (using standard column notation), but having asked the author via email he mentioned that the paper uses row notation.

Secondly, in the code of the paper at page 5, the line kv = jnp.matmul(k, v) makes me think that it's not simply a scalar product, but this depends on the shapes of the elements. From the perspective of what was described in page 4 regarding the shapes of a, q, k, v, it's not understandable what kind of operation that matmul is supposed to perform since matmul between shapes like (n, d) @ (n, d) doesn't work, if we had (n, d, 1) @ (n, 1, d) we could obtain (n, d, d), while if it were (n, 1, d), (n, d, 1) we'd obtain (n, 1,1), so there's ambiguity in the description.

Yet the paper does mention at page 4 that:

Choosing a max-headed variant, that is d_h = 1, we obtain the SISO case which coincides with previous definitions and element-wise gating when parallelized across multiple channels.

However it doesn't state anywhere whether this is actually what they chose for the architecture. It can make sense as an operation and it would then lead to your suggestion, but both approaches seem realizable, also possible mixes that don't have to be dxd.

I suspect that the author chose the latter, but I haven't received confirmation and am waiting for some corrections in the paper that make those sections clearer. Besides, it might still work very well and dxd allows for an amount of memory that is quadratic in the dimensionality rather than linear, if your tasks are bottlenecked by it it should be a reasonable thing to do.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Nov 12, 2023

Thanks; I guess we are on the same page in our state of confusion.

I noticed that the lucidrains reproduction is also trying to figure this out; its JAX implementation uses the SISO approach but the torch implementation makes this configurable; latest update comments on the performance implications of that. 1

@tobiaskatsch
Copy link

tobiaskatsch commented Nov 12, 2023

Hello, I want to clearify the shapes (see updated code example). In our experiments we used d_h=1 for all reported experiments due to the memory bottleneck imposed by the parallel scan. The general case is included to highlight the relation to standard MHA.

def gate_loop_operator(k, v, q, a):     
    """
    :param k: Input gates           (l, nr_heads, d_h, 1)
    :param v: Values                (l, nr_heads, 1, d_h)
    :param q: Output gates          (l, nr_heads, 1, d_h)
    :param a: State transitions     (l, nr_heads, d_h, 1)
    """
    def binary_operator(e_i, e_j):
        a_i, kv_i = e_i
        a_j, kv_j = e_j
        return a_j * a_i, a_j * kv_i + kv_j

    kv = jnp.matmul(k, v)
    _, y = associative_scan(binary_operator, (a, kv), axis=1)
    y = jnp.matmul(q, y)
    return y

@mgostIH
Copy link
Author

mgostIH commented Nov 13, 2023

Nice, I'll keep this gist as it is now, but I'd argue that the memory bottleneck could be fixed with jax.checkpoint (or in this specific script equinox.filter_checkpoint would be more correct), as the Mamba paper suggests they do themselves.

Granted, their formulation isn't exactly the same, but close enough I'd argue that insights from one transfer from the other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment