Skip to content

Instantly share code, notes, and snippets.

@PythonNut
Created June 25, 2022 00:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PythonNut/7c94f0921ed5f902296cde5e7cf80350 to your computer and use it in GitHub Desktop.
Save PythonNut/7c94f0921ed5f902296cde5e7cf80350 to your computer and use it in GitHub Desktop.
CIFAR10 94% test accuracy in 125s using JAX
from pathlib import Path
from functools import partial
from collections import namedtuple
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import jit, vmap, value_and_grad
from jax.tree_util import tree_map
import optax
import augmax
import jax_metrics as jm
import flax
from flax import linen as nn
from flax.training import train_state
tf.config.set_visible_devices([], device_type="GPU")
TFDS_DIR = Path("./tfds")
def dict_to_namedtuple(name, d):
return namedtuple(name, d.keys())(**d)
def rng_iter(key):
while True:
key, subkey = jr.split(key, 2)
yield subkey
tfdata, tfinfo = tfds.load(
name="cifar10", batch_size=-1, data_dir=TFDS_DIR, with_info=True
)
tfdata = tfds.as_numpy(tfdata)
num_classes = tfinfo.features["label"].num_classes
data = dict_to_namedtuple(
"Data",
{
k: dict_to_namedtuple(
"Dataset", {t: jnp.asarray(v) for t, v in ds.items() if t != "id"}
)
for k, ds in tfdata.items()
},
)
rng_stream = rng_iter(jr.PRNGKey(0))
dtype = jnp.float16
class ResNetModel(nn.Module):
@nn.compact
def __call__(self, x, train=True):
norm = partial(nn.BatchNorm, use_running_average=not train)
act = nn.relu
# prep
print("init ", x.shape)
x = nn.Conv(features=64, kernel_size=(3, 3), dtype=dtype)(x)
x = norm()(x)
print("after Conv_0 ", x.shape)
x = act(x)
# no pool in the "prep" layer
# layer1
x = nn.Conv(features=128, kernel_size=(3, 3), dtype=dtype)(x)
print("after Conv_1 ", x.shape)
x = norm()(x)
x = act(x)
x = nn.max_pool(x, (2, 2), strides=(2, 2))
print("after max_pool_0", x.shape)
residual = x
y = nn.Conv(features=128, kernel_size=(3, 3), dtype=dtype)(x)
print("after Conv_2 ", y.shape)
y = norm()(y)
y = act(y)
y = nn.Conv(features=128, kernel_size=(3, 3), dtype=dtype)(y)
print("after Conv_3 ", y.shape)
y = act(y)
y = norm()(y)
x = y + residual
# layer2
x = nn.Conv(features=256, kernel_size=(3, 3), dtype=dtype)(x)
print("after Conv_4 ", x.shape)
x = norm()(x)
x = act(x)
x = nn.max_pool(x, (2, 2), strides=(2, 2))
print("after max_pool_1", x.shape)
# layer3
x = nn.Conv(features=512, kernel_size=(3, 3), dtype=dtype)(x)
print("after Conv_5 ", x.shape)
x = norm()(x)
x = act(x)
x = nn.max_pool(x, (2, 2), strides=(2, 2))
print("after max_pool_2", x.shape)
residual = x
y = nn.Conv(features=512, kernel_size=(3, 3), dtype=dtype)(x)
print("after Conv_6 ", y.shape)
y = norm()(y)
y = act(y)
y = nn.Conv(features=512, kernel_size=(3, 3), dtype=dtype)(y)
print("after Conv_7 ", y.shape)
y = norm()(y)
y = act(y)
x = y + residual
x = nn.max_pool(x, (4, 4), strides=(4, 4))
print("after max_pool_3", x.shape)
x = jnp.reshape(x, (x.shape[0], -1))
x = nn.Dense(10, dtype=dtype, use_bias=False)(x)
return x * 0.125
model = ResNetModel()
BATCH_SIZE = 128
TEST_BATCH_SIZE = 2000
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 5e-3
EPOCHS = 40
variables = model.init(next(rng_stream), jnp.ones((BATCH_SIZE, 32, 32, 3)))
variables, params = variables.pop("params")
tx = optax.chain(
optax.trace(decay=0.9, nesterov=True),
optax.add_decayed_weights(WEIGHT_DECAY),
optax.scale_by_schedule(
optax.linear_onecycle_schedule(
len(data.train.image) // BATCH_SIZE * EPOCHS, 0.1
)
),
optax.scale(-1),
)
class CutOut(augmax.imagelevel.ImageLevelTransformation):
def __init__(self, cutout_size=(8, 8), p=0.5, input_types=[augmax.InputType.IMAGE]):
assert input_types == [augmax.InputType.IMAGE]
super().__init__(input_types)
self.cutout_size = cutout_size
self.probability = p
def apply(self, rng, inputs: jnp.ndarray, input_types=None, invert=False):
assert input_types == [augmax.InputType.IMAGE]
key1, key2 = jax.random.split(rng)
do_apply = jax.random.bernoulli(key1, self.probability)
val = []
for input, type in zip(inputs, input_types):
raw_image = input
H, W, C = raw_image.shape
cx, cy = self.cutout_size
cutout = jnp.zeros((cx, cy, C))
x, y = jax.random.randint(key2, (2,), (H - cx), (W - cy))
image = jax.lax.dynamic_update_slice(raw_image, cutout, (x, y, 0))
current = jnp.where(do_apply, image, raw_image)
val.append(current)
return val
test_aug = augmax.Chain(
augmax.ByteToFloat(),
augmax.Normalize(
mean=jnp.array([0.4914, 0.4822, 0.4465]),
std=jnp.array([0.2023, 0.1994, 0.2010]),
),
)
train_aug = augmax.Chain(
*test_aug.transforms[0].transforms,
augmax.CenterCrop(32 + 4 * 2, 32 + 4 * 2),
augmax.RandomCrop(32, 32),
augmax.HorizontalFlip(),
CutOut((8, 8), 1.0),
)
def softmax_cross_entropy_with_integer_labels(logits, labels):
logits_max = jnp.max(logits, axis=-1, keepdims=True)
logits -= jax.lax.stop_gradient(logits_max)
label_logits = jnp.take_along_axis(logits, labels[..., None], axis=-1)[..., 0]
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=-1))
return log_normalizers - label_logits
class TrainState(train_state.TrainState):
train_metrics: jm.LossesAndMetrics
test_metrics: jm.LossesAndMetrics
variables: flax.core.frozen_dict.FrozenDict
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
variables=variables,
train_metrics=jm.LossesAndMetrics(
metrics={
"train_acc": jm.metrics.Accuracy(num_classes=num_classes),
},
losses={
"train_loss": jm.losses.Crossentropy(),
"train_l2": jm.regularizers.L2(WEIGHT_DECAY),
},
).init(),
test_metrics=jm.LossesAndMetrics(
metrics={
"test_acc": jm.metrics.Accuracy(num_classes=num_classes),
},
losses={
"test_loss": jm.losses.Crossentropy(),
},
).init(),
)
@partial(value_and_grad, argnums=1, has_aux=True)
def loss(state, params, inputs, targets):
outputs, variables = state.apply_fn(
{**state.variables, "params": params}, inputs, mutable=["batch_stats"]
)
state = state.replace(variables=variables)
return softmax_cross_entropy_with_integer_labels(outputs, targets).mean(), (
state,
outputs,
)
def train_update(state, inputs, targets):
(_, (state, outputs)), grads = loss(state, state.params, inputs, targets)
state = state.replace(
train_metrics=state.train_metrics.update(
preds=outputs, target=targets, parameters=state.params
),
)
return state.apply_gradients(grads=grads)
def test_update(state, inputs, targets):
outputs = state.apply_fn(
{**state.variables, "params": state.params}, inputs, train=False
)
return state.replace(
test_metrics=state.test_metrics.update(preds=outputs, target=targets)
)
def batch(batch_size, arr, leftovers=False):
n = len(arr)
batches = arr[: n - n % batch_size].reshape(-1, batch_size, *arr.shape[1:])
if not leftovers:
return batches
return batches, arr[-(n % batch_size or n) :]
@jit
def do_epoch(key, state, data):
key, perm_key, aug_key = jr.split(key, 3)
state = state.replace(
train_metrics=state.train_metrics.reset(),
test_metrics=state.test_metrics.reset(),
)
P = jax.random.permutation(perm_key, jnp.arange(len(data.train.label)))
train_batches = tree_map(lambda t: batch(BATCH_SIZE, t[P]), data.train)
train_aug_keys = jr.split(aug_key, len(train_batches.label))
def train_scanner(s, xyk):
(inputs, targets), ak = xyk
aks = jr.split(ak, len(targets))
aug_inputs = vmap(train_aug)(aks, inputs)
return train_update(s, aug_inputs, targets), None
state, _ = jax.lax.scan(train_scanner, state, (train_batches, train_aug_keys))
assert len(data.test.label) % TEST_BATCH_SIZE == 0
test_batches = tree_map(partial(batch, TEST_BATCH_SIZE), data.test)
test_aug_keys = jr.split(aug_key, len(test_batches.label))
def test_scanner(s, xyk):
(inputs, targets), ak = xyk
aks = jr.split(ak, len(targets))
aug_inputs = vmap(test_aug)(aks, inputs)
return test_update(s, aug_inputs, targets), None
state, _ = jax.lax.scan(test_scanner, state, (test_batches, test_aug_keys))
return state
for epoch, key in zip(range(1, EPOCHS + 1), rng_stream):
print(f"epoch: {epoch}")
state = do_epoch(key, state, data)
print(state.train_metrics.compute())
print(state.test_metrics.compute())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment