Created
April 8, 2023 14:20
-
-
Save ingenieroariel/3f9ad99f1ab8a4f71991b841ccc9b4a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
import wkv_cuda | |
# Forward function | |
# y = W * K + U * V | |
# W, U, K, V: matrices | |
# y: output matrix | |
def forward_wkv(B, T, C, w, u, k, v): | |
return wkv_cuda.forward(B, T, C, w, u, k, v) | |
# Backward function | |
# Computes gradients with respect to W, U, K, V | |
def backward_wkv(B, T, C, gy, w, u, k, v, gw, gu, gk, gv): | |
wkv_cuda.backward(B, T, C, gy, w, u, k, v, gw, gu, gk, gv) | |
# RWKV function | |
# Computes y = W * K + U * V | |
# W, U, K, V: matrices | |
# y: output matrix | |
def rwkv(w, u, k, v, C): | |
B, T = k.shape[:2] | |
y = forward_wkv(B, T, C, w, u, k, v) | |
return y | |
if __name__ == '__main__': | |
# Set seeds for reproducibility | |
torch.manual_seed(0) | |
torch.cuda.manual_seed(0) | |
# Parameters | |
# C: number of channels | |
# B: batch size | |
# T: number of time steps | |
C = 8 | |
B = 32 | |
T = 32 | |
# Prepare input data | |
# w: W matrix, shape (B, C) | |
# u: U matrix, shape (B, C) | |
# k: K matrix, shape (B, T, C) | |
# v: V matrix, shape (B, T, C) | |
w = torch.randn((B, C), device='cuda', requires_grad=True) | |
u = torch.randn((B, C), device='cuda', requires_grad=True) | |
k = torch.randn((B, T, C), device='cuda', requires_grad=True) | |
v = torch.randn((B, T, C), device='cuda', requires_grad=True) | |
# Forward pass | |
# y = W * K + U * V | |
y = rwkv(w, u, k, v, C) | |
# Compute loss and backpropagate | |
# L = ||y||_F^2 | |
# L: loss | |
# ||y||_F: Frobenius norm of y | |
loss = y.norm(p='fro') ** 2 | |
loss.backward() | |
# Print gradients | |
# Gradients with respect to W, U, K, V | |
print('gw', w.grad) | |
print('gu', u.grad) | |
print('gk', k.grad) | |
print('gv', v.grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class L2Wrap(torch.autograd.Function): | |
This class is a custom implementation of an autograd function to calculate the gradient of a loss function with respect to the model's parameters. It consists of two methods: forward and backward. The forward method takes in the loss and the output tensor y, and it simply returns the loss. The backward method computes the gradient with respect to the loss and the output tensor y. This class aims to encourage the logits to be close to 0 by applying a factor to the gradient. | |
class WKV(torch.autograd.Function): | |
This class implements a custom autograd function to perform the forward and backward operations for the RWKV Language Model. The forward method takes the input dimensions B (batch size), T (sequence length), and C (number of channels) and the tensors w, u, k, and v. It applies the CUDA kernel to perform the forward pass and returns the output tensor y. The backward method computes the gradients with respect to the input tensors and accumulates the gradient values in gw, gu, gk, and gv. | |
def RUN_CUDA(B, T, C, w, u, k, v): | |
This is a simple wrapper function that applies the WKV custom autograd function to the input tensors. It takes the dimensions B (batch size), T (sequence length), and C (number of channels) and the tensors w, u, k, and v. It then calls the apply method of the WKV class with the input tensors and returns the result. | |
class RWKV_TimeMix(torch.jit.ScriptModule): | |
This class implements the time-mix part of the RWKV language model. It inherits from the torch.jit.ScriptModule and contains methods jit_func and forward. The jit_func method takes an input tensor x and computes the tensors for keys (k), values (v), and receptance (r) by applying linear transformations and combining them with shifted versions of the input tensor. The forward method takes an input tensor x and computes the attention matrix using the custom CUDA function, followed by a linear transformation to generate the output tensor. | |
class RWKV_ChannelMix(torch.jit.ScriptModule): | |
This class implements the channel-mix part of the RWKV language model. It inherits from the torch.jit.ScriptModule and contains a forward method. The forward method takes an input tensor x and computes the key (k) and receptance (r) tensors by applying linear transformations and combining them with shifted versions of the input tensor. It then computes the output tensor by element-wise multiplication of the sigmoid activation of the receptance tensor with the value tensor. | |
class GPTConfig: | |
This class is a simple configuration container for the GPT model. It takes a vocabulary size and a context length, along with other keyword arguments, and stores them as attributes. This allows easy access and management of the model's configuration parameters. | |
class Block(nn.Module): | |
This class implements a single block of the GPT model, which includes a layer normalization, a RWKV_TimeMix (or RWKV_ChannelMix for the first layer in some cases), and a RWKV_ChannelMix layer. It also contains a forward method that applies these layers sequentially to the input tensor and returns the output tensor. | |
class GPT(nn.Module): | |
This class represents the main GPT model and contains an embedding layer, a sequence of Block modules, a final layer normalization, and a linear head that projects the output to the vocabulary size. The forward method takes an input tensor and applies these layers | |
sequentially to generate the output tensor. It also includes a method for generating tokens from the model output probabilities. | |
class Trainer: | |
This class is responsible for training and evaluating the GPT model. It takes a GPT model, a DataLoader for training and validation data, a learning rate, and other optional parameters. It contains methods for training (train), evaluation (evaluate), and model checkpointing (save_checkpoint). The train method iterates through the training data, computes the loss, and updates the model parameters using the optimizer. The evaluate method calculates the average loss on the validation dataset and returns it. | |
def top_k_logits(logits, k): | |
This function is a utility for the token sampling process during the generation of text. It takes the logits of the model's output and a value for k, and it returns the logits after setting all but the top k logits to a very negative value, effectively zeroing out the probabilities of these low-ranked tokens. | |
def sample_sequence(model, length, context, temperature=1.0, top_k=None): | |
This function is responsible for generating a sequence of tokens given a model, a desired output length, an input context, and optional temperature and top_k values. It applies the GPT model to the context, samples the next token based on the output probabilities, and updates the context with the new token. This process is repeated until the desired length is reached. The temperature parameter controls the randomness of the sampling process, with higher values leading to more random outputs. | |
def main(): | |
This function is the main entry point of the script. It parses command-line arguments for configuring the GPT model, training, and evaluation. It initializes the GPT model, DataLoader, and Trainer instances and proceeds with the training and evaluation process based on the provided arguments. It also supports loading a pre-trained model checkpoint for fine-tuning or text generation. Finally, it can save the trained model for future use. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math, os | |
import numpy as np | |
import logging | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
try: | |
from deepspeed.ops.adam import FusedAdam | |
except: | |
pass | |
logger = logging.getLogger(__name__) | |
RWKV_HEAD_QK_DIM = 0 | |
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') | |
class L2Wrap(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, loss, y): | |
ctx.save_for_backward(y) | |
return loss | |
@staticmethod | |
def backward(ctx, grad_output): | |
y = ctx.saved_tensors[0] | |
factor = 1e-4 / (y.shape[0] * y.shape[1]) | |
maxx, ids = torch.max(y, -1, keepdim=True) | |
gy = torch.zeros_like(y) | |
gy.scatter_(-1, ids, maxx * factor) | |
return (grad_output, gy) | |
T_MAX = 1024 | |
from torch.utils.cpp_extension import load | |
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], | |
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) | |
class WKV(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, B, T, C, w, u, k, v): | |
ctx.B = B | |
ctx.T = T | |
ctx.C = C | |
assert T <= T_MAX | |
assert B * C % min(C, 1024) == 0 | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
w = -torch.exp(w.contiguous()) | |
u = u.contiguous() | |
k = k.contiguous() | |
v = v.contiguous() | |
else: | |
w = -torch.exp(w.float().contiguous()) | |
u = u.float().contiguous() | |
k = k.float().contiguous() | |
v = v.float().contiguous() | |
ctx.save_for_backward(w, u, k, v) | |
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) | |
wkv_cuda.forward(B, T, C, w, u, k, v, y) | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
return y | |
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': | |
return y.half() | |
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': | |
return y.bfloat16() | |
@staticmethod | |
def backward(ctx, gy): | |
B = ctx.B | |
T = ctx.T | |
C = ctx.C | |
assert T <= T_MAX | |
assert B * C % min(C, 1024) == 0 | |
w, u, k, v = ctx.saved_tensors | |
gw = torch.zeros((B, C), device='cuda').contiguous() | |
gu = torch.zeros((B, C), device='cuda').contiguous() | |
gk = torch.zeros((B, T, C), device='cuda').contiguous() | |
gv = torch.zeros((B, T, C), device='cuda').contiguous() | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
wkv_cuda.backward(B, T | |
C, gy.contiguous(), w, u, k, v, gw, gu, gk, gv) | |
else: | |
wkv_cuda.backward(B, T, C, gy.float().contiguous(), w, u, k, v, gw, gu, gk, gv) | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
return (None, None, None, gw, gu, gk, gv) | |
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': | |
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) | |
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': | |
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) | |
class RWKV(nn.Module): | |
def __init__(self, C, float_mode='fp16'): | |
super(RWKV, self).__init__() | |
self.C = C | |
self.float_mode = float_mode | |
os.environ['RWKV_FLOAT_MODE'] = float_mode | |
def forward(self, w, u, k, v): | |
B, T = k.shape[:2] | |
C = self.C | |
y = WKV.apply(B, T, C, w, u, k, v) | |
return y | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO) | |
torch.manual_seed(0) | |
torch.cuda.manual_seed(0) | |
# Instantiate model | |
C = 8 | |
model = RWKV(C).cuda() | |
# Prepare input data | |
B = 32 | |
T = 32 | |
w = torch.randn((B, C), device='cuda', requires_grad=True) | |
u = torch.randn((B, C), device='cuda', requires_grad=True) | |
k = torch.randn((B, T, C), device='cuda', requires_grad=True) | |
v = torch.randn((B, T, C), device='cuda', requires_grad=True) | |
# Forward pass | |
y = model(w, u, k, v) | |
# Compute loss and backpropagate | |
loss = y.norm(p='fro') ** 2 | |
loss = L2Wrap.apply(loss, y) | |
loss.backward() | |
# Print gradients | |
print('gw', w.grad) | |
print('gu', u.grad) | |
print('gk', k.grad) | |
print('gv', v.grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######################################################################################################## | |
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM | |
######################################################################################################## | |
import math, os | |
import numpy as np | |
import logging | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
try: | |
from deepspeed.ops.adam import FusedAdam | |
except: | |
pass # some poor windows users cant install deepspeed | |
logger = logging.getLogger(__name__) | |
RWKV_HEAD_QK_DIM = 0 | |
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') | |
class L2Wrap(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, loss, y): | |
ctx.save_for_backward(y) | |
return loss | |
@staticmethod | |
def backward(ctx, grad_output): | |
y = ctx.saved_tensors[0] | |
# to encourage the logits to be close to 0 | |
factor = 1e-4 / (y.shape[0] * y.shape[1]) | |
maxx, ids = torch.max(y, -1, keepdim=True) | |
gy = torch.zeros_like(y) | |
gy.scatter_(-1, ids, maxx * factor) | |
return (grad_output, gy) | |
######################################################################################################## | |
# CUDA Kernel | |
######################################################################################################## | |
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] | |
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice | |
from torch.utils.cpp_extension import load | |
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], | |
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) | |
class WKV(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, B, T, C, w, u, k, v): | |
ctx.B = B | |
ctx.T = T | |
ctx.C = C | |
assert T <= T_MAX | |
assert B * C % min(C, 1024) == 0 | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
w = -torch.exp(w.contiguous()) | |
u = u.contiguous() | |
k = k.contiguous() | |
v = v.contiguous() | |
else: | |
w = -torch.exp(w.float().contiguous()) | |
u = u.float().contiguous() | |
k = k.float().contiguous() | |
v = v.float().contiguous() | |
ctx.save_for_backward(w, u, k, v) | |
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) | |
wkv_cuda.forward(B, T, C, w, u, k, v, y) | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
return y | |
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': | |
return y.half() | |
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': | |
return y.bfloat16() | |
@staticmethod | |
def backward(ctx, gy): | |
B = ctx.B | |
T = ctx.T | |
C = ctx.C | |
assert T <= T_MAX | |
assert B * C % min(C, 1024) == 0 | |
w, u, k, v = ctx.saved_tensors | |
gw = torch.zeros((B, C), device='cuda').contiguous() | |
gu = torch.zeros((B, C), device='cuda').contiguous() | |
gk = torch.zeros((B, T, C), device='cuda').contiguous() | |
gv = torch.zeros((B, T, C), device='cuda').contiguous() | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) | |
else: | |
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) | |
gw = torch.sum(gw, dim=0) | |
gu = torch.sum(gu, dim=0) | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
return (None, None, None, gw, gu, gk, gv) | |
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': | |
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) | |
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': | |
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) | |
def RUN_CUDA(B, T, C, w, u, k, v): | |
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) | |
######################################################################################################## | |
# RWKV: RWKV Time-mix + RWKV Channel-mix | |
######################################################################################################## | |
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model | |
print("\n[--> first run, init model params (very slow for large models) <--]") | |
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n") | |
for mm in model.modules(): | |
if "RecursiveScriptModule" in str(type(mm)): | |
if mm.original_name not in ["Linear"]: | |
continue | |
ww = None | |
for name, param in mm.named_parameters(): | |
if name == "weight": | |
ww = param | |
else: | |
m = mm | |
if not isinstance(m, (nn.Linear, nn.Embedding)): | |
continue | |
ww = m.weight | |
with torch.no_grad(): | |
name = "[unknown weight]" | |
for name, parameter in model.named_parameters(): # find the name of the weight | |
if id(ww) == id(parameter): | |
break | |
shape = ww.shape | |
gain = 1.0 | |
scale = 1.0 # extra scale for gain | |
if isinstance(m, nn.Embedding): | |
gain = math.sqrt(max(shape[0], shape[1])) | |
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb? | |
scale = 1e-4 | |
else: | |
scale = 0 | |
if isinstance(m, nn.Linear): | |
if shape[0] > shape[1]: | |
gain = math.sqrt(shape[0] / shape[1]) | |
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection? | |
scale = 0.5 | |
if hasattr(m, "scale_init"): | |
scale = m.scale_init | |
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}") | |
gain *= scale | |
if scale == -999: | |
nn.init.eye_(ww) | |
elif gain == 0: | |
# zero init is great for some RWKV matrices | |
nn.init.zeros_(ww) | |
elif gain > 0: | |
nn.init.orthogonal_(ww, gain=gain) | |
else: | |
nn.init.normal_(ww, mean=0.0, std=-scale) | |
class RWKV_TimeMix(torch.jit.ScriptModule): | |
def __init__(self, config, layer_id): | |
super().__init__() | |
self.layer_id = layer_id | |
self.ctx_len = config.ctx_len | |
self.n_embd = config.n_embd | |
attn_sz = config.n_embd | |
with torch.no_grad(): # fancy init | |
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 | |
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 | |
# fancy time_decay | |
decay_speed = torch.ones(attn_sz) | |
for h in range(attn_sz): | |
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) | |
self.time_decay = nn.Parameter(decay_speed) | |
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) | |
# fancy time_first | |
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) | |
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) | |
# fancy time_mix | |
x = torch.ones(1, 1, config.n_embd) | |
for i in range(config.n_embd): | |
x[0, 0, i] = i / config.n_embd | |
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) | |
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) | |
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) | |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
self.key = nn.Linear(config.n_embd, attn_sz, bias=False) | |
self.value = nn.Linear(config.n_embd, attn_sz, bias=False) | |
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) | |
self.output = nn.Linear(attn_sz, config.n_embd, bias=False) | |
self.key.scale_init = 0 | |
self.receptance.scale_init = 0 | |
self.output.scale_init = 0 | |
@torch.jit.script_method | |
def jit_func(self, x): | |
# Mix x with the previous timestep to produce xk, xv, xr | |
xx = self.time_shift(x) | |
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) | |
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
# Use xk, xv, xr to produce k, v, r | |
k = self.key(xk) | |
v = self.value(xv) | |
r = self.receptance(xr) | |
sr = torch.sigmoid(r) | |
return sr, k, v | |
def forward(self, x): | |
B, T, C = x.size() # x = (Batch,Time,Channel) | |
sr, k, v = self.jit_func(x) | |
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) | |
rwkv = self.output(rwkv) | |
return rwkv | |
class RWKV_ChannelMix(torch.jit.ScriptModule): | |
def __init__(self, config, layer_id): | |
super().__init__() | |
self.layer_id = layer_id | |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) | |
with torch.no_grad(): # fancy init of time_mix | |
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 | |
x = torch.ones(1, 1, config.n_embd) | |
for i in range(config.n_embd): | |
x[0, 0, i] = i / config.n_embd | |
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) | |
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) | |
hidden_sz = 4 * config.n_embd | |
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) | |
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) | |
self.value.scale_init = 0 | |
self.receptance.scale_init = 0 | |
@torch.jit.script_method | |
def forward(self, x): | |
xx = self.time_shift(x) | |
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
k = self.key(xk) | |
k = torch.square(torch.relu(k)) | |
kv = self.value(k) | |
rkv = torch.sigmoid(self.receptance(xr)) * kv | |
return rkv | |
######################################################################################################## | |
# The GPT Model with our blocks | |
######################################################################################################## | |
class GPTConfig: | |
def __init__(self, vocab_size, ctx_len, **kwargs): | |
self.vocab_size = vocab_size | |
self.ctx_len = ctx_len | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
class Block(nn.Module): | |
def __init__(self, config, layer_id): | |
super().__init__() | |
self.config = config | |
self.layer_id = layer_id | |
self.ln1 = nn.LayerNorm(config.n_embd) | |
self.ln2 = nn.LayerNorm(config.n_embd) | |
if self.layer_id == 0: | |
self.ln0 = nn.LayerNorm(config.n_embd) | |
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': | |
self.ffnPre = RWKV_ChannelMix(config, 0) | |
else: | |
self.att = RWKV_TimeMix(config, layer_id) | |
self.ffn = RWKV_ChannelMix(config, layer_id) | |
def forward(self, x): | |
if self.layer_id == 0: | |
x = self.ln0(x) | |
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': | |
x = x + self.ffnPre(self.ln1(x)) # better in some cases | |
else: | |
x = x + self.att(self.ln1(x)) | |
x = x + self.ffn(self.ln2(x)) | |
return x | |
class GPT(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.step = 0 | |
self.config = config | |
self.emb = nn.Embedding(config.vocab_size, config.n_embd) | |
self.blocks = nn.Sequential(*[Block(config, i) | |
for i in range(config.n_layer)]) | |
self.ln_out = nn.LayerNorm(config.n_embd) | |
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
if RWKV_HEAD_QK_DIM > 0: | |
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) | |
self.head_q.scale_init = 0 | |
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) | |
self.head_k.scale_init = 0.1 | |
self.register_buffer("copy_mask", torch.tril( | |
torch.ones(config.ctx_len, config.ctx_len))) | |
self.ctx_len = config.ctx_len | |
try: | |
if os.environ['RWKV_LOAD_MODEL'] == str(False): | |
RWKV_Init(self, config) | |
except: | |
pass | |
logger.info("number of parameters: %e", sum(p.numel() | |
for p in self.parameters())) | |
def get_ctx_len(self): | |
return self.ctx_len | |
def _init_weights(self, module): | |
if isinstance(module, (nn.Linear)): | |
module.weight.data.normal_(mean=0.0, std=0.01) | |
if isinstance(module, (nn.Embedding)): | |
module.weight.data.normal_(mean=0.0, std=1e-5) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
def configure_optimizers(self, train_config): | |
no_decay = set() | |
for mn, m in self.named_modules(): # here we disable weight_decay | |
for pn, p in m.named_parameters(): | |
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
no_decay.add(fpn) | |
param_dict = {pn: p for pn, p in self.named_parameters()} | |
optim_groups = [ | |
{"params": [param_dict[pn] | |
for pn in sorted(list(no_decay))], "weight_decay": 0.0}, | |
] | |
try: | |
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) | |
except: | |
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') | |
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) | |
return optimizer | |
def forward(self, idx, targets=None): | |
idx = idx.to(self.emb.weight.device) | |
self.step += 1 | |
B, T = idx.size() | |
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." | |
x = self.emb(idx) | |
x = self.blocks(x) | |
x = self.ln_out(x) | |
if RWKV_HEAD_QK_DIM > 0: | |
q = self.head_q(x)[:, :T, :] | |
k = self.head_k(x)[:, :T, :] | |
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) | |
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) | |
if '32' in os.environ['RWKV_FLOAT_MODE']: | |
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) | |
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': | |
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() | |
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': | |
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() | |
x = self.head(x) + c | |
else: | |
x = self.head(x) | |
loss = None | |
if targets is not None: | |
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) | |
return L2Wrap.apply(loss, x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment