Created
March 13, 2021 11:39
-
-
Save chizuchizu/6d82870a4f4512a3e0f34bb4b6a587c3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flax import linen as nn | |
from flax import optim | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import tensorflow as tf | |
import os | |
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html | |
# OOMを防ぐ | |
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" | |
tf.config.experimental.set_visible_devices([], "GPU") | |
def cross_entropy_loss(logits, labels): | |
one_hot_labels = jax.nn.one_hot(labels, num_classes=10) | |
return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1)) | |
def compute_metrics(logits, labels): | |
loss = cross_entropy_loss(logits, labels) | |
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) | |
metrics = { | |
'loss': loss, | |
'accuracy': accuracy, | |
} | |
return metrics | |
class CNN(nn.Module): | |
"""A simple CNN model.""" | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Conv(features=32, kernel_size=(3, 3))(x) | |
x = nn.relu(x) | |
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) | |
x = nn.Conv(features=64, kernel_size=(3, 3))(x) | |
x = nn.relu(x) | |
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) | |
x = x.reshape((x.shape[0], -1)) # flatten | |
x = nn.Dense(features=256)(x) | |
x = nn.relu(x) | |
x = nn.Dense(features=10)(x) | |
x = nn.log_softmax(x) | |
return x | |
@jax.jit | |
def train_step(optimizer, x, y): | |
"""Train for a single step.""" | |
def loss_fn(params): | |
logits = CNN().apply({'params': params}, x) | |
loss = cross_entropy_loss(logits, y) | |
return loss, logits | |
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | |
(_, logits), grad = grad_fn(optimizer.target) | |
optimizer = optimizer.apply_gradient(grad) | |
metrics = compute_metrics(logits, y) | |
return optimizer, metrics | |
@jax.jit | |
def eval_step(optimizer, x, y): | |
logits = model.apply({'params': optimizer.target}, x) | |
return compute_metrics(logits, y) | |
class Trainer: | |
def __init__(self, model, criterion, epochs, batch_size): | |
self.model = model | |
self.epochs = epochs | |
self.batch_size = batch_size | |
self.rng = jax.random.PRNGKey(0) | |
self.rng, init_rng = jax.random.split(self.rng) | |
self.params = self.get_initial_params(init_rng) | |
self.optimizer = self.create_optimizer( | |
criterion | |
) | |
def get_initial_params(self, key): | |
init_val = jnp.ones((1, 28, 28, 1), jnp.float32) | |
initial_params = self.model.init(key, init_val)['params'] | |
return initial_params | |
def create_optimizer(self, criterion): | |
optimizer = criterion.create(self.params) | |
return optimizer | |
def train_epoch(self, train, target, batch_size, epoch, rng): | |
"""Train for a single epoch.""" | |
train_ds_size = train.shape[0] | |
steps_per_epoch = train_ds_size // batch_size | |
# indexを並び替える | |
perms = jax.random.permutation(rng, train.shape[0]) | |
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch | |
perms = perms.reshape((steps_per_epoch, batch_size)) | |
batch_metrics = [] | |
for perm in perms: | |
batch_x = train[perm, ...] | |
batch_y = target[perm] | |
self.optimizer, metrics = train_step(self.optimizer, batch_x, batch_y) | |
batch_metrics.append(metrics) | |
# compute mean of metrics across each batch in epoch. | |
batch_metrics_np = jax.device_get(batch_metrics) | |
epoch_metrics_np = { | |
k: np.mean([metrics[k] for metrics in batch_metrics_np]) | |
for k in batch_metrics_np[0]} | |
print( | |
f'train epoch: {epoch}, loss: {round(epoch_metrics_np["loss"], 4)}, accuracy: {round(epoch_metrics_np["accuracy"] * 100, 2)}') | |
return epoch_metrics_np | |
def eval_model(self, x, y): | |
metrics = eval_step(self.optimizer, x, y) | |
metrics = jax.device_get(metrics) | |
summary = jax.tree_map(lambda x_: x_.item(), metrics) | |
return summary['loss'], summary['accuracy'] | |
def training(self, X_train, y_train, X_test, y_test): | |
for epoch in range(1, self.epochs + 1): | |
self.rng, input_rng = jax.random.split(self.rng) | |
train_metrics = self.train_epoch( | |
X_train, y_train, self.batch_size, epoch, input_rng | |
) | |
loss, accuracy = self.eval_model(X_test, y_test) | |
print(f'eval epoch: {epoch}, loss: {round(loss, 4)}, accuracy: {round(accuracy * 100, 2)}') | |
def get_datasets(): | |
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() | |
def preprocess_images(images): | |
images = jnp.float32(images.reshape((images.shape[0], 28, 28, 1)) / 255.0) | |
return images.astype("float32") | |
X_train = preprocess_images(X_train) | |
X_test = preprocess_images(X_test) | |
return X_train, y_train, X_test, y_test | |
X_train, y_train, X_test, y_test = get_datasets() | |
model = CNN() | |
criterion = optim.Momentum(learning_rate=0.01, beta=0.9) | |
trainer = Trainer( | |
model=model, | |
criterion=criterion, | |
epochs=5, | |
batch_size=64 | |
) | |
trainer.training(X_train, y_train, X_test, y_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment