Skip to content

Instantly share code, notes, and snippets.

@FindHao
Created December 1, 2022 19:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FindHao/db610e0b7d38c362558f413793eb7207 to your computer and use it in GitHub Desktop.
Save FindHao/db610e0b7d38c362558f413793eb7207 to your computer and use it in GitHub Desktop.
an resnet example for jax
# https://github.com/phlippe/uvadlc_notebooks_benchmarking/blob/main/PyTorch/Tutorial5_Inception_ResNet_DenseNet.py
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data as data
import torch
from flax.training import train_state, checkpoints
from flax import linen as nn
from jax import random
import jax.numpy as jnp
import jax
from tqdm.auto import tqdm
import os
import numpy as np
from typing import Any
from collections import defaultdict
import time
import optax
DATASET_PATH = "../data"
CHECKPOINT_PATH = "../saved_models/tutorial5_jax"
timestr = time.strftime("%Y_%m_%d__%H_%M_%S")
LOG_FILE = open(f'../logs/tutorial5_jax_{timestr}.txt', 'w')
main_rng = random.PRNGKey(42)
print("Device:", jax.devices()[0])
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0, 1, 2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0, 1, 2))
def image_to_numpy(img):
img = np.array(img, dtype=np.float32)
img = (img / 255. - DATA_MEANS) / DATA_STD
return img
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
test_transform = image_to_numpy
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(
(32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
image_to_numpy
])
train_dataset = CIFAR10(root=DATASET_PATH, train=True,
transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True,
transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(
train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(
val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
test_set = CIFAR10(root=DATASET_PATH, train=False,
transform=test_transform, download=True)
train_loader = data.DataLoader(train_set,
batch_size=128,
shuffle=True,
drop_last=True,
collate_fn=numpy_collate,
num_workers=8,
persistent_workers=True)
val_loader = data.DataLoader(val_set,
batch_size=128,
shuffle=False,
drop_last=False,
collate_fn=numpy_collate,
num_workers=4,
persistent_workers=True)
test_loader = data.DataLoader(test_set,
batch_size=128,
shuffle=False,
drop_last=False,
collate_fn=numpy_collate,
num_workers=4,
persistent_workers=True)
class TrainState(train_state.TrainState):
# A simple extension of TrainState to also include batch statistics
batch_stats: Any
class TrainerModule:
def __init__(self,
model_name: str,
model_class: nn.Module,
model_hparams: dict,
optimizer_name: str,
optimizer_hparams: dict,
exmp_imgs: Any,
seed=42):
"""
Module for summarizing all training functionalities for classification on CIFAR10.
Inputs:
model_name - String of the class name, used for logging and saving
model_class - Class implementing the neural network
model_hparams - Hyperparameters of the model, used as input to model constructor
optimizer_name - String of the optimizer name, supporting ['sgd', 'adam', 'adamw']
optimizer_hparams - Hyperparameters of the optimizer, including learning rate as 'lr'
exmp_imgs - Example imgs, used as input to initialize the model
seed - Seed to use in the model initialization
"""
super().__init__()
self.model_name = model_name
self.model_class = model_class
self.model_hparams = model_hparams
self.optimizer_name = optimizer_name
self.optimizer_hparams = optimizer_hparams
self.seed = seed
# Create empty model. Note: no parameters yet
self.model = self.model_class(**self.model_hparams)
# Prepare logging
self.log_dir = os.path.join(CHECKPOINT_PATH, self.model_name)
self.logger = SummaryWriter(log_dir=self.log_dir)
# Create jitted training and eval functions
self.create_functions()
# Initialize model
self.init_model(exmp_imgs)
def create_functions(self):
# Function to calculate the classification loss and accuracy for a model
def calculate_loss(params, batch_stats, batch, train):
imgs, labels = batch
labels_onehot = jax.nn.one_hot(
labels, num_classes=self.model.num_classes)
# Run model. During training, we need to update the BatchNorm statistics.
outs = self.model.apply({'params': params, 'batch_stats': batch_stats},
imgs,
train=train,
mutable=['batch_stats'] if train else False)
logits, new_model_state = outs if train else (outs, None)
loss = optax.softmax_cross_entropy(logits, labels_onehot).mean()
acc = (logits.argmax(axis=-1) == labels).mean()
return loss, (acc, new_model_state)
# Training function
def train_step(state, batch):
def loss_fn(params): return calculate_loss(
params, state.batch_stats, batch, train=True)
# Get loss, gradients for loss, and other outputs of loss function
ret, grads = jax.value_and_grad(
loss_fn, has_aux=True)(state.params)
loss, acc, new_model_state = ret[0], *ret[1]
# Update parameters and batch statistics
state = state.apply_gradients(
grads=grads, batch_stats=new_model_state['batch_stats'])
return state, loss, acc
# Eval function
def eval_step(state, batch):
# Return the accuracy for a single batch
_, (acc, _) = calculate_loss(state.params,
state.batch_stats, batch, train=False)
return acc
# jit for efficiency
self.train_step = jax.jit(train_step)
self.eval_step = jax.jit(eval_step)
def init_model(self, exmp_imgs):
# Initialize model
init_rng = jax.random.PRNGKey(self.seed)
variables = self.model.init(init_rng, exmp_imgs, train=True)
self.init_params, self.init_batch_stats = variables['params'], variables['batch_stats']
self.state = None
def init_optimizer(self, num_epochs, num_steps_per_epoch):
# Initialize learning rate schedule and optimizer
if self.optimizer_name.lower() == 'adam':
opt_class = optax.adam
elif self.optimizer_name.lower() == 'adamw':
opt_class = optax.adamw
elif self.optimizer_name.lower() == 'sgd':
opt_class = optax.sgd
else:
assert False, f'Unknown optimizer "{opt_class}"'
# We decrease the learning rate by a factor of 0.1 after 60% and 85% of the training
lr_schedule = optax.piecewise_constant_schedule(
init_value=self.optimizer_hparams.pop('lr'),
boundaries_and_scales={int(num_steps_per_epoch*num_epochs*0.6): 0.1,
int(num_steps_per_epoch*num_epochs*0.85): 0.1}
)
# Clip gradients at max value, and evt. apply weight decay
transf = [optax.clip(1.0)]
if opt_class == optax.sgd and 'weight_decay' in self.optimizer_hparams: # wd is integrated in adamw
transf.append(optax.add_decayed_weights(
self.optimizer_hparams.pop('weight_decay')))
optimizer = optax.chain(
*transf,
opt_class(lr_schedule, **self.optimizer_hparams)
)
# Initialize training state
self.state = TrainState.create(apply_fn=self.model.apply,
params=self.init_params if self.state is None else self.state.params,
batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats,
tx=optimizer)
def train_model(self, train_loader, val_loader, num_epochs=200):
# Train model for defined number of epochs
# We first need to create optimizer and the scheduler for the given number of epochs
self.init_optimizer(num_epochs, len(train_loader))
# Track best eval accuracy
best_eval = 0.0
for epoch_idx in tqdm(range(1, num_epochs+1)):
self.train_epoch(train_loader, epoch=epoch_idx)
if epoch_idx % 2 == 0:
eval_acc = self.eval_model(val_loader)
self.logger.add_scalar(
'val/acc', eval_acc, global_step=epoch_idx)
if eval_acc >= best_eval:
best_eval = eval_acc
self.save_model(step=epoch_idx)
self.logger.flush()
def train_epoch(self, train_loader, epoch):
# Train model for one epoch, and log avg loss and accuracy
metrics = defaultdict(list)
for batch in tqdm(train_loader, desc='Training', leave=False):
self.state, loss, acc = self.train_step(self.state, batch)
metrics['loss'].append(loss)
metrics['acc'].append(acc)
for key in metrics:
avg_val = np.stack(jax.device_get(metrics[key])).mean()
self.logger.add_scalar('train/'+key, avg_val, global_step=epoch)
def eval_model(self, data_loader):
# Test model on all images of a data loader and return avg loss
correct_class, count = 0, 0
for batch in data_loader:
acc = self.eval_step(self.state, batch)
correct_class += acc * batch[0].shape[0]
count += batch[0].shape[0]
eval_acc = (correct_class / count).item()
return eval_acc
def save_model(self, step=0):
# Save current model at certain training iteration
checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
target={'params': self.state.params,
'batch_stats': self.state.batch_stats},
step=step,
overwrite=True)
def load_model(self, pretrained=False):
# Load model. We use different checkpoint for pretrained models
if not pretrained:
state_dict = checkpoints.restore_checkpoint(
ckpt_dir=self.log_dir, target=None)
else:
state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(
CHECKPOINT_PATH, f'{self.model_name}.ckpt'), target=None)
self.state = TrainState.create(apply_fn=self.model.apply,
params=state_dict['params'],
batch_stats=state_dict['batch_stats'],
tx=self.state.tx if self.state else optax.sgd(
0.1) # Default optimizer
)
def checkpoint_exists(self):
# Check whether a pretrained model exist for this autoencoder
return os.path.isfile(os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'))
def train_classifier(*args, num_epochs=200, **kwargs):
# Create a trainer module with specified hyperparameters
trainer = TrainerModule(*args, **kwargs)
start_time = time.time()
with jax.profiler.trace("/tmp/jax-trace"):
trainer.train_model(train_loader, val_loader, num_epochs=num_epochs)
train_time = time.time()
print(trainer.model_name, ' - Full training time:',
time.strftime('%H:%M:%S', time.gmtime(train_time - start_time)),
file=LOG_FILE, flush=True)
return None, None
resnet_kernel_init = nn.initializers.variance_scaling(
2.0, mode='fan_out', distribution='normal')
class ResNetBlock(nn.Module):
act_fn: callable # Activation function
c_out: int # Output feature size
subsample: bool = False # If True, we apply a stride inside F
@nn.compact
def __call__(self, x, train=True):
# Network representing F
z = nn.Conv(self.c_out, kernel_size=(3, 3),
strides=(1, 1) if not self.subsample else (2, 2),
kernel_init=resnet_kernel_init,
use_bias=False)(x)
z = nn.BatchNorm()(z, use_running_average=not train)
z = self.act_fn(z)
z = nn.Conv(self.c_out, kernel_size=(3, 3),
kernel_init=resnet_kernel_init,
use_bias=False)(z)
z = nn.BatchNorm()(z, use_running_average=not train)
if self.subsample:
x = nn.Conv(self.c_out, kernel_size=(1, 1), strides=(
2, 2), kernel_init=resnet_kernel_init)(x)
x_out = self.act_fn(z + x)
return x_out
class ResNet(nn.Module):
num_classes: int
act_fn: callable
block_class: nn.Module
num_blocks: tuple = (3, 3, 3)
c_hidden: tuple = (16, 32, 64)
@nn.compact
def __call__(self, x, train=True):
# A first convolution on the original image to scale up the channel size
x = nn.Conv(self.c_hidden[0], kernel_size=(
3, 3), kernel_init=resnet_kernel_init, use_bias=False)(x)
if self.block_class == ResNetBlock: # If pre-activation block, we do not apply non-linearities yet
x = nn.BatchNorm()(x, use_running_average=not train)
x = self.act_fn(x)
# Creating the ResNet blocks
for block_idx, block_count in enumerate(self.num_blocks):
for bc in range(block_count):
# Subsample the first block of each group, except the very first one.
subsample = (bc == 0 and block_idx > 0)
# ResNet block
x = self.block_class(c_out=self.c_hidden[block_idx],
act_fn=self.act_fn,
subsample=subsample)(x, train=train)
# Mapping to classification output
x = x.mean(axis=(1, 2))
x = nn.Dense(self.num_classes)(x)
return x
resnet_trainer, resnet_results = train_classifier(model_name="ResNet",
model_class=ResNet,
model_hparams={"num_classes": 10,
"c_hidden": (16, 32, 64),
"num_blocks": (3, 3, 3),
"act_fn": nn.relu,
"block_class": ResNetBlock},
optimizer_name="SGD",
optimizer_hparams={"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4},
exmp_imgs=jax.device_put(
next(iter(train_loader))[0]),
num_epochs=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment