Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Last active June 19, 2022 00:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merrymercy/af36140f92c2b7e82298e3b7f038d74a to your computer and use it in GitHub Desktop.
Save merrymercy/af36140f92c2b7e82298e3b7f038d74a to your computer and use it in GitHub Desktop.
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
import numpy as np
import optax
def create_train_state_and_batch(batch_size, hidden_size, use_remat):
class Layer(nn.Module):
@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(hidden_size, use_bias=False)(x)
x = nn.Dropout(rate=0.1, deterministic=deterministic)(x)
x = nn.Dense(hidden_size, use_bias=False)(x)
x = nn.Dropout(rate=0.1, deterministic=deterministic)(x)
x = nn.Dense(hidden_size, use_bias=False)(x)
x = nn.Dropout(rate=0.1, deterministic=deterministic)(x)
return x
class Model(nn.Module):
@nn.compact
def __call__(self, x, deterministic=True):
if use_remat:
layer = nn.remat(Layer, concrete=True)()
else:
layer = Layer()
x = layer(x, deterministic)
x = nn.Dense(hidden_size, use_bias=False)(x)
return x
rngkey = jax.random.PRNGKey(0)
batch = {
"x":
jax.random.normal(rngkey, (batch_size, hidden_size),
dtype=jnp.float32),
"y":
jax.random.normal(rngkey, (batch_size, hidden_size),
dtype=jnp.float32)
}
# Init model and optimizer
model = Model()
rngkey = jax.random.PRNGKey(0)
params = model.init(rngkey, batch["x"])
tx = optax.sgd(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
return state, batch
def test_remat(use_remat):
def train_step(state, batch, rng_key):
def loss_func(params):
out = state.apply_fn(params, batch['x'],
deterministic=False,
rngs={"dropout": rng_key})
return jnp.mean((out - batch['y'])**2)
grads = jax.grad(loss_func)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
state, batch = create_train_state_and_batch(64, 256, use_remat)
rng_key = jax.random.PRNGKey(0)
jaxpr = jax.make_jaxpr(train_step, static_argnums=(3,))(
state, batch, rng_key)
return jaxpr
if __name__ == "__main__":
#import alpa
print("Use remat")
jaxpr = test_remat(True)
print(jaxpr)
print("No remat")
jaxpr = test_remat(False)
print(jaxpr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment