Last active
November 18, 2023 17:09
-
-
Save mgostIH/0d4a5b614f1b9829f7471d888133d99b to your computer and use it in GitHub Desktop.
An implementation of the GateLoop paper using real matrices only. The jax.checkpoint is an idea from the Mamba paper. Provided as a guide with types and more descriptive operations.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import equinox as eqx | |
from jax.lax import associative_scan | |
from jaxtyping import Float, Array, PRNGKeyArray | |
# This will compute the linear recursion from the paper GateLoop | |
# We work with real valued vectors and matrices | |
# v and k are vectors of shape (n, d) | |
# Together they will be outer producted to a matrix of shape (n, d, d) | |
# The internal state of the RNN is therefore a DxD matrix | |
# We'll compute all the internal states in parallel at each time using associative_scan | |
# a is technically a matrix of shape (n, d, d) but we'll only use the diagonal | |
# Therefore we'll actually have a shape (n, d) and use the standard product (Not matrix product) | |
# We'll obtain the next state like the following: | |
# h_t = a_t * h_{t-1} + v_t outer k_t | |
# While for the output we'll use the following: | |
# y_t = q @ h_t | |
# Where @ is the matrix product, q is a vector of size d (we have n of them, so shape (n, d)) | |
# jax.checkpoint avoids storing all the intermediate states for backpropagation | |
@eqx.filter_checkpoint | |
def gate(k: Float[Array, "N D"], v: Float[Array, "N D"], q: Float[Array, "N D"], a: Float[Array, "N D"]) -> Float[Array, "N D"]: | |
def binary_operator(e_i, e_j): | |
a_i, kv_i = e_i | |
a_j, kv_j = e_j | |
return a_j * a_i, a_j * kv_i + kv_j | |
# Compute outer product between k and v | |
kv = jnp.einsum("nd,ne->nde", k, v) | |
# Compute the linear recursion | |
# a[..., None] makes a of shape (n, d, 1) so it can be broadcast when multiplying with kv | |
_, y = associative_scan(binary_operator, (a[..., None], kv)) | |
y = jnp.einsum("nde,ne->nd", y, q) | |
return y | |
class TimeMixer(eqx.Module): | |
W_K : eqx.nn.Linear | |
W_V : eqx.nn.Linear | |
W_A : eqx.nn.Linear | |
W_Q : eqx.nn.Linear | |
def __init__(self, d: int, *, key : PRNGKeyArray): | |
super().__init__() | |
key_k, key_v, key_a, key_q = jax.random.split(key, 4) | |
self.W_K = eqx.nn.Linear(in_features=d, out_features=d, key = key_k) | |
self.W_V = eqx.nn.Linear(in_features=d, out_features=d, key = key_v) | |
self.W_A = eqx.nn.Linear(in_features=d, out_features=d, key = key_a) | |
self.W_Q = eqx.nn.Linear(in_features=d, out_features=d, key = key_q) | |
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]: | |
k = jax.vmap(self.W_K)(x) | |
v = jax.vmap(self.W_V)(x) | |
a = jax.vmap(self.W_A)(x) | |
q = jax.vmap(self.W_Q)(x) | |
# a needs to be from 0 to 1 | |
a = jax.nn.sigmoid(a) | |
return gate(k, v, q, a) | |
class GateLoopLayer(eqx.Module): | |
time_mixer: TimeMixer | |
layer_norm_time : eqx.nn.LayerNorm | |
layer_norm_channel : eqx.nn.LayerNorm | |
mlp : eqx.nn.MLP | |
def __init__(self, d: int, *, key : PRNGKeyArray): | |
super().__init__() | |
mixer_key, mlp_key = jax.random.split(key, 2) | |
self.time_mixer = TimeMixer(d, key = mixer_key) | |
self.layer_norm_time = eqx.nn.LayerNorm(shape=d) | |
self.layer_norm_channel = eqx.nn.LayerNorm(shape=d) | |
self.mlp = eqx.nn.MLP(in_size=d, out_size=d, width_size=d, depth=1, key = mlp_key) | |
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N D"]: | |
x = x + jax.vmap(self.layer_norm_time)(self.time_mixer(x)) | |
return x + jax.vmap(self.layer_norm_channel)(jax.vmap(self.mlp)(x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice, I'll keep this gist as it is now, but I'd argue that the memory bottleneck could be fixed with jax.checkpoint (or in this specific script equinox.filter_checkpoint would be more correct), as the Mamba paper suggests they do themselves.
Granted, their formulation isn't exactly the same, but close enough I'd argue that insights from one transfer from the other.