Created
February 6, 2024 05:36
-
-
Save kvablack/1f9a0b211f017815d7af327d921dd20d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from functools import partial | |
import jax.numpy as jnp | |
import jax | |
from jax.experimental.shard_map import shard_map | |
from jax.experimental import mesh_utils | |
from jax.sharding import NamedSharding, PartitionSpec | |
from jax.experimental.compilation_cache import compilation_cache | |
import tqdm | |
from transformer import Transformer, common_transformer_sizes | |
def main(): | |
compilation_cache.initialize_cache("/tmp/jax_compilation_cache") | |
mesh = jax.sharding.Mesh( | |
mesh_utils.create_device_mesh([jax.device_count()]), ["dp"] | |
) | |
rep_spec = PartitionSpec() | |
rep_sharding = NamedSharding(mesh, rep_spec) | |
dp_spec = PartitionSpec("dp") | |
dp_sharding = NamedSharding(mesh, dp_spec) | |
token_dim, transformer_kwargs = common_transformer_sizes("vit_l") | |
model = Transformer(**transformer_kwargs, remat_policy="full") | |
rng = jax.random.PRNGKey(0) | |
def data_iterator(): | |
i = 0 | |
attention_mask = np.broadcast_to(np.tri(768), (128, 1, 768, 768)) | |
while True: | |
tokens = np.full((128, 768, token_dim), i / 100) | |
yield tokens, attention_mask | |
iter = data_iterator() | |
@jax.jit | |
def init_fn(*args): | |
params = model.init(rng, *args, train=False)["params"] | |
return params | |
params = init_fn(*next(iter)) | |
def loss_fn(params, tokens, attention_mask): | |
embeddings = model.apply( | |
{"params": params}, | |
tokens, | |
attention_mask, | |
train=True, | |
rngs={"dropout": rng}, | |
) | |
# fake loss that is just mean of last embedding | |
return jnp.mean(embeddings[:, -1]) | |
@partial( | |
jax.jit, | |
in_shardings=(rep_sharding, dp_sharding, dp_sharding), | |
out_shardings=rep_sharding, | |
donate_argnums=0, | |
) | |
# @partial( | |
# shard_map, | |
# mesh=mesh, | |
# in_specs=(rep_spec, dp_spec, dp_spec), | |
# out_specs=rep_spec, | |
# check_rep=False, | |
# ) | |
def train_step(params, tokens, attention_mask): | |
grads = jax.grad(loss_fn)(params, tokens, attention_mask) | |
# grads = jax.lax.pmean(grads, axis_name="dp") | |
new_params = jax.tree_map(lambda p, g: p - 1e-4 * g, params, grads) | |
return new_params | |
for i in tqdm.tqdm(range(0, 10000), dynamic_ncols=True): | |
tokens, attention_mask = next(iter) | |
params = train_step(params, tokens, attention_mask) | |
jax.block_until_ready(params) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from typing import Callable, Tuple, Any, Union | |
import flax.linen as nn | |
from flax.linen.attention import dot_product_attention_weights | |
import jax | |
from jax.ad_checkpoint import checkpoint_name | |
import jax.numpy as jnp | |
from jax.typing import DTypeLike | |
class AttentionBlock(nn.Module): | |
"""Transformer self-attention block.""" | |
num_heads: int | |
dtype: DTypeLike = jnp.float32 | |
kernel_init: Callable[ | |
[Any, tuple, DTypeLike], jax.Array | |
] = nn.initializers.xavier_uniform() | |
@nn.compact | |
def __call__(self, x, mask): | |
assert x.ndim == 3, f"Expected (batch, length, embed), got {x.shape}" | |
embed_dim = x.shape[-1] | |
assert embed_dim % self.num_heads == 0, ( | |
f"Memory dimension ({embed_dim}) must be divisible by number of" | |
f" heads ({self.num_heads})." | |
) | |
head_dim = embed_dim // self.num_heads | |
# project to query, key, value: [batch, length, n_heads, head_dim] | |
dense = partial( | |
nn.DenseGeneral, | |
axis=-1, | |
dtype=self.dtype, | |
features=(self.num_heads, head_dim), | |
kernel_init=nn.with_logical_partitioning( | |
# workaround for Flax bug https://github.com/google/flax/issues/3676 | |
lambda *args: self.kernel_init(*args).reshape( | |
embed_dim, self.num_heads, head_dim | |
), | |
("embed", "heads", "head_dim"), | |
), | |
) | |
constraint = partial( | |
nn.with_logical_constraint, | |
logical_axis_resources=("batch", "length", "heads", "head_dim"), | |
) | |
query, key, value = ( | |
constraint(dense(name="query")(x)), | |
constraint(dense(name="key")(x)), | |
constraint(dense(name="value")(x)), | |
) | |
# apply Flax dot_product_attention_weights (no attention dropout) | |
# shape [batch, n_heads, length, length] | |
attn_weights = dot_product_attention_weights( | |
query, key, mask=mask, dtype=self.dtype | |
) | |
attn_weights = nn.with_logical_constraint( | |
attn_weights, ("batch", "heads", None, None) | |
) | |
x = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) | |
x = nn.with_logical_constraint(x, ("batch", "length", "heads", "head_dim")) | |
# merge heads and project to output: [batch, length, embed] | |
x = nn.DenseGeneral( | |
features=embed_dim, | |
axis=(-2, -1), | |
kernel_init=nn.with_logical_partitioning( | |
# workaround for Flax bug https://github.com/google/flax/issues/3676 | |
lambda *args: self.kernel_init(*args).reshape( | |
self.num_heads, | |
head_dim, | |
embed_dim, | |
), | |
("heads", "head_dim", "embed"), | |
), | |
dtype=self.dtype, | |
name="out", | |
)(x) | |
x = nn.with_logical_constraint(x, ("batch", "length", "embed")) | |
return x | |
class MlpBlock(nn.Module): | |
"""Transformer MLP / feed-forward block.""" | |
mlp_dim: int | |
deterministic: bool | |
dtype: DTypeLike = jnp.float32 | |
dropout_rate: float = 0.1 | |
kernel_init: Callable[ | |
[Any, tuple, DTypeLike], jax.Array | |
] = nn.initializers.xavier_uniform() | |
bias_init: Callable[[Any, tuple, DTypeLike], jax.Array] = nn.initializers.normal( | |
stddev=1e-6 | |
) | |
@nn.compact | |
def __call__(self, x): | |
"""Applies Transformer MlpBlock module.""" | |
assert x.ndim == 3, f"Expected (batch, length, embed), got {x.shape}" | |
embed_dim = x.shape[-1] | |
x = nn.Dense( | |
features=self.mlp_dim, | |
dtype=self.dtype, | |
kernel_init=nn.with_logical_partitioning( | |
self.kernel_init, ("embed", "mlp") # "mlp" is the hidden dim | |
), | |
bias_init=self.bias_init, | |
)(x) | |
x = nn.gelu(x) | |
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic) | |
x = nn.with_logical_constraint(x, ("batch", "length", "mlp")) | |
x = nn.Dense( | |
features=embed_dim, | |
dtype=self.dtype, | |
kernel_init=nn.with_logical_partitioning( | |
self.kernel_init, ("mlp", "embed") | |
), | |
bias_init=self.bias_init, | |
)(x) | |
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic) | |
x = nn.with_logical_constraint(x, ("batch", "length", "embed")) | |
return x | |
class Encoder1DBlock(nn.Module): | |
"""Transformer encoder layer. | |
Attributes: | |
inputs: input data. | |
mlp_dim: dimension of the mlp on top of attention block. | |
dtype: the dtype of the computation (default: float32). | |
dropout_rate: dropout rate. | |
deterministic: bool, deterministic or not (to apply dropout). | |
num_heads: Number of heads in nn.MultiHeadDotProductAttention | |
""" | |
mlp_dim: int | |
num_heads: int | |
deterministic: bool | |
dtype: DTypeLike = jnp.float32 | |
dropout_rate: float = 0.1 | |
@nn.compact | |
def __call__(self, inputs, attention_mask): | |
"""Applies Encoder1DBlock module. | |
Args: | |
inputs: Inputs to the layer. | |
deterministic: Dropout will not be applied when set to true. | |
Returns: | |
output after transformer encoder block. | |
""" | |
# Attention block. | |
assert inputs.ndim == 3, f"Expected (batch, length, embed) got {inputs.shape}" | |
assert ( | |
attention_mask.ndim == 4 | |
), f"Expected (batch, 1, length, length) got {attention_mask.shape}" | |
inputs = nn.with_logical_constraint(inputs, ("batch", "length", "embed")) | |
x = nn.LayerNorm(dtype=self.dtype)(inputs) | |
x = nn.with_logical_constraint(x, ("batch", "length", "embed")) | |
x = AttentionBlock( | |
dtype=self.dtype, | |
num_heads=self.num_heads, | |
)(x, mask=attention_mask) | |
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=self.deterministic) | |
x = x + inputs | |
x = nn.with_logical_constraint(x, ("batch", "length", "embed")) | |
x = checkpoint_name(x, "mid_attention") | |
# MLP block. | |
y = nn.LayerNorm(dtype=self.dtype)(x) | |
y = nn.with_logical_constraint(y, ("batch", "length", "embed")) | |
y = MlpBlock( | |
mlp_dim=self.mlp_dim, | |
deterministic=self.deterministic, | |
dtype=self.dtype, | |
dropout_rate=self.dropout_rate, | |
)(y) | |
out = x + y | |
out = nn.with_logical_constraint(out, ("batch", "length", "embed")) | |
return out, None | |
class Transformer(nn.Module): | |
"""Transformer Model Encoder for sequence to sequence translation. | |
Attributes: | |
num_layers: number of layers | |
mlp_dim: dimension of the mlp on top of attention block | |
num_heads: Number of heads in nn.MultiHeadDotProductAttention | |
dropout_rate: dropout rate. | |
remat_policy: one of "none", "full", "dots_with_no_batch_dims", "mid_attention" | |
""" | |
num_layers: int | |
mlp_dim: int | |
num_attention_heads: int | |
dropout_rate: float = 0.1 | |
dtype: Union[DTypeLike, str] = jnp.float32 | |
remat_policy: str = "none" | |
@nn.compact | |
def __call__(self, x, attention_mask, *, train): | |
"""Applies Transformer model on the inputs. | |
Args: | |
x: Inputs to the layer. | |
train: Set to `True` when training. | |
Returns: | |
output of a transformer encoder. | |
""" | |
assert x.ndim == 3 # (batch, length, embed) | |
assert attention_mask.ndim == 4 # (batch, 1, length, length) | |
seq_len = x.shape[1] | |
assert attention_mask.shape[2] == seq_len | |
assert attention_mask.shape[3] == seq_len | |
dtype = jax.dtypes.canonicalize_dtype(self.dtype) | |
# Input Encoder | |
remat_block = nn.remat( | |
Encoder1DBlock, | |
policy=get_remat_policy(self.remat_policy), | |
prevent_cse=True, | |
) | |
x, _ = nn.scan( | |
remat_block, | |
length=self.num_layers, | |
variable_axes={"params": 0}, | |
in_axes=(nn.broadcast,), | |
split_rngs={"params": True, "dropout": True}, | |
metadata_params={nn.PARTITION_NAME: "layers"}, | |
)( | |
dtype=dtype, | |
mlp_dim=self.mlp_dim, | |
dropout_rate=self.dropout_rate, | |
name="encoder_stack", | |
num_heads=self.num_attention_heads, | |
deterministic=not train, | |
)( | |
x, | |
attention_mask, | |
) | |
encoded = nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x) | |
return encoded | |
def get_remat_policy(remat_policy: str) -> Callable: | |
if remat_policy == "none": | |
return jax.checkpoint_policies.everything_saveable | |
elif remat_policy == "full": | |
return jax.checkpoint_policies.nothing_saveable | |
elif remat_policy == "dots_with_no_batch_dims": | |
return jax.checkpoint_policies.dots_with_no_batch_dims_saveable | |
elif remat_policy == "mid_attention": | |
return jax.checkpoint_policies.save_only_these_names("mid_attention") | |
raise ValueError(f"Unknown remat_policy: {remat_policy}") | |
def common_transformer_sizes(transformer_size: str) -> Tuple[int, dict]: | |
""" | |
Args: | |
transformer_size (str): The size of the transformer. One of "dummy", "vanilla", "vit_s", "vit_b", "vit_l", "vit_h" | |
Returns: | |
token_embedding_size (int): The size of the token embeddings | |
transformer_kwargs (dict): The kwargs to pass to the transformer | |
""" | |
TRANSFORMER_SIZES = { | |
"dummy": dict( | |
num_layers=1, | |
mlp_dim=256, | |
num_attention_heads=2, | |
dropout_rate=0.1, | |
), | |
"vanilla": dict( | |
num_layers=4, | |
mlp_dim=1024, | |
num_attention_heads=8, | |
dropout_rate=0.1, | |
), | |
"vit_t": dict( | |
num_layers=12, | |
mlp_dim=768, | |
num_attention_heads=3, | |
dropout_rate=0.0, | |
), | |
"vit_s": dict( | |
num_layers=12, | |
mlp_dim=1536, | |
num_attention_heads=6, | |
dropout_rate=0.0, | |
), | |
"vit_b": dict( | |
num_layers=12, | |
mlp_dim=3072, | |
num_attention_heads=12, | |
dropout_rate=0.0, | |
), | |
"vit_l": dict( | |
num_layers=24, | |
mlp_dim=4096, | |
num_attention_heads=16, | |
dropout_rate=0.1, | |
), | |
"vit_h": dict( | |
num_layers=32, | |
mlp_dim=5120, | |
num_attention_heads=16, | |
dropout_rate=0.1, | |
), | |
} | |
TOKEN_DIMS = { | |
"dummy": 256, | |
"vanilla": 256, | |
"vit_t": 192, | |
"vit_s": 384, | |
"vit_b": 768, | |
"vit_l": 1024, | |
"vit_h": 1280, | |
} | |
return TOKEN_DIMS[transformer_size], {**TRANSFORMER_SIZES[transformer_size]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment