Last active
August 4, 2021 19:49
-
-
Save FirefoxMetzger/6b6ccf4f7c344459507e73bbd13ec541 to your computer and use it in GitHub Desktop.
a residual network using Keras' Sequential() API training on CIFAR10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Train a simple residual network on the CIFAR10 small images dataset. | |
It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs. | |
(it's still underfitting at that point, though). | |
''' | |
from __future__ import print_function | |
import keras | |
from keras.datasets import cifar10 | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, Activation, Flatten | |
from keras.layers import Conv2D, MaxPooling2D, Add | |
import os | |
from keras.engine.topology import Layer | |
# Define the residual block as a new layer | |
class Residual(Layer): | |
def __init__(self, channels_in,kernel,**kwargs): | |
super(Residual, self).__init__(**kwargs) | |
self.channels_in = channels_in | |
self.kernel = kernel | |
def call(self, x): | |
# the residual block using Keras functional API | |
first_layer = Activation("linear", trainable=False)(x) | |
x = Conv2D( self.channels_in, | |
self.kernel, | |
padding="same")(first_layer) | |
x = Activation("relu")(x) | |
x = Conv2D( self.channels_in, | |
self.kernel, | |
padding="same")(x) | |
residual = Add()([x, first_layer]) | |
x = Activation("relu")(residual) | |
return x | |
def compute_output_shape(self, input_shape): | |
return input_shape | |
batch_size = 32 | |
num_classes = 10 | |
epochs = 100 | |
data_augmentation = True | |
num_predictions = 20 | |
save_dir = os.path.join(os.getcwd(), 'saved_models') | |
model_name = 'keras_cifar10_trained_model.h5' | |
# The data, split between train and test sets: | |
(x_train, y_train), (x_test, y_test) = cifar10.load_data() | |
print('x_train shape:', x_train.shape) | |
print(x_train.shape[0], 'train samples') | |
print(x_test.shape[0], 'test samples') | |
# Convert class vectors to binary class matrices. | |
y_train = keras.utils.to_categorical(y_train, num_classes) | |
y_test = keras.utils.to_categorical(y_test, num_classes) | |
model = Sequential() | |
model.add(Conv2D(32, (3, 3), padding='same', | |
input_shape=x_train.shape[1:])) | |
model.add(Activation('relu')) | |
model.add(Residual(32,(3,3))) | |
model.add(Residual(32,(3,3))) | |
model.add(Residual(32,(3,3))) | |
model.add(Residual(32,(3,3))) | |
model.add(Residual(32,(3,3))) | |
model.add(Flatten()) | |
model.add(Dense(512)) | |
model.add(Activation('relu')) | |
model.add(Dropout(0.5)) | |
model.add(Dense(num_classes)) | |
model.add(Activation('softmax')) | |
# initiate RMSprop optimizer | |
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6) | |
# Let's train the model using RMSprop | |
model.compile(loss='categorical_crossentropy', | |
optimizer=opt, | |
metrics=['accuracy']) | |
x_train = x_train.astype('float32') | |
x_test = x_test.astype('float32') | |
x_train /= 255 | |
x_test /= 255 | |
if not data_augmentation: | |
print('Not using data augmentation.') | |
model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_data=(x_test, y_test), | |
shuffle=True) | |
else: | |
print('Using real-time data augmentation.') | |
# This will do preprocessing and realtime data augmentation: | |
datagen = ImageDataGenerator( | |
featurewise_center=False, # set input mean to 0 over the dataset | |
samplewise_center=False, # set each sample mean to 0 | |
featurewise_std_normalization=False, # divide inputs by std of the dataset | |
samplewise_std_normalization=False, # divide each input by its std | |
zca_whitening=False, # apply ZCA whitening | |
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) | |
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) | |
height_shift_range=0.1, # randomly shift images vertically (fraction of total height) | |
horizontal_flip=True, # randomly flip images | |
vertical_flip=False) # randomly flip images | |
# Compute quantities required for feature-wise normalization | |
# (std, mean, and principal components if ZCA whitening is applied). | |
datagen.fit(x_train) | |
# Fit the model on the batches generated by datagen.flow(). | |
model.fit_generator(datagen.flow(x_train, y_train, | |
batch_size=batch_size), | |
epochs=epochs, | |
validation_data=(x_test, y_test), | |
workers=4) | |
# Save model and weights | |
if not os.path.isdir(save_dir): | |
os.makedirs(save_dir) | |
model_path = os.path.join(save_dir, model_name) | |
model.save(model_path) | |
print('Saved trained model at %s ' % model_path) | |
# Score trained model. | |
scores = model.evaluate(x_test, y_test, verbose=1) | |
print('Test loss:', scores[0]) | |
print('Test accuracy:', scores[1]) |
The skip connection is added in the Residual
class at line 34. The entire thing is then interpreted as a block / layer with only one input and output. That is, because a Sequential
model can't fork. It only goes from one layer to the next in sequence.
You can not see them in the model_to_dot call, because the type of model is Sequential
. Assuming you are using the tensorflow backend, you can look at the model in tensorboard. There, you can expand / zoom into blocks and will see that there is a skip connection in every layer.
Got It!
Thank You Sir !
This does not work on most recent. I get ValueError: tf.function-decorated function tried to create variables on non-first call.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Sir! I have executed your code to understand how Renet work but I am not able to understand that when I am trying to see the skip connections made by the network, it is displaying the resnets in connected in sequence, please explain as I am new to this. Apologies if I said wrong.
Waiting for your positive reply!
Thanks!