Skip to content

Instantly share code, notes, and snippets.

@EelcoHoogendoorn
Created November 12, 2023 10:50
Show Gist options
  • Save EelcoHoogendoorn/269a78de3dd42993b0d7ba1b9129181c to your computer and use it in GitHub Desktop.
Save EelcoHoogendoorn/269a78de3dd42993b0d7ba1b9129181c to your computer and use it in GitHub Desktop.
"""Minimal example of DLR (diagonal linear recurrent) layer in JAX
https://arxiv.org/pdf/2212.00768.pdf
"""
from typing import Any, Callable, Sequence, Tuple
from flax import linen
import jax
import jax.numpy as jnp
from numga.algebra.algebra import Algebra
class GADLR(linen.Module):
"""GA-DLR module."""
n_rotors: int
n_outputs: int
log_decay: Tuple[float, float] = (-4., 0.)
algebra: Algebra = Algebra.from_pqr(2, 0, 0)
bias: bool = False
kernel_init: Callable[..., Any] = jax.nn.initializers.glorot_normal()
# kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform()
@linen.compact
def __call__(self, x, u):
even = self.algebra.subspace.even_grade()
bivec = self.algebra.subspace.bivector()
def prod(a, b):
op = self.algebra.operator.product(even, even)
return jnp.einsum('ijk,bi,bj->bk', op.kernel, a, b)
def exp(b: "Bivector", steps=4):
"""Calculate exponential of a bivector via bisection method
Dont need a high step count in this approximate exponential;
just need ability to seed the space of rotors in parametrizable manner"""
sel = self.algebra.operator.select(bivec, even)
r = jnp.einsum('ij,bi->bj', sel.kernel, b / (2 ** steps))
r = r.at[:, 0].set(1)
r = r / jnp.linalg.norm(r, axis=1, keepdims=True)
for i in range(steps):
r = prod(r, r)
return r
key = jax.random.PRNGKey(1)
def make_log_scalar(key):
log_scalar = jax.random.uniform(key, shape=(self.n_rotors, 1), minval=self.log_decay[0], maxval=self.log_decay[1])
return log_scalar
# log_scalar = self.param("log_scalar", make_log_scalar)
log_scalar = make_log_scalar(key)
def make_bivec(key):
bivector = jax.random.ball(key, shape=(self.n_rotors,), d=len(bivec))
bn = jnp.linalg.norm(bivector, axis=1, keepdims=True)
return bivector * bn
# for the time being, lets leave these as fixed parameters?
# bivector = self.param("bivector", make_bivec)
bivector = make_bivec(key)
l = exp(jnp.pi * bivector) * jnp.exp(-jnp.exp(log_scalar))
A = lambda x: prod(l, x)
B = linen.Dense(
len(even)*self.n_rotors,
name=f'hidden_B',
kernel_init=self.kernel_init,
use_bias=self.bias)
C = linen.Dense(
self.n_outputs,
name=f'hidden_C',
kernel_init=self.kernel_init,
use_bias=self.bias)
D = linen.Dense(
self.n_outputs,
name=f'hidden_D',
kernel_init=self.kernel_init,
use_bias=self.bias)
x = A(x) + B(u).reshape(x.shape)
y = C(x.flatten()) + D(u)
# y = C(x[:,0]) + D(u)
return x, y
def init_carry(self):
even = self.algebra.subspace.even_grade()
return jnp.zeros((self.n_rotors, len(even)))
def test():
key = jax.random.PRNGKey(1)
dlr = GADLR(
algebra=Algebra.from_pqr(3, 0, 0),
n_outputs=32,
n_rotors=32,
bias=False
)
x = dlr.init_carry()
u = jnp.ones((1,))
params = dlr.init(key, x, u)
# print(params)
# quit()
apply = jax.jit(dlr.apply)
r = []
for i in range(100):
x, y = apply(params, x, u)
r.append(y)
r = jnp.array(r)
import matplotlib.pyplot as plt
plt.plot(r)
plt.show()
if __name__=='__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment