Skip to content

Instantly share code, notes, and snippets.

@EelcoHoogendoorn
Last active November 12, 2023 11:56
Show Gist options
  • Save EelcoHoogendoorn/6be31f076e1ea4d8d1ce197e0b0b3b63 to your computer and use it in GitHub Desktop.
Save EelcoHoogendoorn/6be31f076e1ea4d8d1ce197e0b0b3b63 to your computer and use it in GitHub Desktop.
"""Minimal example of DLR (diagonal linear recurrent) layer in JAX
https://arxiv.org/pdf/2212.00768.pdf
"""
from typing import Any, Callable, Sequence, Tuple
from flax import linen
import jax
import jax.numpy as jnp
from numga.algebra.algebra import Algebra
from numga.backend.jax.context import JaxContext
class GADLR(linen.Module):
"""GA-DLR module."""
ga: object
n_rotors: int
n_outputs: int
log_decay: Tuple[float, float] = (-4., 0.)
bias: bool = False
kernel_init: Callable[..., Any] = jax.nn.initializers.glorot_normal()
# kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform()
@linen.compact
def __call__(self, x, u):
even = self.ga.subspace.even_grade()
bivec = self.ga.subspace.bivector()
key = jax.random.PRNGKey(1)
def make_log_scalar(key):
log_scalar = jax.random.uniform(key, shape=(self.n_rotors,), minval=self.log_decay[0], maxval=self.log_decay[1])
return log_scalar
# log_scalar = self.param("log_scalar", make_log_scalar)
log_scalar = make_log_scalar(key)
def make_bivec(key) -> "Bivector":
b = jax.random.ball(key, shape=(self.n_rotors,), d=len(bivec))
bn = jnp.linalg.norm(b, axis=1, keepdims=True)
return self.ga.multivector(bivec, b * bn)
# for the time being, lets leave these as fixed parameters?
# according to gateloop/mamba, want to condition these on u
# bivector = self.param("bivector", make_bivec)
bivector = make_bivec(key)
l: "Even" = (bivector * jnp.pi).exp() * jnp.exp(-jnp.exp(log_scalar))
# wrap recurrent state x here in mv type and unpack again after product
# alternatively we would leak the multivector type in the broader codebase / arch
# may or may not be desirable?
A = lambda x: (l * self.ga.multivector(even, x)).values
B = linen.Dense(
len(even)*self.n_rotors,
name=f'hidden_B',
kernel_init=self.kernel_init,
use_bias=self.bias)
C = linen.Dense(
self.n_outputs,
name=f'hidden_C',
kernel_init=self.kernel_init,
use_bias=self.bias)
D = linen.Dense(
self.n_outputs,
name=f'hidden_D',
kernel_init=self.kernel_init,
use_bias=self.bias)
x = A(x) + B(u).reshape(x.shape)
y = C(x.flatten()) + D(u)
# y = C(x[:,0]) + D(u)
return x, y
def init_carry(self):
even = self.ga.subspace.even_grade()
c = jnp.zeros((self.n_rotors, len(even)))
return c
# return ga.multivector(c, subspace=even)
def test():
key = jax.random.PRNGKey(1)
dlr = GADLR(
# algebra=Algebra.from_pqr(3, 0, 0),
ga=JaxContext(Algebra.from_pqr(3, 0, 0)),
n_outputs=32,
n_rotors=32,
bias=False
)
x = dlr.init_carry()
u = jnp.ones((1,))
params = dlr.init(key, x, u)
# print(params)
# quit()
apply = jax.jit(dlr.apply)
r = []
for i in range(100):
x, y = apply(params, x, u)
r.append(y)
r = jnp.array(r)
import matplotlib.pyplot as plt
plt.plot(r)
plt.show()
if __name__=='__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment