Skip to content

Instantly share code, notes, and snippets.

@kvablack
Created February 6, 2024 05:36
Show Gist options
  • Save kvablack/1f9a0b211f017815d7af327d921dd20d to your computer and use it in GitHub Desktop.
Save kvablack/1f9a0b211f017815d7af327d921dd20d to your computer and use it in GitHub Desktop.
import numpy as np
from functools import partial
import jax.numpy as jnp
import jax
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.compilation_cache import compilation_cache
import tqdm
from transformer import Transformer, common_transformer_sizes
def main():
compilation_cache.initialize_cache("/tmp/jax_compilation_cache")
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh([jax.device_count()]), ["dp"]
)
rep_spec = PartitionSpec()
rep_sharding = NamedSharding(mesh, rep_spec)
dp_spec = PartitionSpec("dp")
dp_sharding = NamedSharding(mesh, dp_spec)
token_dim, transformer_kwargs = common_transformer_sizes("vit_l")
model = Transformer(**transformer_kwargs, remat_policy="full")
rng = jax.random.PRNGKey(0)
def data_iterator():
i = 0
attention_mask = np.broadcast_to(np.tri(768), (128, 1, 768, 768))
while True:
tokens = np.full((128, 768, token_dim), i / 100)
yield tokens, attention_mask
iter = data_iterator()
@jax.jit
def init_fn(*args):
params = model.init(rng, *args, train=False)["params"]
return params
params = init_fn(*next(iter))
def loss_fn(params, tokens, attention_mask):
embeddings = model.apply(
{"params": params},
tokens,
attention_mask,
train=True,
rngs={"dropout": rng},
)
# fake loss that is just mean of last embedding
return jnp.mean(embeddings[:, -1])
@partial(
jax.jit,
in_shardings=(rep_sharding, dp_sharding, dp_sharding),
out_shardings=rep_sharding,
donate_argnums=0,
)
# @partial(
# shard_map,
# mesh=mesh,
# in_specs=(rep_spec, dp_spec, dp_spec),
# out_specs=rep_spec,
# check_rep=False,
# )
def train_step(params, tokens, attention_mask):
grads = jax.grad(loss_fn)(params, tokens, attention_mask)
# grads = jax.lax.pmean(grads, axis_name="dp")
new_params = jax.tree_map(lambda p, g: p - 1e-4 * g, params, grads)
return new_params
for i in tqdm.tqdm(range(0, 10000), dynamic_ncols=True):
tokens, attention_mask = next(iter)
params = train_step(params, tokens, attention_mask)
jax.block_until_ready(params)
if __name__ == "__main__":
main()
from functools import partial
from typing import Callable, Tuple, Any, Union
import flax.linen as nn
from flax.linen.attention import dot_product_attention_weights
import jax
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.typing import DTypeLike
class AttentionBlock(nn.Module):
"""Transformer self-attention block."""
num_heads: int
dtype: DTypeLike = jnp.float32
kernel_init: Callable[
[Any, tuple, DTypeLike], jax.Array
] = nn.initializers.xavier_uniform()
@nn.compact
def __call__(self, x, mask):
assert x.ndim == 3, f"Expected (batch, length, embed), got {x.shape}"
embed_dim = x.shape[-1]
assert embed_dim % self.num_heads == 0, (
f"Memory dimension ({embed_dim}) must be divisible by number of"
f" heads ({self.num_heads})."
)
head_dim = embed_dim // self.num_heads
# project to query, key, value: [batch, length, n_heads, head_dim]
dense = partial(
nn.DenseGeneral,
axis=-1,
dtype=self.dtype,
features=(self.num_heads, head_dim),
kernel_init=nn.with_logical_partitioning(
# workaround for Flax bug https://github.com/google/flax/issues/3676
lambda *args: self.kernel_init(*args).reshape(
embed_dim, self.num_heads, head_dim
),
("embed", "heads", "head_dim"),
),
)
constraint = partial(
nn.with_logical_constraint,
logical_axis_resources=("batch", "length", "heads", "head_dim"),
)
query, key, value = (
constraint(dense(name="query")(x)),
constraint(dense(name="key")(x)),
constraint(dense(name="value")(x)),
)
# apply Flax dot_product_attention_weights (no attention dropout)
# shape [batch, n_heads, length, length]
attn_weights = dot_product_attention_weights(
query, key, mask=mask, dtype=self.dtype
)
attn_weights = nn.with_logical_constraint(
attn_weights, ("batch", "heads", None, None)
)
x = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
x = nn.with_logical_constraint(x, ("batch", "length", "heads", "head_dim"))
# merge heads and project to output: [batch, length, embed]
x = nn.DenseGeneral(
features=embed_dim,
axis=(-2, -1),
kernel_init=nn.with_logical_partitioning(
# workaround for Flax bug https://github.com/google/flax/issues/3676
lambda *args: self.kernel_init(*args).reshape(
self.num_heads,
head_dim,
embed_dim,
),
("heads", "head_dim", "embed"),
),
dtype=self.dtype,
name="out",
)(x)
x = nn.with_logical_constraint(x, ("batch", "length", "embed"))
return x
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
deterministic: bool
dtype: DTypeLike = jnp.float32
dropout_rate: float = 0.1
kernel_init: Callable[
[Any, tuple, DTypeLike], jax.Array
] = nn.initializers.xavier_uniform()
bias_init: Callable[[Any, tuple, DTypeLike], jax.Array] = nn.initializers.normal(
stddev=1e-6
)
@nn.compact
def __call__(self, x):
"""Applies Transformer MlpBlock module."""
assert x.ndim == 3, f"Expected (batch, length, embed), got {x.shape}"
embed_dim = x.shape[-1]
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
kernel_init=nn.with_logical_partitioning(
self.kernel_init, ("embed", "mlp") # "mlp" is the hidden dim
),
bias_init=self.bias_init,
)(x)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic)
x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
x = nn.Dense(
features=embed_dim,
dtype=self.dtype,
kernel_init=nn.with_logical_partitioning(
self.kernel_init, ("mlp", "embed")
),
bias_init=self.bias_init,
)(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic)
x = nn.with_logical_constraint(x, ("batch", "length", "embed"))
return x
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
inputs: input data.
mlp_dim: dimension of the mlp on top of attention block.
dtype: the dtype of the computation (default: float32).
dropout_rate: dropout rate.
deterministic: bool, deterministic or not (to apply dropout).
num_heads: Number of heads in nn.MultiHeadDotProductAttention
"""
mlp_dim: int
num_heads: int
deterministic: bool
dtype: DTypeLike = jnp.float32
dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, attention_mask):
"""Applies Encoder1DBlock module.
Args:
inputs: Inputs to the layer.
deterministic: Dropout will not be applied when set to true.
Returns:
output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3, f"Expected (batch, length, embed) got {inputs.shape}"
assert (
attention_mask.ndim == 4
), f"Expected (batch, 1, length, length) got {attention_mask.shape}"
inputs = nn.with_logical_constraint(inputs, ("batch", "length", "embed"))
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.with_logical_constraint(x, ("batch", "length", "embed"))
x = AttentionBlock(
dtype=self.dtype,
num_heads=self.num_heads,
)(x, mask=attention_mask)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic)
x = x + inputs
x = nn.with_logical_constraint(x, ("batch", "length", "embed"))
x = checkpoint_name(x, "mid_attention")
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = nn.with_logical_constraint(y, ("batch", "length", "embed"))
y = MlpBlock(
mlp_dim=self.mlp_dim,
deterministic=self.deterministic,
dtype=self.dtype,
dropout_rate=self.dropout_rate,
)(y)
out = x + y
out = nn.with_logical_constraint(out, ("batch", "length", "embed"))
return out, None
class Transformer(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
Attributes:
num_layers: number of layers
mlp_dim: dimension of the mlp on top of attention block
num_heads: Number of heads in nn.MultiHeadDotProductAttention
dropout_rate: dropout rate.
remat_policy: one of "none", "full", "dots_with_no_batch_dims", "mid_attention"
"""
num_layers: int
mlp_dim: int
num_attention_heads: int
dropout_rate: float = 0.1
dtype: Union[DTypeLike, str] = jnp.float32
remat_policy: str = "none"
@nn.compact
def __call__(self, x, attention_mask, *, train):
"""Applies Transformer model on the inputs.
Args:
x: Inputs to the layer.
train: Set to `True` when training.
Returns:
output of a transformer encoder.
"""
assert x.ndim == 3 # (batch, length, embed)
assert attention_mask.ndim == 4 # (batch, 1, length, length)
seq_len = x.shape[1]
assert attention_mask.shape[2] == seq_len
assert attention_mask.shape[3] == seq_len
dtype = jax.dtypes.canonicalize_dtype(self.dtype)
# Input Encoder
remat_block = nn.remat(
Encoder1DBlock,
policy=get_remat_policy(self.remat_policy),
prevent_cse=True,
)
x, _ = nn.scan(
remat_block,
length=self.num_layers,
variable_axes={"params": 0},
in_axes=(nn.broadcast,),
split_rngs={"params": True, "dropout": True},
metadata_params={nn.PARTITION_NAME: "layers"},
)(
dtype=dtype,
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
name="encoder_stack",
num_heads=self.num_attention_heads,
deterministic=not train,
)(
x,
attention_mask,
)
encoded = nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
return encoded
def get_remat_policy(remat_policy: str) -> Callable:
if remat_policy == "none":
return jax.checkpoint_policies.everything_saveable
elif remat_policy == "full":
return jax.checkpoint_policies.nothing_saveable
elif remat_policy == "dots_with_no_batch_dims":
return jax.checkpoint_policies.dots_with_no_batch_dims_saveable
elif remat_policy == "mid_attention":
return jax.checkpoint_policies.save_only_these_names("mid_attention")
raise ValueError(f"Unknown remat_policy: {remat_policy}")
def common_transformer_sizes(transformer_size: str) -> Tuple[int, dict]:
"""
Args:
transformer_size (str): The size of the transformer. One of "dummy", "vanilla", "vit_s", "vit_b", "vit_l", "vit_h"
Returns:
token_embedding_size (int): The size of the token embeddings
transformer_kwargs (dict): The kwargs to pass to the transformer
"""
TRANSFORMER_SIZES = {
"dummy": dict(
num_layers=1,
mlp_dim=256,
num_attention_heads=2,
dropout_rate=0.1,
),
"vanilla": dict(
num_layers=4,
mlp_dim=1024,
num_attention_heads=8,
dropout_rate=0.1,
),
"vit_t": dict(
num_layers=12,
mlp_dim=768,
num_attention_heads=3,
dropout_rate=0.0,
),
"vit_s": dict(
num_layers=12,
mlp_dim=1536,
num_attention_heads=6,
dropout_rate=0.0,
),
"vit_b": dict(
num_layers=12,
mlp_dim=3072,
num_attention_heads=12,
dropout_rate=0.0,
),
"vit_l": dict(
num_layers=24,
mlp_dim=4096,
num_attention_heads=16,
dropout_rate=0.1,
),
"vit_h": dict(
num_layers=32,
mlp_dim=5120,
num_attention_heads=16,
dropout_rate=0.1,
),
}
TOKEN_DIMS = {
"dummy": 256,
"vanilla": 256,
"vit_t": 192,
"vit_s": 384,
"vit_b": 768,
"vit_l": 1024,
"vit_h": 1280,
}
return TOKEN_DIMS[transformer_size], {**TRANSFORMER_SIZES[transformer_size]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment