Last active
June 19, 2022 00:22
-
-
Save merrymercy/af36140f92c2b7e82298e3b7f038d74a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flax import linen as nn | |
from flax.training.train_state import TrainState | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
def create_train_state_and_batch(batch_size, hidden_size, use_remat): | |
class Layer(nn.Module): | |
@nn.compact | |
def __call__(self, x, deterministic=True): | |
x = nn.Dense(hidden_size, use_bias=False)(x) | |
x = nn.Dropout(rate=0.1, deterministic=deterministic)(x) | |
x = nn.Dense(hidden_size, use_bias=False)(x) | |
x = nn.Dropout(rate=0.1, deterministic=deterministic)(x) | |
x = nn.Dense(hidden_size, use_bias=False)(x) | |
x = nn.Dropout(rate=0.1, deterministic=deterministic)(x) | |
return x | |
class Model(nn.Module): | |
@nn.compact | |
def __call__(self, x, deterministic=True): | |
if use_remat: | |
layer = nn.remat(Layer, concrete=True)() | |
else: | |
layer = Layer() | |
x = layer(x, deterministic) | |
x = nn.Dense(hidden_size, use_bias=False)(x) | |
return x | |
rngkey = jax.random.PRNGKey(0) | |
batch = { | |
"x": | |
jax.random.normal(rngkey, (batch_size, hidden_size), | |
dtype=jnp.float32), | |
"y": | |
jax.random.normal(rngkey, (batch_size, hidden_size), | |
dtype=jnp.float32) | |
} | |
# Init model and optimizer | |
model = Model() | |
rngkey = jax.random.PRNGKey(0) | |
params = model.init(rngkey, batch["x"]) | |
tx = optax.sgd(learning_rate=1e-3) | |
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) | |
return state, batch | |
def test_remat(use_remat): | |
def train_step(state, batch, rng_key): | |
def loss_func(params): | |
out = state.apply_fn(params, batch['x'], | |
deterministic=False, | |
rngs={"dropout": rng_key}) | |
return jnp.mean((out - batch['y'])**2) | |
grads = jax.grad(loss_func)(state.params) | |
new_state = state.apply_gradients(grads=grads) | |
return new_state | |
state, batch = create_train_state_and_batch(64, 256, use_remat) | |
rng_key = jax.random.PRNGKey(0) | |
jaxpr = jax.make_jaxpr(train_step, static_argnums=(3,))( | |
state, batch, rng_key) | |
return jaxpr | |
if __name__ == "__main__": | |
#import alpa | |
print("Use remat") | |
jaxpr = test_remat(True) | |
print(jaxpr) | |
print("No remat") | |
jaxpr = test_remat(False) | |
print(jaxpr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment