This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
rng = jax.random.PRNGKey(0) | |
rng, init_rng = jax.random.split(rng) | |
learning_rate = 1e-5 | |
seed = 0 | |
state = create_train_state(init_rng, learning_rate, momentum) | |
del init_rng # Must not be used anymore. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate_model(state, batch): | |
"""Evaluate on the validation set.""" | |
test_imgs, test_lbls = batch | |
metrics = eval_step(state, test_imgs, test_lbls) | |
metrics = jax.device_get(metrics) | |
metrics = jax.tree_map(lambda x: x.item(), metrics) | |
return metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@jax.jit | |
def eval_step(state, batch): | |
images, labels = batch | |
logits = CNN().apply({'params': state.params}, images) | |
return compute_metrics(logits=logits, labels=labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def compute_loss(params,images,labels): | |
logits = CNN().apply({'params': params}, images) | |
loss = cross_entropy_loss(logits=logits, labels=labels) | |
return loss, logits | |
@jax.jit | |
def train_step(state, batch): | |
"""Train for a single step.""" | |
images, labels = batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_train_state(rng, learning_rate): | |
"""Creates initial `TrainState`.""" | |
cnn = CNN() | |
params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params'] | |
tx = optax.adam(learning_rate) | |
return train_state.TrainState.create( | |
apply_fn=cnn.apply, params=params, tx=tx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import flax | |
from flax import linen as nn | |
class CNN(nn.Module): | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Conv(features=32, kernel_size=(3, 3))(x) | |
x = nn.relu(x) | |
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
learning_rate = 0.1 | |
momentum = 0.9 | |
seed = 0 | |
num_epochs = 30 | |
rng = jax.random.PRNGKey(0) | |
rng, init_rng = jax.random.split(rng) | |
state = create_train_state(jax.random.split(init_rng, | |
jax.device_count()),learning_rate, momentum) | |
del init_rng # Must not be used anymore. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
labels = jnp.array([0.20,0.30,0.10,0.20,0.2]) | |
optax.smooth_labels(labels,alpha=0.4) | |
# DeviceArray([0.2 , 0.26, 0.14, 0.2 , 0.2 ], dtype=float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25]) | |
targets = jnp.array([0.20,0.30,0.10,0.20,0.2]) | |
optax.log_cosh(predictions,targets) | |
# DeviceArray([0.04434085, 0.04434085, 0.17013526, 0.00499171, 0.00124949], dtype=float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25]) | |
targets = jnp.array([0.20,0.30,0.10,0.20,0.2]) | |
optax.l2_loss(predictions,targets) | |
# DeviceArray([0.045 , 0.045 , 0.17999998, 0.005 , 0.00125 ], dtype=float32) |
NewerOlder