Skip to content

Instantly share code, notes, and snippets.

@varun19299
Created November 26, 2020 09:40
Show Gist options
  • Save varun19299/405ef27e85d79dee61d11a1b3cfacd5d to your computer and use it in GitHub Desktop.
Save varun19299/405ef27e85d79dee61d11a1b3cfacd5d to your computer and use it in GitHub Desktop.
WRN-22-2
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range
import os
import logging
logging.basicConfig(level=logging.DEBUG)
import sys
#sys.stdout = sys.stderr
# Prevent reaching to maximum recursion depth in `theano.tensor.grad`
#sys.setrecursionlimit(2 ** 20)
import numpy as np
np.random.seed(2 ** 10)
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, AveragePooling2D, BatchNormalization, Dropout, Input, Activation, Add, Dense, Flatten
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import backend as K
# ================================================
# NETWORK/TRAINING CONFIGURATION:
logging.debug("Loading network/training configuration...")
depth = 22 # table 5 on page 8 indicates best value (4.17) CIFAR-10
k = 2 # 'widen_factor'; table 5 on page 8 indicates best value (4.17) CIFAR-10
dropout_probability = 0 # table 6 on page 10 indicates best value (4.17) CIFAR-10
weight_decay = 0.0005 # page 10: "Used in all experiments"
batch_size = 128 # page 8: "Used in all experiments"
# Other config from code; throughtout all layer:
use_bias = False # following functions 'FCinit(model)' and 'DisableBias(model)' in utils.lua
weight_init="he_normal" # follows the 'MSRinit(model)' function in utils.lua
nb_classes = 10
# Keras specific
if K.image_data_format() == "th":
logging.debug("image_dim_ordering = 'th'")
channel_axis = 1
input_shape = (3, 32, 32)
else:
logging.debug("image_dim_ordering = 'tf'")
channel_axis = -1
input_shape = (32, 32, 3)
# Wide residual network http://arxiv.org/abs/1605.07146
def _wide_basic(n_input_plane, n_output_plane, stride):
def f(net):
# format of conv_params:
# [ [nb_col="kernel width", nb_row="kernel height",
# subsample="(stride_vertical,stride_horizontal)",
# border_mode="same" or "valid"] ]
# B(3,3): orignal <<basic>> block
conv_params = [ [3,3,stride,"same"],
[3,3,(1,1),"same"] ]
n_bottleneck_plane = n_output_plane
# Residual block
for i, v in enumerate(conv_params):
if i == 0:
if n_input_plane != n_output_plane:
net = BatchNormalization(axis=channel_axis)(net)
net = Activation("relu")(net)
convs = net
else:
convs = BatchNormalization(axis=channel_axis)(net)
convs = Activation("relu")(convs)
convs = Conv2D(n_bottleneck_plane,
(v[0],v[1]),
strides=v[2],
padding=v[3],
kernel_initializer=weight_init,
kernel_regularizer=l2(weight_decay),
use_bias=use_bias)(convs)
else:
convs = BatchNormalization(axis=channel_axis)(convs)
convs = Activation("relu")(convs)
if dropout_probability > 0:
convs = Dropout(dropout_probability)(convs)
convs = Conv2D(n_bottleneck_plane,
(v[0],v[1]),
strides=v[2],
padding=v[3],
kernel_initializer=weight_init,
kernel_regularizer=l2(weight_decay),
use_bias=use_bias)(convs)
# Shortcut Conntection: identity function or 1x1 convolutional
# (depends on difference between input & output shape - this
# corresponds to whether we are using the first block in each
# group; see _layer() ).
if n_input_plane != n_output_plane:
shortcut = Conv2D(n_output_plane,
(1,1),
strides=stride,
padding="same",
kernel_initializer=weight_init,
kernel_regularizer=l2(weight_decay),
use_bias=use_bias)(net)
else:
shortcut = net
return Add()([convs, shortcut])
return f
# "Stacking Residual Units on the same stage"
def _layer(block, n_input_plane, n_output_plane, count, stride):
def f(net):
net = block(n_input_plane, n_output_plane, stride)(net)
for i in range(2,int(count+1)):
net = block(n_output_plane, n_output_plane, stride=(1,1))(net)
return net
return f
def create_model():
logging.debug("Creating model...")
assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
inputs = Input(shape=input_shape)
n_stages=[16, 16*k, 32*k, 64*k]
conv1 = Conv2D(n_stages[0],
(3, 3),
strides=1,
padding="same",
kernel_initializer=weight_init,
kernel_regularizer=l2(weight_decay),
use_bias=use_bias)(inputs) # "One conv at the beginning (spatial size: 32x32)"
# Add wide residual blocks
block_fn = _wide_basic
conv2 = _layer(block_fn, n_input_plane=n_stages[0], n_output_plane=n_stages[1], count=n, stride=(1,1))(conv1)# "Stage 1 (spatial size: 32x32)"
conv3 = _layer(block_fn, n_input_plane=n_stages[1], n_output_plane=n_stages[2], count=n, stride=(2,2))(conv2)# "Stage 2 (spatial size: 16x16)"
conv4 = _layer(block_fn, n_input_plane=n_stages[2], n_output_plane=n_stages[3], count=n, stride=(2,2))(conv3)# "Stage 3 (spatial size: 8x8)"
batch_norm = BatchNormalization(axis=channel_axis)(conv4)
relu = Activation("relu")(batch_norm)
# Classifier block
pool = AveragePooling2D(pool_size=(8, 8), strides=(1, 1), padding="same")(relu)
flatten = Flatten()(pool)
predictions = Dense(units=nb_classes, kernel_initializer=weight_init, use_bias=use_bias,
kernel_regularizer=l2(weight_decay), activation="softmax")(flatten)
model = Model(inputs=inputs, outputs=predictions)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment