Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created March 17, 2023 06:36
Show Gist options
  • Save AranKomat/fd219be5790b10c74752efce5c43985d to your computer and use it in GitHub Desktop.
Save AranKomat/fd219be5790b10c74752efce5c43985d to your computer and use it in GitHub Desktop.
import math
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import freeze, unfreeze
from mingpt.utils import CfgNode as CN
# -----------------------------------------------------------------------------
class NewGELU(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
"""
def __call__(self, x):
return 0.5 * x * (1.0 + jax.lax.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * jnp.power(x, 3.0))))
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
"""
config: CN
def setup(self):
assert self.config.n_embd % self.config.n_head == 0
self.c_attn = nn.Dense(3 * self.config.n_embd)
self.c_proj = nn.Dense(self.config.n_embd)
self.attn_dropout = nn.Dropout(self.config.attn_pdrop)
self.resid_dropout = nn.Dropout(self.config.resid_pdrop)
self.register_buffer("bias", jnp.tril(jnp.ones((self.config.block_size, self.config.block_size)))
.reshape((1, 1, self.config.block_size, self.config.block_size)))
self.n_head = self.config.n_head
self.n_embd = self.config.n_embd
def __call__(self, x):
B, T, C = x.shape
q, k ,v = self.c_attn(x).split(self.n_embd, axis=-1)
k = k.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3))
q = q.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3))
v = v.reshape((B, T, self.n_head, C // self.n_head)).transpose((0, 2, 1, 3))
att = (q @ k.transpose(-2, -1)) * (1.0 / jnp.sqrt(k.shape[-1]))
att = att.at[self.bias[:, :, :T, :T] == 0].set(float('-inf'))
att = jax.nn.softmax(att, axis=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose((0, 2, 1, 3)).reshape((B, T, C))
y = self.resid_dropout(self.c_proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
config: CN
def setup(self):
self.ln_1 = nn.LayerNorm(self.config.n_embd)
self.attn = CausalSelfAttention(self.config)
self.ln_2 = nn.LayerNorm(self.config.n_embd)
self.mlp = nn.Sequential(
nn.Dense(4 * self.config.n_embd),
NewGELU(),
nn.Dense(self.config.n_embd),
nn.Dropout(self.config.resid_pdrop),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment