Skip to content

Instantly share code, notes, and snippets.

@ingenieroariel
Created April 8, 2023 14:20
Show Gist options
  • Save ingenieroariel/3f9ad99f1ab8a4f71991b841ccc9b4a0 to your computer and use it in GitHub Desktop.
Save ingenieroariel/3f9ad99f1ab8a4f71991b841ccc9b4a0 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Function
import wkv_cuda
# Forward function
# y = W * K + U * V
# W, U, K, V: matrices
# y: output matrix
def forward_wkv(B, T, C, w, u, k, v):
return wkv_cuda.forward(B, T, C, w, u, k, v)
# Backward function
# Computes gradients with respect to W, U, K, V
def backward_wkv(B, T, C, gy, w, u, k, v, gw, gu, gk, gv):
wkv_cuda.backward(B, T, C, gy, w, u, k, v, gw, gu, gk, gv)
# RWKV function
# Computes y = W * K + U * V
# W, U, K, V: matrices
# y: output matrix
def rwkv(w, u, k, v, C):
B, T = k.shape[:2]
y = forward_wkv(B, T, C, w, u, k, v)
return y
if __name__ == '__main__':
# Set seeds for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Parameters
# C: number of channels
# B: batch size
# T: number of time steps
C = 8
B = 32
T = 32
# Prepare input data
# w: W matrix, shape (B, C)
# u: U matrix, shape (B, C)
# k: K matrix, shape (B, T, C)
# v: V matrix, shape (B, T, C)
w = torch.randn((B, C), device='cuda', requires_grad=True)
u = torch.randn((B, C), device='cuda', requires_grad=True)
k = torch.randn((B, T, C), device='cuda', requires_grad=True)
v = torch.randn((B, T, C), device='cuda', requires_grad=True)
# Forward pass
# y = W * K + U * V
y = rwkv(w, u, k, v, C)
# Compute loss and backpropagate
# L = ||y||_F^2
# L: loss
# ||y||_F: Frobenius norm of y
loss = y.norm(p='fro') ** 2
loss.backward()
# Print gradients
# Gradients with respect to W, U, K, V
print('gw', w.grad)
print('gu', u.grad)
print('gk', k.grad)
print('gv', v.grad)
class L2Wrap(torch.autograd.Function):
This class is a custom implementation of an autograd function to calculate the gradient of a loss function with respect to the model's parameters. It consists of two methods: forward and backward. The forward method takes in the loss and the output tensor y, and it simply returns the loss. The backward method computes the gradient with respect to the loss and the output tensor y. This class aims to encourage the logits to be close to 0 by applying a factor to the gradient.
class WKV(torch.autograd.Function):
This class implements a custom autograd function to perform the forward and backward operations for the RWKV Language Model. The forward method takes the input dimensions B (batch size), T (sequence length), and C (number of channels) and the tensors w, u, k, and v. It applies the CUDA kernel to perform the forward pass and returns the output tensor y. The backward method computes the gradients with respect to the input tensors and accumulates the gradient values in gw, gu, gk, and gv.
def RUN_CUDA(B, T, C, w, u, k, v):
This is a simple wrapper function that applies the WKV custom autograd function to the input tensors. It takes the dimensions B (batch size), T (sequence length), and C (number of channels) and the tensors w, u, k, and v. It then calls the apply method of the WKV class with the input tensors and returns the result.
class RWKV_TimeMix(torch.jit.ScriptModule):
This class implements the time-mix part of the RWKV language model. It inherits from the torch.jit.ScriptModule and contains methods jit_func and forward. The jit_func method takes an input tensor x and computes the tensors for keys (k), values (v), and receptance (r) by applying linear transformations and combining them with shifted versions of the input tensor. The forward method takes an input tensor x and computes the attention matrix using the custom CUDA function, followed by a linear transformation to generate the output tensor.
class RWKV_ChannelMix(torch.jit.ScriptModule):
This class implements the channel-mix part of the RWKV language model. It inherits from the torch.jit.ScriptModule and contains a forward method. The forward method takes an input tensor x and computes the key (k) and receptance (r) tensors by applying linear transformations and combining them with shifted versions of the input tensor. It then computes the output tensor by element-wise multiplication of the sigmoid activation of the receptance tensor with the value tensor.
class GPTConfig:
This class is a simple configuration container for the GPT model. It takes a vocabulary size and a context length, along with other keyword arguments, and stores them as attributes. This allows easy access and management of the model's configuration parameters.
class Block(nn.Module):
This class implements a single block of the GPT model, which includes a layer normalization, a RWKV_TimeMix (or RWKV_ChannelMix for the first layer in some cases), and a RWKV_ChannelMix layer. It also contains a forward method that applies these layers sequentially to the input tensor and returns the output tensor.
class GPT(nn.Module):
This class represents the main GPT model and contains an embedding layer, a sequence of Block modules, a final layer normalization, and a linear head that projects the output to the vocabulary size. The forward method takes an input tensor and applies these layers
sequentially to generate the output tensor. It also includes a method for generating tokens from the model output probabilities.
class Trainer:
This class is responsible for training and evaluating the GPT model. It takes a GPT model, a DataLoader for training and validation data, a learning rate, and other optional parameters. It contains methods for training (train), evaluation (evaluate), and model checkpointing (save_checkpoint). The train method iterates through the training data, computes the loss, and updates the model parameters using the optimizer. The evaluate method calculates the average loss on the validation dataset and returns it.
def top_k_logits(logits, k):
This function is a utility for the token sampling process during the generation of text. It takes the logits of the model's output and a value for k, and it returns the logits after setting all but the top k logits to a very negative value, effectively zeroing out the probabilities of these low-ranked tokens.
def sample_sequence(model, length, context, temperature=1.0, top_k=None):
This function is responsible for generating a sequence of tokens given a model, a desired output length, an input context, and optional temperature and top_k values. It applies the GPT model to the context, samples the next token based on the output probabilities, and updates the context with the new token. This process is repeated until the desired length is reached. The temperature parameter controls the randomness of the sampling process, with higher values leading to more random outputs.
def main():
This function is the main entry point of the script. It parses command-line arguments for configuring the GPT model, training, and evaluation. It initializes the GPT model, DataLoader, and Trainer instances and proceeds with the training and evaluation process based on the provided arguments. It also supports loading a pre-trained model checkpoint for fine-tuning or text generation. Finally, it can save the trained model for future use.
import math, os
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
from deepspeed.ops.adam import FusedAdam
except:
pass
logger = logging.getLogger(__name__)
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
T_MAX = 1024
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T
C, gy.contiguous(), w, u, k, v, gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, gy.float().contiguous(), w, u, k, v, gw, gu, gk, gv)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
class RWKV(nn.Module):
def __init__(self, C, float_mode='fp16'):
super(RWKV, self).__init__()
self.C = C
self.float_mode = float_mode
os.environ['RWKV_FLOAT_MODE'] = float_mode
def forward(self, w, u, k, v):
B, T = k.shape[:2]
C = self.C
y = WKV.apply(B, T, C, w, u, k, v)
return y
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Instantiate model
C = 8
model = RWKV(C).cuda()
# Prepare input data
B = 32
T = 32
w = torch.randn((B, C), device='cuda', requires_grad=True)
u = torch.randn((B, C), device='cuda', requires_grad=True)
k = torch.randn((B, T, C), device='cuda', requires_grad=True)
v = torch.randn((B, T, C), device='cuda', requires_grad=True)
# Forward pass
y = model(w, u, k, v)
# Compute loss and backpropagate
loss = y.norm(p='fro') ** 2
loss = L2Wrap.apply(loss, y)
loss.backward()
# Print gradients
print('gw', w.grad)
print('gu', u.grad)
print('gk', k.grad)
print('gv', v.grad)
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import math, os
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
from deepspeed.ops.adam import FusedAdam
except:
pass # some poor windows users cant install deepspeed
logger = logging.getLogger(__name__)
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
########################################################################################################
# CUDA Kernel
########################################################################################################
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
if '32' in os.environ['RWKV_FLOAT_MODE']:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return y
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return y.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 1024) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device='cuda').contiguous()
gu = torch.zeros((B, C), device='cuda').contiguous()
gk = torch.zeros((B, T, C), device='cuda').contiguous()
gv = torch.zeros((B, T, C), device='cuda').contiguous()
if '32' in os.environ['RWKV_FLOAT_MODE']:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
return (None, None, None, gw, gu, gk, gv)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
def RUN_CUDA(B, T, C, w, u, k, v):
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
print("\n[--> first run, init model params (very slow for large models) <--]")
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
for mm in model.modules():
if "RecursiveScriptModule" in str(type(mm)):
if mm.original_name not in ["Linear"]:
continue
ww = None
for name, param in mm.named_parameters():
if name == "weight":
ww = param
else:
m = mm
if not isinstance(m, (nn.Linear, nn.Embedding)):
continue
ww = m.weight
with torch.no_grad():
name = "[unknown weight]"
for name, parameter in model.named_parameters(): # find the name of the weight
if id(ww) == id(parameter):
break
shape = ww.shape
gain = 1.0
scale = 1.0 # extra scale for gain
if isinstance(m, nn.Embedding):
gain = math.sqrt(max(shape[0], shape[1]))
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
scale = 1e-4
else:
scale = 0
if isinstance(m, nn.Linear):
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
scale = 0.5
if hasattr(m, "scale_init"):
scale = m.scale_init
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
gain *= scale
if scale == -999:
nn.init.eye_(ww)
elif gain == 0:
# zero init is great for some RWKV matrices
nn.init.zeros_(ww)
elif gain > 0:
nn.init.orthogonal_(ww, gain=gain)
else:
nn.init.normal_(ww, mean=0.0, std=-scale)
class RWKV_TimeMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.ctx_len = config.ctx_len
self.n_embd = config.n_embd
attn_sz = config.n_embd
with torch.no_grad(): # fancy init
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
# fancy time_decay
decay_speed = torch.ones(attn_sz)
for h in range(attn_sz):
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
# fancy time_first
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
# fancy time_mix
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
self.key.scale_init = 0
self.receptance.scale_init = 0
self.output.scale_init = 0
@torch.jit.script_method
def jit_func(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)
return sr, k, v
def forward(self, x):
B, T, C = x.size() # x = (Batch,Time,Channel)
sr, k, v = self.jit_func(x)
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
rwkv = self.output(rwkv)
return rwkv
class RWKV_ChannelMix(torch.jit.ScriptModule):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
x = torch.ones(1, 1, config.n_embd)
for i in range(config.n_embd):
x[0, 0, i] = i / config.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
hidden_sz = 4 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
self.value.scale_init = 0
self.receptance.scale_init = 0
@torch.jit.script_method
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.square(torch.relu(k))
kv = self.value(k)
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv
########################################################################################################
# The GPT Model with our blocks
########################################################################################################
class GPTConfig:
def __init__(self, vocab_size, ctx_len, **kwargs):
self.vocab_size = vocab_size
self.ctx_len = ctx_len
for k, v in kwargs.items():
setattr(self, k, v)
class Block(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(config.n_embd)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
self.ffnPre = RWKV_ChannelMix(config, 0)
else:
self.att = RWKV_TimeMix(config, layer_id)
self.ffn = RWKV_ChannelMix(config, layer_id)
def forward(self, x):
if self.layer_id == 0:
x = self.ln0(x)
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
x = x + self.ffnPre(self.ln1(x)) # better in some cases
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.step = 0
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
self.blocks = nn.Sequential(*[Block(config, i)
for i in range(config.n_layer)])
self.ln_out = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if RWKV_HEAD_QK_DIM > 0:
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_q.scale_init = 0
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
self.head_k.scale_init = 0.1
self.register_buffer("copy_mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))
self.ctx_len = config.ctx_len
try:
if os.environ['RWKV_LOAD_MODEL'] == str(False):
RWKV_Init(self, config)
except:
pass
logger.info("number of parameters: %e", sum(p.numel()
for p in self.parameters()))
def get_ctx_len(self):
return self.ctx_len
def _init_weights(self, module):
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.01)
if isinstance(module, (nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=1e-5)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self, train_config):
no_decay = set()
for mn, m in self.named_modules(): # here we disable weight_decay
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
no_decay.add(fpn)
param_dict = {pn: p for pn, p in self.named_parameters()}
optim_groups = [
{"params": [param_dict[pn]
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
try:
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
except:
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
return optimizer
def forward(self, idx, targets=None):
idx = idx.to(self.emb.weight.device)
self.step += 1
B, T = idx.size()
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
x = self.emb(idx)
x = self.blocks(x)
x = self.ln_out(x)
if RWKV_HEAD_QK_DIM > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if '32' in os.environ['RWKV_FLOAT_MODE']:
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
return L2Wrap.apply(loss, x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment