Skip to content

Instantly share code, notes, and snippets.

import jax
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 1e-5
seed = 0
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng # Must not be used anymore.
def evaluate_model(state, batch):
"""Evaluate on the validation set."""
test_imgs, test_lbls = batch
metrics = eval_step(state, test_imgs, test_lbls)
metrics = jax.device_get(metrics)
metrics = jax.tree_map(lambda x: x.item(), metrics)
return metrics
@jax.jit
def eval_step(state, batch):
images, labels = batch
logits = CNN().apply({'params': state.params}, images)
return compute_metrics(logits=logits, labels=labels)
def compute_loss(params,images,labels):
logits = CNN().apply({'params': params}, images)
loss = cross_entropy_loss(logits=logits, labels=labels)
return loss, logits
@jax.jit
def train_step(state, batch):
"""Train for a single step."""
images, labels = batch
def create_train_state(rng, learning_rate):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params']
tx = optax.adam(learning_rate)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
import flax
from flax import linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
learning_rate = 0.1
momentum = 0.9
seed = 0
num_epochs = 30
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng,
jax.device_count()),learning_rate, momentum)
del init_rng # Must not be used anymore.
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.smooth_labels(labels,alpha=0.4)
# DeviceArray([0.2 , 0.26, 0.14, 0.2 , 0.2 ], dtype=float32)
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.log_cosh(predictions,targets)
# DeviceArray([0.04434085, 0.04434085, 0.17013526, 0.00499171, 0.00124949], dtype=float32)
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.l2_loss(predictions,targets)
# DeviceArray([0.045 , 0.045 , 0.17999998, 0.005 , 0.00125 ], dtype=float32)