Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Created June 19, 2023 20:29
Show Gist options
  • Save Ryu1845/d510c47575d94b54bff0397bd9138d0a to your computer and use it in GitHub Desktop.
Save Ryu1845/d510c47575d94b54bff0397bd9138d0a to your computer and use it in GitHub Desktop.
JAX implementation of Block-State Transfomer (copied from https://arxiv.org/abs/2306.09539)
"""Block-State Transformer Layer."""
# Block Transformers are non-recurrent and parallelizable.
block_transformer = jax.vmap(BRecT.nonrecurrent_cell)
def BST(u):
"""Block-State Transformer Layer."""
global MF # True if Multi-Filter, False otherwise (SH/MH)
# split inputs into windows (l/w, w, d)
u = jnp.split(u, seq_length // win_length, axis=0)
# collect context states from SSM outputs
context_states = [SH/MH/MF]_context_states(u)
# pass the contexts in place of recurrent states
y = block_transformer(
token_embeddings=u,
recurrent_state=context_states,
use_cross_attn_causal_mask=not MF,
use_cross_positional_emb=MF, # context IDs
)
return rearrange(y, "lw w d -> (lw w) d") # (l, d)
"""Context state collection for BST-MF variant."""
# (MF): Multi-Filter
def MF_context_states(u):
"""Multi-Filter Context Collection."""
h, b = get_filters_[unstruct/s4](channels=num_states)
y_s = multichannel_convolution(u, h, b)
# y_s: (l, d, s)
context_states = jnp.split(
y_s, seq_length // win_length, axis=0)
# context_states: (l/w, w, d, s)
# collect the last context states
context_states = context_states[:, -1, ...] # (l/w, d, s)
context_states = rearrange(
context_states, "lw d s -> lw s d")
# shift context states corresponding to windows
context_states = jnp.roll(context_states, 1, axis=1)
# replace the initial window with trainable weights
init_context = get_init_context(num_states) # (d, s)
context_states[0] = init_context
# lift to multiple heads
context_states = dense(context_states)
return context_states # (l/w, s, d, h)
"""Context state collection for BST-MH variant."""
# (MH): Multi-Head
def MH_context_states(u):
"""Multi-Head Context Collection."""
h, b = get_filters_[unstruct/s4](channels=num_heads)
y_h = multichannel_convolution(u, h, b)
# y_h: (l, d, h)
context_states = jnp.split(
y_h, seq_length // win_length, axis=0)
return context_states # (l/w, w, d, h)
"""Context state collection for BST-SH variant."""
num_heads = 8 # (h)
num_states = 32 # (s)
# (SH): Single-Head
def SH_context_states(u):
"""Single-Head Context Collection."""
h, b = get_filters_[unstruct/s4](channels=1)
y_1 = multichannel_convolution(u, h, b)
# y_1: (l, d, 1)
# lift to multiple heads
y_h = dense(y_1)
# y_h: (l, d, h)
context_states = jnp.split(
y_h, seq_length // win_length, axis=0)
return context_states # (l/w, w, d, h)
"""Unstructured filters and convolutions."""
import jax
from jax import numpy as jnp
from einops import rearrange
win_length = 512 # (w)
seq_length = 4096 # (l)
def get_filters_unstruct(channels):
"""Returns trainable filters and biases.
Args:
channels: number of filters.
Returns:
h: filter of shape (seq_length, channels, dim)
b: bias of shape (channels, dim)
"""
t = jnp.linspace(0.0, 1.0, seq_length)
h = jnp.exp(- alpha * t) * dense(positional_emb(t))
b = get_bias()
return h, b
def multichannel_convolution(u, h, b):
"""Multichannel convolution function.
Args:
u: input of shape (seq_length, dim)
h: filters of shape (seq_length, channels, dim)
b: bias of shape (channels, dim)
"""
h = rearrange(h, "l c d -> c d l")
fft_size = seq_length * 2
u_f = jnp.fft.rfft(x, n=fft_size)
h_f = jnp.fft.rfft(h, n=fft_size)
y = jnp.fft.irfft(h_f * x_f, n=fft_size, norm="forward")[
..., :seq_length] # (c, d, l)
y = y + x * b[..., None] # (c, d, l)
y = rearrange(y, "c d l -> l d c")
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment