Created March 17, 2023
import math
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import freeze, unfreeze
from mingpt.utils import CfgNode as CN
# -----------------------------------------------------------------------------
class NewGELU(nn.Module):
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
Reference: Gaussian Error Linear Units (GELU) paper:
def __call__(self, x):
return 0.5 * x * (1.0 + jax.lax.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0))))
class CausalSelfAttention(nn.Module):
A vanilla multi-head masked self-attention layer with a projection at the end.
config: CN
def setup(self):
assert self.config.n_embd % self.config.n_head == 0
self.c_attn = nn.Dense(3 * self.config.n_embd)
self.c_proj = nn.Dense(self.config.n_embd)
self.attn_dropout = nn.Dropout(self.config.attn_pdrop)
self.resid_dropout = nn.Dropout(self.config.resid_pdrop)
self.register_buffer("bias", jnp.tril(jnp.ones((self.config.block_size, self.config.block_size)))
.reshape((1, 1, self.config.block_size, self.config.block_size)))
self.n_head = self.config.n_head
self.n_embd = self.config.n_embd
def __call__(self, x):
B, T, C = x.shape
q, k ,v = self.c_attn(x).split(self.n_embd, axis=-1)
k = k.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3))
q = q.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3))
v = v.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3))
att = (q @ k.transpose(-2, -1)) * (1.0 / jnp.sqrt(k.shape[-1]))
att =[self.bias[:, :, :T, :T] == 0].set(float('-inf'))
att = jax.nn.softmax(att, axis=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose((0, 2, 1, 3)).reshape((B, T, C))
y = self.resid_dropout(self.c_proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
config: CN
def setup(self):
self.ln_1 = nn.LayerNorm(self.config.n_embd)
self.attn = CausalSelfAttention(self.config)
self.ln_2 = nn.LayerNorm(self.config.n_embd)
self.mlp = nn.Sequential(
nn.Dense(4 * self.config.n_embd),
