Created
March 17, 2023 06:36
-
-
Save AranKomat/fd219be5790b10c74752efce5c43985d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import jax | |
import jax.numpy as jnp | |
from flax import linen as nn | |
from flax.core import freeze, unfreeze | |
from mingpt.utils import CfgNode as CN | |
# ----------------------------------------------------------------------------- | |
class NewGELU(nn.Module): | |
""" | |
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). | |
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 | |
""" | |
def __call__(self, x): | |
return 0.5 * x * (1.0 + jax.lax.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0)))) | |
class CausalSelfAttention(nn.Module): | |
""" | |
A vanilla multi-head masked self-attention layer with a projection at the end. | |
""" | |
config: CN | |
def setup(self): | |
assert self.config.n_embd % self.config.n_head == 0 | |
self.c_attn = nn.Dense(3 * self.config.n_embd) | |
self.c_proj = nn.Dense(self.config.n_embd) | |
self.attn_dropout = nn.Dropout(self.config.attn_pdrop) | |
self.resid_dropout = nn.Dropout(self.config.resid_pdrop) | |
self.register_buffer("bias", jnp.tril(jnp.ones((self.config.block_size, self.config.block_size))) | |
.reshape((1, 1, self.config.block_size, self.config.block_size))) | |
self.n_head = self.config.n_head | |
self.n_embd = self.config.n_embd | |
def __call__(self, x): | |
B, T, C = x.shape | |
q, k ,v = self.c_attn(x).split(self.n_embd, axis=-1) | |
k = k.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3)) | |
q = q.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3)) | |
v = v.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3)) | |
att = (q @ k.transpose(-2, -1)) * (1.0 / jnp.sqrt(k.shape[-1])) | |
att = att.at[self.bias[:, :, :T, :T] == 0].set(float('-inf')) | |
att = jax.nn.softmax(att, axis=-1) | |
att = self.attn_dropout(att) | |
y = att @ v | |
y = y.transpose((0, 2, 1, 3)).reshape((B, T, C)) | |
y = self.resid_dropout(self.c_proj(y)) | |
return y | |
class Block(nn.Module): | |
""" an unassuming Transformer block """ | |
config: CN | |
def setup(self): | |
self.ln_1 = nn.LayerNorm(self.config.n_embd) | |
self.attn = CausalSelfAttention(self.config) | |
self.ln_2 = nn.LayerNorm(self.config.n_embd) | |
self.mlp = nn.Sequential( | |
nn.Dense(4 * self.config.n_embd), | |
NewGELU(), | |
nn.Dense(self.config.n_embd), | |
nn.Dropout(self.config.resid_pdrop), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment