Skip to content

Instantly share code, notes, and snippets.

View NickleDave's full-sized avatar
🦋
pushing the envelope against the rain

David Nicholson NickleDave

🦋
pushing the envelope against the rain
View GitHub Profile
@NickleDave
NickleDave / lru.py
Created October 20, 2023 01:32 — forked from Ryu1845/lru.py
Linear Recurrent Unit (LRU) from the paper ["Resurrecting Recurrent Neural Networks for Long Sequences"](https://arxiv.org/abs/2303.06349)
"""
Simplified Implementation of the Linear Recurrent Unit
------------------------------------------------------
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU).
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$
according to the following formula (and efficiently parallelized using an associative scan):
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$,
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$.
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2
to account for halving the state variance by taking the real part of the output projection.