Last active
February 29, 2024 01:48
-
-
Save Ronsor/bfff976fd8de06750c4006fd7ac1e96f to your computer and use it in GitHub Desktop.
`mamba_simple` style RWKV implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright © 2024 Ronsor Labs. All rights reserved | |
# Except as otherwise indicated, you may only use this software in accordance with | |
# the MIT license. | |
# A self-contained RWKV-6 x060 implementation, inspired by mamba_simple. | |
# TimeMix and wkv_op_torch adapted from Apache-2.0-licensed: | |
# - https://github.com/SmerkyG/gptcore/blob/main/model/experimental/rwkv6_0.py | |
# - https://github.com/SmerkyG/gptcore/blob/main/model/experimental/rwkv_inner.py | |
# - https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/src/model.py | |
# No guarantees of correctness are made. You'll still have to do proper initialization | |
# of linear layers. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import warnings | |
def wkv_op_torch(r, k, v, w, u, kv_state, chunk_len=64, precision_dtype=torch.float32): | |
B, H, L, K = k.size() | |
V = v.size(-1) | |
T = chunk_len | |
if L == 1: | |
kv = k @ v | |
out = r @ (kv_state + u * kv) | |
kv_state = w * kv_state + kv | |
return out, kv_state | |
else: | |
if L % T != 0: | |
warnings.warn( | |
"Sequence length should be evenly divisible by chunk_len. Finding a new chunk_len for you, but this IS slow.", | |
RuntimeWarning | |
) | |
if L % 2 != 0: | |
T = 1 | |
else: | |
while L % T != 0: | |
T -= 2 | |
N = L // T | |
dtype_min = { | |
torch.float32: 0.005, | |
torch.float64: 1e-10, | |
} | |
assert precision_dtype in dtype_min | |
w = w.clamp(dtype_min[precision_dtype]) | |
w_log = w.float().log() | |
wc_log = w_log.view(w.size(0),H,N,T,K) | |
wc_log_cum = wc_log.cumsum(dim=-2) | |
shifted_wc_log_cum = F.pad(wc_log_cum, (0, 0, 1, -1)) | |
ws = wc_log.sum(dim=-2, keepdim=True) # 1HN1K or BHN1K | |
w_inter = ws - wc_log_cum # 1HNTK or BHNTK (w^(T-1) ... w^0) | |
w_intra = wc_log_cum - wc_log # 1HNTK or BHNTK (w^0 ... w^(T-2)) | |
ws = list(ws.mT.exp().to(r.dtype).unbind(dim=-3)) # N x 1HK1 or BHK1 !!NOTE THE .mT HERE!! | |
w_inter = w_inter.exp().to(r.dtype) # 1HNTK or BHNTK | |
w_intra = w_intra.exp().to(r.dtype) # 1HNTK or BHNTK | |
r = r.view(B,H,N,T,K) | |
k = k.view(B,H,N,T,K) | |
v = v.view(B,H,N,T,V) | |
u = u.unsqueeze(2).to(r.dtype) # (1,H,1,1,K) | |
wc_log_offset = shifted_wc_log_cum[...,T//2:T//2+1,:] # B,H,N,1,K | |
r_decay = (shifted_wc_log_cum - wc_log_offset).to(precision_dtype).exp() # B,H,N,T,K | |
k_inv_decay = (wc_log_offset - wc_log_cum).to(precision_dtype).exp() # B,H,N,T,K | |
a = ((r*r_decay) @ (k*k_inv_decay).mT).to(r.dtype).tril(-1) # B,H,N,T,T | |
a = a + torch.einsum('bhntk,bhntk->bhnt', r, u * k).diag_embed() | |
out = a @ v # BHNTV | |
wkv = (k * w_inter).mT @ v # BHNKV | |
wkv = list(wkv.unbind(dim=-3)) # N x BHKV | |
states = [] | |
for i in range(N): | |
states.append(kv_state) | |
kv_state = kv_state * ws[i] + wkv[i] # BHKV | |
states = torch.stack(states, dim=2) # BHNKV | |
out = out + (r * w_intra) @ states # BHNTV | |
out = out.view(B,H,L,V) | |
return out, kv_state | |
class LayerNorm(nn.Module): | |
def __init__(self, ndim, bias, eps=1e-5, device=None, dtype=None): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(ndim, device=device, dtype=dtype)) | |
self.bias = nn.Parameter(torch.zeros(ndim, device=device, dtype=dtype)) if bias else None | |
def forward(self, input): | |
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, self.eps) | |
class TimeMix(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
d_head=None, | |
bias=False, | |
layer_idx=0, | |
n_layer=1, | |
wkv_op="torch", | |
wkv_chunk_len=64, | |
wkv_dtype=torch.float32, | |
device=None, | |
dtype=None, | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.d_model = d_model | |
self.d_head = d_head | |
self.n_head = int(self.d_model / self.d_head) | |
self.n_layer = n_layer | |
self.layer_idx = layer_idx | |
self.wkv_op = wkv_op | |
self.wkv_chunk_len = wkv_chunk_len | |
self.wkv_dtype = wkv_dtype | |
with torch.no_grad(): | |
TIME_MIX_EXTRA_DIM = 32 | |
W_MIX_EXTRA_DIM = 64 | |
ddd = torch.arange(d_model, **factory_kwargs).view(1, 1, -1) / d_model | |
r0to1 = layer_idx / max(n_layer - 1, 1) | |
r1to0 = 1.0 - (layer_idx / n_layer) | |
self.x_maa = nn.Parameter(1 - torch.pow(ddd, r1to0)) | |
self.r_maa = nn.Parameter(1 - torch.pow(ddd, 0.5 * r1to0)) | |
self.w_maa = nn.Parameter(1 - torch.pow(ddd, r1to0)) | |
self.k_maa = nn.Parameter(1 - torch.pow(ddd, r1to0)) | |
self.v_maa = nn.Parameter(1 - (torch.pow(ddd, r1to0) + 0.3 * r0to1)) | |
self.g_maa = nn.Parameter(1 - torch.pow(ddd, 0.5 * r1to0)) | |
self.tm_w1 = nn.Parameter(torch.empty(self.d_model, TIME_MIX_EXTRA_DIM * 5, **factory_kwargs).uniform_(-0.01, 0.01)) | |
self.tm_w2 = nn.Parameter(torch.zeros(5, TIME_MIX_EXTRA_DIM, self.d_model, **factory_kwargs)) | |
self.td_w1 = nn.Parameter(torch.empty(self.d_model, W_MIX_EXTRA_DIM, **factory_kwargs).uniform_(-0.01, 0.01)) | |
self.td_w2 = nn.Parameter(torch.zeros(W_MIX_EXTRA_DIM, self.d_model, **factory_kwargs)) | |
decay_speed = torch.arange(self.d_model, **factory_kwargs) | |
decay_speed = (decay_speed / max(self.d_model - 1, 1)).pow(0.7 + 1.3 * r0to1) * 5 - 6 | |
self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.d_head)) | |
first_speed = torch.arange(self.d_model, **factory_kwargs) | |
first_speed = r0to1 * (1 - (first_speed / max(self.d_model - 1, 1))) + ((first_speed + 1) % 3 - 1) * 0.1 | |
self.time_first = nn.Parameter(first_speed.reshape(self.n_head, self.d_head)) | |
if wkv_op == "torch": | |
self.wkv_op_fn = wkv_op_torch | |
elif isinstance(wkv_op, str): | |
raise ValueError(f"Unknown WKV operator implementation: {wkv_op}") | |
else: | |
self.wkv_op_fn = wkv_op | |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
self.receptance = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs) | |
self.key = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs) | |
self.value = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs) | |
self.output = nn.Linear(self.d_model, d_model, bias=bias, **factory_kwargs) | |
self.gate = nn.Linear(d_model, self.d_model, bias=bias, **factory_kwargs) | |
self.ln_x = nn.GroupNorm(self.n_head, self.d_model, **factory_kwargs) | |
def forward(self, x, kv_state=None): | |
chunk_len = self.wkv_chunk_len | |
B, T, C = x.size() | |
xx = x | |
sx = self.time_shift(x) - xx | |
xxx = xx + sx * self.x_maa | |
xxx = torch.tanh(xxx @ self.tm_w1).view(B*T, 5, -1).transpose(0, 1) | |
xxx = torch.bmm(xxx, self.tm_w2).view(5, B, T, -1) | |
mw, mk, mv, mr, mg = xxx.unbind(dim=0) | |
wx = xx + sx * (self.w_maa + mw) | |
rx = xx + sx * (self.r_maa + mr) | |
kx = xx + sx * (self.k_maa + mk) | |
vx = xx + sx * (self.v_maa + mv) | |
gx = xx + sx * (self.g_maa + mg) | |
r = self.receptance(rx).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
k = self.key(kx).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
v = self.value(vx).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
g = F.silu(self.gate(gx)) | |
time_decay = self.time_decay.float() | |
time_first = self.time_first.float() | |
has_kv_state = kv_state is not None | |
if not has_kv_state: | |
kv_state = torch.zeros(B, self.n_head, self.d_head, self.d_head, device=r.device, dtype=r.dtype) | |
if r.dtype == torch.bfloat16 and kv_state.dtype != torch.bfloat16: | |
kv_state = kv_state.contiguous().to(torch.bfloat16) | |
w = time_decay.view(1,self.n_head,1,self.d_head) | |
w = w + (torch.tanh(wx @ self.td_w1) @ self.td_w2).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
w = torch.exp(-torch.exp(w)) | |
u = time_first.view(1,self.n_head,1,self.d_head) | |
out, s = self.wkv_op_fn(r, k, v, w, u, kv_state, chunk_len, precision_dtype=self.wkv_dtype) | |
out = out.transpose(1,2).reshape(B*T, self.n_head*self.d_head) | |
out = self.ln_x(out).view(B, T, self.n_head*self.d_head) | |
out = self.output(out * g) | |
if has_kv_state: | |
kv_state.copy_(s) | |
return out | |
class ChannelMix(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
expand=3, | |
bias=False, | |
layer_idx=0, | |
n_layer=1, | |
device=None, | |
dtype=None | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
with torch.no_grad(): | |
ddd = torch.arange(d_model, **factory_kwargs).view(1, 1, -1) / d_model | |
r1to0 = 1.0 - (layer_idx / n_layer) | |
self.k_maa = nn.Parameter(1.0 - torch.pow(ddd, r1to0)) | |
self.r_maa = nn.Parameter(1.0 - torch.pow(ddd, r1to0)) | |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
self.key = nn.Linear(d_model, expand * d_model, bias=bias, **factory_kwargs) | |
self.value = nn.Linear(expand * d_model, d_model, bias=bias, **factory_kwargs) | |
self.receptance = nn.Linear(d_model, d_model, bias=bias, **factory_kwargs) | |
def forward(self, x): | |
xx = self.time_shift(x) - x | |
xk = x + xx * self.k_maa | |
xr = x + xx * self.r_maa | |
x = self.key(xk) | |
x = F.relu(x) ** 2 | |
x = self.value(x) | |
x = torch.sigmoid(self.receptance(xr)) * x | |
return x | |
class CrossChannelMix(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
d_receptance=None, | |
d_key=None, | |
expand=3, | |
bias=False, | |
layer_idx=0, | |
n_layer=1, | |
device=None, | |
dtype=None | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
if d_receptance is None: | |
d_receptance = d_model | |
if d_key is None: | |
d_key = d_model | |
with torch.no_grad(): | |
ddk = torch.arange(d_key, **factory_kwargs).view(1, 1, -1) / d_key | |
ddr = torch.arange(d_receptance, **factory_kwargs).view(1, 1, -1) / d_receptance | |
r1to0 = 1.0 - (layer_idx / n_layer) | |
self.k_maa = nn.Parameter(1.0 - torch.pow(ddk, r1to0)) | |
self.r_maa = nn.Parameter(1.0 - torch.pow(ddr, r1to0)) | |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
self.key = nn.Linear(d_key, expand * d_model, bias=bias, **factory_kwargs) | |
self.value = nn.Linear(expand * d_model, d_model, bias=bias, **factory_kwargs) | |
self.receptance = nn.Linear(d_receptance, d_model, bias=bias, **factory_kwargs) | |
def forward(self, xr, xk): | |
xxk = self.time_shift(xk) - xk | |
xxr = self.time_shift(xr) - xr | |
xk = xk + xxk * self.k_maa | |
xr = xr + xxr * self.r_maa | |
x = self.key(xk) | |
x = F.relu(x) ** 2 | |
x = self.value(x) | |
x = torch.sigmoid(self.receptance(xr)) * x | |
return x | |
class Block(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
d_head, | |
expand=3, | |
bias=False, | |
layer_idx=0, | |
n_layer=1, | |
tmix_kwargs={}, | |
cmix_kwargs={}, | |
device=None, | |
dtype=None | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.ln_1 = LayerNorm(d_model, bias, **factory_kwargs) | |
self.tmix = TimeMix(d_model, d_head, bias, layer_idx, n_layer, **factory_kwargs, **tmix_kwargs) | |
self.ln_2 = LayerNorm(d_model, bias, **factory_kwargs) | |
self.cmix = ChannelMix(d_model, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs) | |
def forward(self, x, *, kv_state=None): | |
x = x + self.tmix(self.ln_1(x), kv_state) | |
x = x + self.cmix(self.ln_2(x)) | |
return x | |
class CrossBlock(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
d_cross, | |
d_head, | |
expand=3, | |
bias=False, | |
cross_mode="key", | |
layer_idx=0, | |
n_layer=1, | |
tmix_kwargs={}, | |
cmix_kwargs={}, | |
device=None, | |
dtype=None | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.cross_mode = cross_mode | |
if cross_mode == "key": | |
self.xmix = CrossChannelMix( | |
d_model, d_model, d_cross, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs | |
) | |
elif cross_mode == "receptance": | |
self.xmix = CrossChannelMix( | |
d_model, d_cross, d_model, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs | |
) | |
else: | |
raise ValueError(f"Unknown cross_mode: {cross_mode}") | |
self.ln_x = LayerNorm(d_model, bias, **factory_kwargs) | |
self.ln_c = LayerNorm(d_cross, bias, **factory_kwargs) | |
self.ln_1 = LayerNorm(d_model, bias, **factory_kwargs) | |
self.tmix = TimeMix(d_model, d_head, bias, layer_idx, n_layer, **factory_kwargs, **tmix_kwargs) | |
self.ln_2 = LayerNorm(d_model, bias, **factory_kwargs) | |
self.cmix = ChannelMix(d_model, expand, bias, layer_idx, n_layer, **factory_kwargs, **cmix_kwargs) | |
def forward(self, x, xc=None, *, kv_state=None): | |
if xc is not None: | |
if self.cross_mode == "key": | |
x = x + self.xmix(self.ln_x(x), self.ln_c(xc)) | |
elif self.cross_mode == "receptance": | |
x = x + self.xmix(self.ln_c(xc), self.ln_x(x)) | |
x = x + self.tmix(self.ln_1(x), kv_state) | |
x = x + self.cmix(self.ln_2(x)) | |
return x | |
if __name__ == "__main__": | |
torch.manual_seed(42) | |
rwkv = Block(d_model=8, d_head=4) | |
n = 4096 | |
x = torch.randn(1, n, 8) | |
kv_state = torch.zeros(1, 8 // 4, 4, 4) | |
print('parallel', rwkv(x, kv_state=kv_state)[0, -1, :]) | |
print(kv_state) | |
kv_state = torch.zeros(1, 8 // 4, 4, 4) | |
print('chunk 1', rwkv(x[:, :1512, :], kv_state=kv_state)[0, -1, :]) | |
print('chunk 2', rwkv(x[:, 1512:, :], kv_state=kv_state)[0, -1, :]) | |
print(kv_state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment