Skip to content

Instantly share code, notes, and snippets.

@aniquetahir
Created August 21, 2022 12:42
Show Gist options
  • Save aniquetahir/7ef534186f91499ed2b2c467968a7932 to your computer and use it in GitHub Desktop.
Save aniquetahir/7ef534186f91499ed2b2c467968a7932 to your computer and use it in GitHub Desktop.
Simple Boilerplate for a Jax MLP with Dropout
NUM_LAYERS = 3
HIDDEN_DIM = 128
DROPOUT = 0.25
NUM_EPOCHS = 1000
key = jax.random.PRNGKey(6)
def loss_fn(params, key, x, y, is_training=True):
logits = mlp_apply(params, key, x, is_training)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
return loss.mean()
loss_value_grad = jax.jit(jax.value_and_grad(loss_fn))
# Define a simple MLP
class MLP(hk.Module):
def __init__(self, num_layers: int, hidden_dim: int, dropout: float):
super(MLP, self).__init__('mlp')
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.dropout = dropout
def __call__(self, x: chex.Array, is_training:bool=False):
x_ = x
for i in range(self.num_layers - 1):
x_ = hk.Linear(self.hidden_dim)(x_)
# x_ = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=.999)(x_, is_training)
x_ = jax.nn.leaky_relu(x_)
x_ = hk.cond(is_training, lambda a: hk.dropout(hk.next_rng_key(), self.dropout, a), lambda a: a, x_)
x_ = hk.Linear(2)(x_)
return jax.nn.softmax(x_)
mlp_init, mlp_apply = hk.transform(lambda x, t: MLP(NUM_LAYERS, HIDDEN_DIM, DROPOUT)(x, t))
key, split = jax.random.split(key)
mlp_params = mlp_init(split, jnp.array(train_data[:10]), True)
optim = optax.adamw(0.005)
opt_state = optim.init(mlp_params)
# Training loop
for i in tqdm(range(NUM_EPOCHS)):
key, split = jax.random.split(key)
loss, grad = loss_value_grad(mlp_params, split, train_data, train_labels)
updates, opt_state = optim.update(grad, opt_state, mlp_params)
mlp_params = optax.apply_updates(mlp_params, updates)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment