Created
February 14, 2022 16:23
-
-
Save fabianp/3ead88c85bc43836dde19cc425d5e81b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2021 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Resnet example with Flax and JAXopt. | |
==================================== | |
""" | |
from absl import app | |
from absl import flags | |
from datetime import datetime | |
from functools import partial | |
from typing import Any, Callable, Sequence, Tuple | |
from flax import linen as nn | |
import jax | |
import jax.numpy as jnp | |
from jaxopt import loss | |
from jaxopt import OptaxSolver | |
from jaxopt import tree_util | |
import optax | |
import tensorflow_datasets as tfds | |
import tensorflow as tf | |
from matplotlib import pyplot as plt | |
import torchvision | |
from torchvision import transforms | |
import torch | |
from torch.utils.data import Dataset | |
dataset_names = [ | |
"mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100" | |
] | |
flags.DEFINE_float("l2reg", 0., "L2 regularization.") | |
flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") | |
flags.DEFINE_integer("epochs", 60, "Number of passes over the dataset.") | |
flags.DEFINE_float("momentum", 0.9, "Momentum strength.") | |
flags.DEFINE_enum("dataset", "cifar10", dataset_names, "Dataset to train on.") | |
flags.DEFINE_enum("model", "resnet18", ["resnet1", "resnet18", "resnet34"], | |
"Model architecture.") | |
flags.DEFINE_integer("train_batch_size", 128, "Batch size at train time.") | |
flags.DEFINE_integer("test_batch_size", 1024, "Batch size at test time.") | |
FLAGS = flags.FLAGS | |
# def load_dataset(split, *, is_training, batch_size): | |
# version = 3 | |
# ds, ds_info = tfds.load( | |
# f"{FLAGS.dataset}:{version}.*.*", | |
# as_supervised=True, # remove useless keys | |
# split=split, | |
# with_info=True) | |
# ds = ds.cache().repeat() | |
# if is_training: | |
# ds = ds.shuffle(10 * batch_size, seed=0) | |
# ds = ds.batch(batch_size) | |
# return iter(tfds.as_numpy(ds)), ds_info | |
def load_dataset(split, *, is_training, batch_size): | |
if is_training: | |
transform_function = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), | |
(0.2023, 0.1994, 0.2010)), | |
]) | |
else: | |
transform_function = transforms.Compose([ | |
# transforms.RandomCrop(32, padding=4), | |
# transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), | |
(0.2023, 0.1994, 0.2010)), | |
]) | |
dataset = torchvision.datasets.CIFAR10( | |
root=".", | |
train=is_training, | |
download=True, | |
transform=transform_function) | |
trainloader = torch.utils.data.DataLoader( | |
dataset, batch_size=128, shuffle=True, num_workers=2) | |
return trainloader | |
class ResNetBlock(nn.Module): | |
"""ResNet block.""" | |
filters: int | |
conv: Any | |
norm: Any | |
act: Callable | |
strides: Tuple[int, int] = (1, 1) | |
@nn.compact | |
def __call__(self, x,): | |
residual = x | |
y = self.conv(self.filters, (3, 3), self.strides)(x) | |
y = self.norm()(y) | |
y = self.act(y) | |
y = self.conv(self.filters, (3, 3))(y) | |
y = self.norm(scale_init=nn.initializers.zeros)(y) | |
if residual.shape != y.shape: | |
residual = self.conv(self.filters, (1, 1), | |
self.strides, name='conv_proj')(residual) | |
residual = self.norm(name='norm_proj')(residual) | |
return self.act(residual + y) | |
class ResNet(nn.Module): | |
"""ResNetV1.""" | |
stage_sizes: Sequence[int] | |
block_cls: Any | |
num_classes: int | |
num_filters: int = 64 | |
dtype: Any = jnp.float32 | |
act: Callable = nn.relu | |
@nn.compact | |
def __call__(self, x, train: bool = True): | |
conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) | |
norm = partial(nn.GroupNorm, | |
num_groups=8, | |
dtype=self.dtype) | |
x = conv(self.num_filters, (7, 7), (2, 2), | |
padding=[(3, 3), (3, 3)], | |
name='conv_init')(x) | |
x = norm(name='bn_init')(x) | |
x = nn.relu(x) | |
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') | |
for i, block_size in enumerate(self.stage_sizes): | |
for j in range(block_size): | |
strides = (2, 2) if i > 0 and j == 0 else (1, 1) | |
x = self.block_cls(self.num_filters * 2 ** i, | |
strides=strides, | |
conv=conv, | |
norm=norm, | |
act=self.act)(x) | |
x = jnp.mean(x, axis=(1, 2)) | |
x = nn.Dense(self.num_classes, dtype=self.dtype)(x) | |
x = jnp.asarray(x, self.dtype) | |
return x | |
ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock) | |
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) | |
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) | |
def main(argv): | |
del argv | |
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make | |
# it unavailable to JAX. | |
tf.config.experimental.set_visible_devices([], 'GPU') | |
trainloader_ds = load_dataset("train", is_training=True, | |
batch_size=FLAGS.train_batch_size) | |
testloader_ds = load_dataset("test", is_training=False, | |
batch_size=FLAGS.test_batch_size) | |
# for inputs, targets in train_ds: | |
# print(inputs) | |
# 1/0 | |
train_ds = iter(trainloader_ds) | |
test_ds = iter(testloader_ds) | |
images, labels = next(train_ds) | |
input_shape = (1,) + images[0].shape | |
num_classes = 10 | |
iter_per_epoch = 50000 // FLAGS.train_batch_size | |
# Set up model. | |
if FLAGS.model == "resnet1": | |
net = ResNet1(num_classes=num_classes) | |
elif FLAGS.model == "resnet18": | |
net = ResNet18(num_classes=num_classes) | |
elif FLAGS.model == "resnet34": | |
net = ResNet34(num_classes=num_classes) | |
else: | |
raise ValueError("Unknown model.") | |
def predict(params, inputs, train=False): | |
x = inputs.astype(jnp.float32) / 255. | |
all_params = {"params": params} | |
if train: | |
# Returns logits and net_state (which contains the key "batch_stats"). | |
return net.apply(all_params, x, train=True) | |
else: | |
# Returns logits only. | |
return net.apply(all_params, x, train=False) | |
logistic_loss = jax.vmap(loss.multiclass_logistic_loss) | |
def loss_from_logits(params, l2reg, logits, labels): | |
mean_loss = jnp.mean(logistic_loss(labels, logits)) | |
sqnorm = tree_util.tree_l2_norm(params, squared=True) | |
return mean_loss + 0.5 * l2reg * sqnorm | |
def accuracy_and_loss(params, l2reg, data): | |
inputs, labels = data | |
logits = predict(params, inputs) | |
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels) | |
loss = loss_from_logits(params, l2reg, logits, labels) | |
return accuracy, loss | |
def loss_fun(params, l2reg, data): | |
inputs, labels = data | |
logits = predict(params, inputs, train=True) | |
loss = loss_from_logits(params, l2reg, logits, labels) | |
# batch_stats will be stored in state.aux | |
return loss#, net_state["batch_stats"] | |
# Initialize solver. | |
opt = optax.adam(learning_rate=FLAGS.learning_rate) | |
# momentum=FLAGS.momentum, | |
# nesterov=True) | |
# We need has_aux=True because loss_fun returns batch_stats. | |
solver = OptaxSolver(opt=opt, fun=loss_fun, maxiter=FLAGS.epochs * iter_per_epoch) | |
# Initialize parameters. | |
rng = jax.random.PRNGKey(0) | |
init_vars = net.init(rng, jnp.zeros(input_shape), train=True) | |
params = init_vars["params"] | |
#batch_stats = init_vars["batch_stats"] | |
start = datetime.now().replace(microsecond=0) | |
# Run training loop. | |
state = solver.init_state(params) | |
jitted_update = jax.jit(solver.update) | |
# jitted_update = solver.update | |
train_accuracy = [] | |
test_accuracy = [] | |
for epoch in range(FLAGS.epochs): | |
# train_ds = load_dataset("train", is_training=True, | |
# batch_size=FLAGS.train_batch_size) | |
# test_ds = load_dataset("test", is_training=False, | |
# batch_size=FLAGS.test_batch_size) | |
# for inputs, targets in train_ds: | |
# print(inputs) | |
# 1/0 | |
train_ds = iter(trainloader_ds) | |
test_ds = iter(testloader_ds) | |
for train_minibatch in train_ds: | |
# convert to JAX arrays | |
train_minibatch = jnp.asarray(train_minibatch[0]), jnp.asarray(train_minibatch[1]) | |
# import pdb; pdb.set_trace() | |
if state.iter_num % iter_per_epoch == iter_per_epoch - 1: | |
# Once per epoch evaluate the model on the train and test sets. | |
train_acc, train_loss = accuracy_and_loss(params, FLAGS.l2reg, train_minibatch) # TODO: this just computes the same thing many times | |
test_minibatch = next(test_ds) | |
test_minibatch = jnp.asarray(test_minibatch[0]), jnp.asarray(test_minibatch[1]) | |
test_acc, test_loss = accuracy_and_loss(params, FLAGS.l2reg, test_minibatch) | |
train_acc = jax.device_get(train_acc) | |
train_accuracy.append(train_acc) | |
train_loss = jax.device_get(train_loss) | |
test_acc = jax.device_get(test_acc) | |
test_accuracy.append(test_acc) | |
test_loss = jax.device_get(test_loss) | |
# time elapsed without microseconds | |
time_elapsed = (datetime.now().replace(microsecond=0) - start) | |
print(f"[Epoch {state.iter_num // (iter_per_epoch+1)}/{FLAGS.epochs}] " | |
f"Train acc: {train_acc:.3f}, train loss: {train_loss:.3f}. " | |
f"Test acc: {test_acc:.3f}, test loss: {test_loss:.3f}. " | |
f"Time elapsed: {time_elapsed}") | |
params, state = jitted_update(params=params, | |
state=state, | |
l2reg=FLAGS.l2reg, | |
data=train_minibatch) | |
#batch_stats = state.aux | |
plt.title(FLAGS.dataset) | |
plt.plot(test_accuracy, lw=3, label='test accuracy') | |
plt.plot(train_accuracy, lw=3, label='train accuracy') | |
plt.grid() | |
plt.legend() | |
plt.show() | |
if __name__ == "__main__": | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment