Created
June 25, 2022 00:41
-
-
Save PythonNut/7c94f0921ed5f902296cde5e7cf80350 to your computer and use it in GitHub Desktop.
CIFAR10 94% test accuracy in 125s using JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from functools import partial | |
from collections import namedtuple | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
import jax | |
import jax.numpy as jnp | |
import jax.random as jr | |
from jax import jit, vmap, value_and_grad | |
from jax.tree_util import tree_map | |
import optax | |
import augmax | |
import jax_metrics as jm | |
import flax | |
from flax import linen as nn | |
from flax.training import train_state | |
tf.config.set_visible_devices([], device_type="GPU") | |
TFDS_DIR = Path("./tfds") | |
def dict_to_namedtuple(name, d): | |
return namedtuple(name, d.keys())(**d) | |
def rng_iter(key): | |
while True: | |
key, subkey = jr.split(key, 2) | |
yield subkey | |
tfdata, tfinfo = tfds.load( | |
name="cifar10", batch_size=-1, data_dir=TFDS_DIR, with_info=True | |
) | |
tfdata = tfds.as_numpy(tfdata) | |
num_classes = tfinfo.features["label"].num_classes | |
data = dict_to_namedtuple( | |
"Data", | |
{ | |
k: dict_to_namedtuple( | |
"Dataset", {t: jnp.asarray(v) for t, v in ds.items() if t != "id"} | |
) | |
for k, ds in tfdata.items() | |
}, | |
) | |
rng_stream = rng_iter(jr.PRNGKey(0)) | |
dtype = jnp.float16 | |
class ResNetModel(nn.Module): | |
@nn.compact | |
def __call__(self, x, train=True): | |
norm = partial(nn.BatchNorm, use_running_average=not train) | |
act = nn.relu | |
# prep | |
print("init ", x.shape) | |
x = nn.Conv(features=64, kernel_size=(3, 3), dtype=dtype)(x) | |
x = norm()(x) | |
print("after Conv_0 ", x.shape) | |
x = act(x) | |
# no pool in the "prep" layer | |
# layer1 | |
x = nn.Conv(features=128, kernel_size=(3, 3), dtype=dtype)(x) | |
print("after Conv_1 ", x.shape) | |
x = norm()(x) | |
x = act(x) | |
x = nn.max_pool(x, (2, 2), strides=(2, 2)) | |
print("after max_pool_0", x.shape) | |
residual = x | |
y = nn.Conv(features=128, kernel_size=(3, 3), dtype=dtype)(x) | |
print("after Conv_2 ", y.shape) | |
y = norm()(y) | |
y = act(y) | |
y = nn.Conv(features=128, kernel_size=(3, 3), dtype=dtype)(y) | |
print("after Conv_3 ", y.shape) | |
y = act(y) | |
y = norm()(y) | |
x = y + residual | |
# layer2 | |
x = nn.Conv(features=256, kernel_size=(3, 3), dtype=dtype)(x) | |
print("after Conv_4 ", x.shape) | |
x = norm()(x) | |
x = act(x) | |
x = nn.max_pool(x, (2, 2), strides=(2, 2)) | |
print("after max_pool_1", x.shape) | |
# layer3 | |
x = nn.Conv(features=512, kernel_size=(3, 3), dtype=dtype)(x) | |
print("after Conv_5 ", x.shape) | |
x = norm()(x) | |
x = act(x) | |
x = nn.max_pool(x, (2, 2), strides=(2, 2)) | |
print("after max_pool_2", x.shape) | |
residual = x | |
y = nn.Conv(features=512, kernel_size=(3, 3), dtype=dtype)(x) | |
print("after Conv_6 ", y.shape) | |
y = norm()(y) | |
y = act(y) | |
y = nn.Conv(features=512, kernel_size=(3, 3), dtype=dtype)(y) | |
print("after Conv_7 ", y.shape) | |
y = norm()(y) | |
y = act(y) | |
x = y + residual | |
x = nn.max_pool(x, (4, 4), strides=(4, 4)) | |
print("after max_pool_3", x.shape) | |
x = jnp.reshape(x, (x.shape[0], -1)) | |
x = nn.Dense(10, dtype=dtype, use_bias=False)(x) | |
return x * 0.125 | |
model = ResNetModel() | |
BATCH_SIZE = 128 | |
TEST_BATCH_SIZE = 2000 | |
LEARNING_RATE = 1e-2 | |
WEIGHT_DECAY = 5e-3 | |
EPOCHS = 40 | |
variables = model.init(next(rng_stream), jnp.ones((BATCH_SIZE, 32, 32, 3))) | |
variables, params = variables.pop("params") | |
tx = optax.chain( | |
optax.trace(decay=0.9, nesterov=True), | |
optax.add_decayed_weights(WEIGHT_DECAY), | |
optax.scale_by_schedule( | |
optax.linear_onecycle_schedule( | |
len(data.train.image) // BATCH_SIZE * EPOCHS, 0.1 | |
) | |
), | |
optax.scale(-1), | |
) | |
class CutOut(augmax.imagelevel.ImageLevelTransformation): | |
def __init__(self, cutout_size=(8, 8), p=0.5, input_types=[augmax.InputType.IMAGE]): | |
assert input_types == [augmax.InputType.IMAGE] | |
super().__init__(input_types) | |
self.cutout_size = cutout_size | |
self.probability = p | |
def apply(self, rng, inputs: jnp.ndarray, input_types=None, invert=False): | |
assert input_types == [augmax.InputType.IMAGE] | |
key1, key2 = jax.random.split(rng) | |
do_apply = jax.random.bernoulli(key1, self.probability) | |
val = [] | |
for input, type in zip(inputs, input_types): | |
raw_image = input | |
H, W, C = raw_image.shape | |
cx, cy = self.cutout_size | |
cutout = jnp.zeros((cx, cy, C)) | |
x, y = jax.random.randint(key2, (2,), (H - cx), (W - cy)) | |
image = jax.lax.dynamic_update_slice(raw_image, cutout, (x, y, 0)) | |
current = jnp.where(do_apply, image, raw_image) | |
val.append(current) | |
return val | |
test_aug = augmax.Chain( | |
augmax.ByteToFloat(), | |
augmax.Normalize( | |
mean=jnp.array([0.4914, 0.4822, 0.4465]), | |
std=jnp.array([0.2023, 0.1994, 0.2010]), | |
), | |
) | |
train_aug = augmax.Chain( | |
*test_aug.transforms[0].transforms, | |
augmax.CenterCrop(32 + 4 * 2, 32 + 4 * 2), | |
augmax.RandomCrop(32, 32), | |
augmax.HorizontalFlip(), | |
CutOut((8, 8), 1.0), | |
) | |
def softmax_cross_entropy_with_integer_labels(logits, labels): | |
logits_max = jnp.max(logits, axis=-1, keepdims=True) | |
logits -= jax.lax.stop_gradient(logits_max) | |
label_logits = jnp.take_along_axis(logits, labels[..., None], axis=-1)[..., 0] | |
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=-1)) | |
return log_normalizers - label_logits | |
class TrainState(train_state.TrainState): | |
train_metrics: jm.LossesAndMetrics | |
test_metrics: jm.LossesAndMetrics | |
variables: flax.core.frozen_dict.FrozenDict | |
state = TrainState.create( | |
apply_fn=model.apply, | |
params=params, | |
tx=tx, | |
variables=variables, | |
train_metrics=jm.LossesAndMetrics( | |
metrics={ | |
"train_acc": jm.metrics.Accuracy(num_classes=num_classes), | |
}, | |
losses={ | |
"train_loss": jm.losses.Crossentropy(), | |
"train_l2": jm.regularizers.L2(WEIGHT_DECAY), | |
}, | |
).init(), | |
test_metrics=jm.LossesAndMetrics( | |
metrics={ | |
"test_acc": jm.metrics.Accuracy(num_classes=num_classes), | |
}, | |
losses={ | |
"test_loss": jm.losses.Crossentropy(), | |
}, | |
).init(), | |
) | |
@partial(value_and_grad, argnums=1, has_aux=True) | |
def loss(state, params, inputs, targets): | |
outputs, variables = state.apply_fn( | |
{**state.variables, "params": params}, inputs, mutable=["batch_stats"] | |
) | |
state = state.replace(variables=variables) | |
return softmax_cross_entropy_with_integer_labels(outputs, targets).mean(), ( | |
state, | |
outputs, | |
) | |
def train_update(state, inputs, targets): | |
(_, (state, outputs)), grads = loss(state, state.params, inputs, targets) | |
state = state.replace( | |
train_metrics=state.train_metrics.update( | |
preds=outputs, target=targets, parameters=state.params | |
), | |
) | |
return state.apply_gradients(grads=grads) | |
def test_update(state, inputs, targets): | |
outputs = state.apply_fn( | |
{**state.variables, "params": state.params}, inputs, train=False | |
) | |
return state.replace( | |
test_metrics=state.test_metrics.update(preds=outputs, target=targets) | |
) | |
def batch(batch_size, arr, leftovers=False): | |
n = len(arr) | |
batches = arr[: n - n % batch_size].reshape(-1, batch_size, *arr.shape[1:]) | |
if not leftovers: | |
return batches | |
return batches, arr[-(n % batch_size or n) :] | |
@jit | |
def do_epoch(key, state, data): | |
key, perm_key, aug_key = jr.split(key, 3) | |
state = state.replace( | |
train_metrics=state.train_metrics.reset(), | |
test_metrics=state.test_metrics.reset(), | |
) | |
P = jax.random.permutation(perm_key, jnp.arange(len(data.train.label))) | |
train_batches = tree_map(lambda t: batch(BATCH_SIZE, t[P]), data.train) | |
train_aug_keys = jr.split(aug_key, len(train_batches.label)) | |
def train_scanner(s, xyk): | |
(inputs, targets), ak = xyk | |
aks = jr.split(ak, len(targets)) | |
aug_inputs = vmap(train_aug)(aks, inputs) | |
return train_update(s, aug_inputs, targets), None | |
state, _ = jax.lax.scan(train_scanner, state, (train_batches, train_aug_keys)) | |
assert len(data.test.label) % TEST_BATCH_SIZE == 0 | |
test_batches = tree_map(partial(batch, TEST_BATCH_SIZE), data.test) | |
test_aug_keys = jr.split(aug_key, len(test_batches.label)) | |
def test_scanner(s, xyk): | |
(inputs, targets), ak = xyk | |
aks = jr.split(ak, len(targets)) | |
aug_inputs = vmap(test_aug)(aks, inputs) | |
return test_update(s, aug_inputs, targets), None | |
state, _ = jax.lax.scan(test_scanner, state, (test_batches, test_aug_keys)) | |
return state | |
for epoch, key in zip(range(1, EPOCHS + 1), rng_stream): | |
print(f"epoch: {epoch}") | |
state = do_epoch(key, state, data) | |
print(state.train_metrics.compute()) | |
print(state.test_metrics.compute()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment