Skip to content

Instantly share code, notes, and snippets.

@chizuchizu
Created March 13, 2021 11:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chizuchizu/6d82870a4f4512a3e0f34bb4b6a587c3 to your computer and use it in GitHub Desktop.
Save chizuchizu/6d82870a4f4512a3e0f34bb4b6a587c3 to your computer and use it in GitHub Desktop.
from flax import linen as nn
from flax import optim
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import os
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
# OOMを防ぐ
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
tf.config.experimental.set_visible_devices([], "GPU")
def cross_entropy_loss(logits, labels):
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
@jax.jit
def train_step(optimizer, x, y):
"""Train for a single step."""
def loss_fn(params):
logits = CNN().apply({'params': params}, x)
loss = cross_entropy_loss(logits, y)
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
metrics = compute_metrics(logits, y)
return optimizer, metrics
@jax.jit
def eval_step(optimizer, x, y):
logits = model.apply({'params': optimizer.target}, x)
return compute_metrics(logits, y)
class Trainer:
def __init__(self, model, criterion, epochs, batch_size):
self.model = model
self.epochs = epochs
self.batch_size = batch_size
self.rng = jax.random.PRNGKey(0)
self.rng, init_rng = jax.random.split(self.rng)
self.params = self.get_initial_params(init_rng)
self.optimizer = self.create_optimizer(
criterion
)
def get_initial_params(self, key):
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = self.model.init(key, init_val)['params']
return initial_params
def create_optimizer(self, criterion):
optimizer = criterion.create(self.params)
return optimizer
def train_epoch(self, train, target, batch_size, epoch, rng):
"""Train for a single epoch."""
train_ds_size = train.shape[0]
steps_per_epoch = train_ds_size // batch_size
# indexを並び替える
perms = jax.random.permutation(rng, train.shape[0])
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch_x = train[perm, ...]
batch_y = target[perm]
self.optimizer, metrics = train_step(self.optimizer, batch_x, batch_y)
batch_metrics.append(metrics)
# compute mean of metrics across each batch in epoch.
batch_metrics_np = jax.device_get(batch_metrics)
epoch_metrics_np = {
k: np.mean([metrics[k] for metrics in batch_metrics_np])
for k in batch_metrics_np[0]}
print(
f'train epoch: {epoch}, loss: {round(epoch_metrics_np["loss"], 4)}, accuracy: {round(epoch_metrics_np["accuracy"] * 100, 2)}')
return epoch_metrics_np
def eval_model(self, x, y):
metrics = eval_step(self.optimizer, x, y)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x_: x_.item(), metrics)
return summary['loss'], summary['accuracy']
def training(self, X_train, y_train, X_test, y_test):
for epoch in range(1, self.epochs + 1):
self.rng, input_rng = jax.random.split(self.rng)
train_metrics = self.train_epoch(
X_train, y_train, self.batch_size, epoch, input_rng
)
loss, accuracy = self.eval_model(X_test, y_test)
print(f'eval epoch: {epoch}, loss: {round(loss, 4)}, accuracy: {round(accuracy * 100, 2)}')
def get_datasets():
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
def preprocess_images(images):
images = jnp.float32(images.reshape((images.shape[0], 28, 28, 1)) / 255.0)
return images.astype("float32")
X_train = preprocess_images(X_train)
X_test = preprocess_images(X_test)
return X_train, y_train, X_test, y_test
X_train, y_train, X_test, y_test = get_datasets()
model = CNN()
criterion = optim.Momentum(learning_rate=0.01, beta=0.9)
trainer = Trainer(
model=model,
criterion=criterion,
epochs=5,
batch_size=64
)
trainer.training(X_train, y_train, X_test, y_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment