Skip to content

Instantly share code, notes, and snippets.

@Ronsor
Last active February 29, 2024 01:48
Show Gist options
  • Save Ronsor/bfff976fd8de06750c4006fd7ac1e96f to your computer and use it in GitHub Desktop.
Save Ronsor/bfff976fd8de06750c4006fd7ac1e96f to your computer and use it in GitHub Desktop.
`mamba_simple` style RWKV implementation
# Copyright © 2024 Ronsor Labs. All rights reserved
# Except as otherwise indicated, you may only use this software in accordance with
# the MIT license.
# A self-contained RWKV-6 x060 implementation, inspired by mamba_simple.
# TimeMix and wkv_op_torch adapted from Apache-2.0-licensed:
# - https://github.com/SmerkyG/gptcore/blob/main/model/experimental/rwkv6_0.py
# - https://github.com/SmerkyG/gptcore/blob/main/model/experimental/rwkv_inner.py
# - https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/src/model.py
# No guarantees of correctness are made. You'll still have to do proper initialization
# of linear layers.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
def wkv_op_torch(r, k, v, w, u, kv_state, chunk_len=64, precision_dtype=torch.float32):
B, H, L, K = k.size()
V = v.size(-1)
T = chunk_len
if L == 1:
kv = k @ v
out = r @ (kv_state + u * kv)
kv_state = w * kv_state + kv
return out, kv_state
else:
if L % T != 0:
warnings.warn(
"Sequence length should be evenly divisible by chunk_len. Finding a new chunk_len for you, but this IS slow.",
RuntimeWarning
)
if L % 2 != 0:
T = 1
else:
while L % T != 0:
T -= 2
N = L // T
dtype_min = {
torch.float32: 0.005,
torch.float64: 1e-10,
}
assert precision_dtype in dtype_min
w = w.clamp(dtype_min[precision_dtype])
w_log = w.float().log()
wc_log = w_log.view(w.size(0),H,N,T,K)
wc_log_cum = wc_log.cumsum(dim=-2)
shifted_wc_log_cum = F.pad(wc_log_cum, (0, 0, 1, -1))
ws = wc_log.sum(dim=-2, keepdim=True) # 1HN1K or BHN1K
w_inter = ws - wc_log_cum # 1HNTK or BHNTK (w^(T-1) ... w^0)
w_intra = wc_log_cum - wc_log # 1HNTK or BHNTK (w^0 ... w^(T-2))
ws = list(ws.mT.exp().to(r.dtype).unbind(dim=-3)) # N x 1HK1 or BHK1 !!NOTE THE .mT HERE!!
w_inter = w_inter.exp().to(r.dtype) # 1HNTK or BHNTK
w_intra = w_intra.exp().to(r.dtype) # 1HNTK or BHNTK
r = r.view(B,H,N,T,K)
k = k.view(B,H,N,T,K)
v = v.view(B,H,N,T,V)
u = u.unsqueeze(2).to(r.dtype) # (1,H,1,1,K)
wc_log_offset = shifted_wc_log_cum[...,T//2:T//2+1,:] # B,H,N,1,K
r_decay = (shifted_wc_log_cum - wc_log_offset).to(precision_dtype).exp() # B,H,N,T,K
k_inv_decay = (wc_log_offset - wc_log_cum).to(precision_dtype).exp() # B,H,N,T,K
a = ((r*r_decay) @ (k*k_inv_decay).mT).to(r.dtype).tril(-1) # B,H,N,T,T
a = a + torch.einsum('bhntk,bhntk->bhnt', r, u * k).diag_embed()
out = a @ v # BHNTV
wkv = (k * w_inter).mT @ v # BHNKV
wkv = list(wkv.unbind(dim=-3)) # N x BHKV
states = []
for i in range(N):
states.append(kv_state)
kv_state = kv_state * ws[i] + wkv[i] # BHKV
states = torch.stack(states, dim=2) # BHNKV
out = out + (r * w_intra) @ states # BHNTV
out = out.view(B,H,L,V)
return out, kv_state
class LayerNorm(nn.Module):
def __init__(self, ndim, bias, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(ndim, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.zeros(ndim, device=device, dtype=dtype)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, self.eps)
class TimeMix(nn.Module):
def __init__(
self,
d_model,
d_head=None,
bias=False,
layer_idx=0,
n_layer=1,
wkv_op="torch",
wkv_chunk_len=64,
wkv_dtype=torch.float32,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_head = d_head
self.n_head = int(self.d_model / self.d_head)
self.n_layer = n_layer
self.layer_idx = layer_idx
self.wkv_op = wkv_op
self.wkv_chunk_len = wkv_chunk_len
self.wkv_dtype = wkv_dtype
with torch.no_grad():
TIME_MIX_EXTRA_DIM = 32
W_MIX_EXTRA_DIM = 64
ddd = torch.arange(d_model, **factory_kwargs).view(1, 1, -1) / d_model
r0to1 = layer_idx / max(n_layer - 1, 1)
r1to0 = 1.0 - (layer_idx / n_layer)
self.x_maa = nn.Parameter(1 - torch.pow(ddd, r1to0))
self.r_maa = nn.Parameter(1 - torch.pow(ddd, 0.5 * r1to0))
self.w_maa = nn.Parameter(1 - torch.pow(ddd, r1to0))
self.k_maa = nn.Parameter(1 - torch.pow(ddd, r1to0))
self.v_maa = nn.Parameter(1 - (torch.pow(ddd, r1to0) + 0.3 * r0to1))
self.g_maa = nn.Parameter(1 - torch.pow(ddd, 0.5 * r1to0))
self.tm_w1 = nn.Parameter(torch.empty(self.d_model, TIME_MIX_EXTRA_DIM * 5, **factory_kwargs).uniform_(-0.01, 0.01))
self.tm_w2 = nn.Parameter(torch.zeros(5, TIME_MIX_EXTRA_DIM, self.d_model, **factory_kwargs))
self.td_w1 = nn.Parameter(torch.empty(self.d_model, W_MIX_EXTRA_DIM, **factory_kwargs).uniform_(-0.01, 0.01))
self.td_w2 = nn.Parameter(torch.zeros(W_MIX_EXTRA_DIM, self.d_model, **factory_kwargs))
decay_speed = torch.arange(self.d_model, **factory_kwargs)
decay_speed = (decay_speed / max(self.d_model - 1, 1)).pow(0.7 + 1.3 * r0to1) * 5 - 6
self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.d_head))
first_speed = torch.arange(self.d_model, **factory_kwargs)
first_speed = r0to1 * (1 - (first_speed / max(self.d_model - 1, 1))) + ((first_speed + 1) % 3 - 1) * 0.1
self.time_first = nn.Parameter(first_speed.reshape(self.n_head, self.d_head))
if wkv_op == "torch":
self.wkv_op_fn = wkv_op_torch
elif isinstance(wkv_op, str):
raise ValueError(f"Unknown WKV operator implementation: {wkv_op}")
else:
self.wkv_op_fn = wkv_op
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs)
self.key = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs)
self.value = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs)
self.output = nn.Linear(self.d_model, d_model, bias=bias, **factory_kwargs)
self.gate = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs)
self.ln_x = nn.GroupNorm(self.n_head, self.d_model, **factory_kwargs)
def forward(self, x, kv_state=None):
chunk_len = self.wkv_chunk_len
B, T, C = x.size()
xx = x
sx = self.time_shift(x) - xx
xxx = xx + sx * self.x_maa
xxx = torch.tanh(xxx @ self.tm_w1).view(B*T, 5, -1).transpose(0, 1)
xxx = torch.bmm(xxx, self.tm_w2).view(5, B, T, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
wx = xx + sx * (self.w_maa + mw)
rx = xx + sx * (self.r_maa + mr)
kx = xx + sx * (self.k_maa + mk)
vx = xx + sx * (self.v_maa + mv)
gx = xx + sx * (self.g_maa + mg)
r = self.receptance(rx).view(B, T, self.n_head, self.d_head).transpose(1, 2)
k = self.key(kx).view(B, T, self.n_head, self.d_head).transpose(1, 2)
v = self.value(vx).view(B, T, self.n_head, self.d_head).transpose(1, 2)
g = F.silu(self.gate(gx))
time_decay = self.time_decay.float()
time_first = self.time_first.float()
has_kv_state = kv_state is not None
if not has_kv_state:
kv_state = torch.zeros(B, self.n_head, self.d_head, self.d_head, device=r.device, dtype=r.dtype)
if r.dtype == torch.bfloat16 and kv_state.dtype != torch.bfloat16:
kv_state = kv_state.contiguous().to(torch.bfloat16)
w = time_decay.view(1,self.n_head,1,self.d_head)
w = w + (torch.tanh(wx @ self.td_w1) @ self.td_w2).view(B, T, self.n_head, self.d_head).transpose(1, 2)
w = torch.exp(-torch.exp(w))
u = time_first.view(1,self.n_head,1,self.d_head)
out, s = self.wkv_op_fn(r, k, v, w, u, kv_state, chunk_len, precision_dtype=self.wkv_dtype)
out = out.transpose(1,2).reshape(B*T, self.n_head*self.d_head)
out = self.ln_x(out).view(B, T, self.n_head*self.d_head)
out = self.output(out * g)
if has_kv_state:
kv_state.copy_(s)
return out
class ChannelMix(nn.Module):
def __init__(
self,
d_model,
expand=3,
bias=False,
layer_idx=0,
n_layer=1,
device=None,
dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
with torch.no_grad():
ddd = torch.arange(d_model, **factory_kwargs).view(1, 1, -1) / d_model
r1to0 = 1.0 - (layer_idx / n_layer)
self.k_maa = nn.Parameter(1.0 - torch.pow(ddd, r1to0))
self.r_maa = nn.Parameter(1.0 - torch.pow(ddd, r1to0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(d_model, expand * d_model, bias=bias, **factory_kwargs)
self.value = nn.Linear(expand * d_model, d_model, bias=bias, **factory_kwargs)
self.receptance = nn.Linear(d_model, d_model, bias=bias, **factory_kwargs)
def forward(self, x):
xx = self.time_shift(x) - x
xk = x + xx * self.k_maa
xr = x + xx * self.r_maa
x = self.key(xk)
x = F.relu(x) ** 2
x = self.value(x)
x = torch.sigmoid(self.receptance(xr)) * x
return x
class CrossChannelMix(nn.Module):
def __init__(
self,
d_model,
d_receptance=None,
d_key=None,
expand=3,
bias=False,
layer_idx=0,
n_layer=1,
device=None,
dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if d_receptance is None:
d_receptance = d_model
if d_key is None:
d_key = d_model
with torch.no_grad():
ddk = torch.arange(d_key, **factory_kwargs).view(1, 1, -1) / d_key
ddr = torch.arange(d_receptance, **factory_kwargs).view(1, 1, -1) / d_receptance
r1to0 = 1.0 - (layer_idx / n_layer)
self.k_maa = nn.Parameter(1.0 - torch.pow(ddk, r1to0))
self.r_maa = nn.Parameter(1.0 - torch.pow(ddr, r1to0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(d_key, expand * d_model, bias=bias, **factory_kwargs)
self.value = nn.Linear(expand * d_model, d_model, bias=bias, **factory_kwargs)
self.receptance = nn.Linear(d_receptance, d_model, bias=bias, **factory_kwargs)
def forward(self, xr, xk):
xxk = self.time_shift(xk) - xk
xxr = self.time_shift(xr) - xr
xk = xk + xxk * self.k_maa
xr = xr + xxr * self.r_maa
x = self.key(xk)
x = F.relu(x) ** 2
x = self.value(x)
x = torch.sigmoid(self.receptance(xr)) * x
return x
class Block(nn.Module):
def __init__(
self,
d_model,
d_head,
expand=3,
bias=False,
layer_idx=0,
n_layer=1,
tmix_kwargs={},
cmix_kwargs={},
device=None,
dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.ln_1 = LayerNorm(d_model, bias, **factory_kwargs)
self.tmix = TimeMix(d_model, d_head, bias, layer_idx, n_layer, **factory_kwargs, **tmix_kwargs)
self.ln_2 = LayerNorm(d_model, bias, **factory_kwargs)
self.cmix = ChannelMix(d_model, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs)
def forward(self, x, *, kv_state=None):
x = x + self.tmix(self.ln_1(x), kv_state)
x = x + self.cmix(self.ln_2(x))
return x
class CrossBlock(nn.Module):
def __init__(
self,
d_model,
d_cross,
d_head,
expand=3,
bias=False,
cross_mode="key",
layer_idx=0,
n_layer=1,
tmix_kwargs={},
cmix_kwargs={},
device=None,
dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.cross_mode = cross_mode
if cross_mode == "key":
self.xmix = CrossChannelMix(
d_model, d_model, d_cross, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs
)
elif cross_mode == "receptance":
self.xmix = CrossChannelMix(
d_model, d_cross, d_model, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs
)
else:
raise ValueError(f"Unknown cross_mode: {cross_mode}")
self.ln_x = LayerNorm(d_model, bias, **factory_kwargs)
self.ln_c = LayerNorm(d_cross, bias, **factory_kwargs)
self.ln_1 = LayerNorm(d_model, bias, **factory_kwargs)
self.tmix = TimeMix(d_model, d_head, bias, layer_idx, n_layer, **factory_kwargs, **tmix_kwargs)
self.ln_2 = LayerNorm(d_model, bias, **factory_kwargs)
self.cmix = ChannelMix(d_model, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs)
def forward(self, x, xc=None, *, kv_state=None):
if xc is not None:
if self.cross_mode == "key":
x = x + self.xmix(self.ln_x(x), self.ln_c(xc))
elif self.cross_mode == "receptance":
x = x + self.xmix(self.ln_c(xc), self.ln_x(x))
x = x + self.tmix(self.ln_1(x), kv_state)
x = x + self.cmix(self.ln_2(x))
return x
if __name__ == "__main__":
torch.manual_seed(42)
rwkv = Block(d_model=8, d_head=4)
n = 4096
x = torch.randn(1, n, 8)
kv_state = torch.zeros(1, 8 // 4, 4, 4)
print('parallel', rwkv(x, kv_state=kv_state)[0, -1, :])
print(kv_state)
kv_state = torch.zeros(1, 8 // 4, 4, 4)
print('chunk 1', rwkv(x[:, :1512, :], kv_state=kv_state)[0, -1, :])
print('chunk 2', rwkv(x[:, 1512:, :], kv_state=kv_state)[0, -1, :])
print(kv_state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment