Skip to content

Instantly share code, notes, and snippets.

@nuzrub
Last active March 25, 2020 03:51
Show Gist options
  • Save nuzrub/6fcc46485f435b372ab0004618a26a98 to your computer and use it in GitHub Desktop.
Save nuzrub/6fcc46485f435b372ab0004618a26a98 to your computer and use it in GitHub Desktop.
Reproducing the main findings of the paper "Training BatchNorm and Only BatchNorm: On the Expressive Power of Random Features in CNNs"
# Reproducing the main findings of the paper "Training BatchNorm and Only BatchNorm: On the Expressive Power of Random Features in CNNs"
# Goal: Train a ResNet model to solve the CIFAR-10 dataset using only batchnorm layers, all else is frozen at their random initial state.
# https://medium.com/@ygorrebouasserpa
# https://www.linkedin.com/in/ygor-rebouças-serpa-11606093/
import tensorflow as tf
import numpy as np
import pandas as pd
architectures = [
('ResNet-50', tf.keras.applications.resnet.ResNet50),
('ResNet-101', tf.keras.applications.resnet.ResNet101),
('ResNet-152', tf.keras.applications.resnet.ResNet152)]
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
n_train_images = X_train.shape[0]
n_test_images = X_test.shape[0]
n_classes = np.max(y_train) + 1
X_train = X_train.astype(np.float32) / 255
X_test = X_test.astype(np.float32) / 255
y_train = tf.keras.utils.to_categorical(y_train, n_classes)
y_test = tf.keras.utils.to_categorical(y_test, n_classes)
for name, architecture in architectures:
input = tf.keras.layers.Input((32, 32, 3))
resnet = architecture(include_top=False, weights='imagenet', input_shape=(32, 32, 3), pooling='avg')(input)
output = tf.keras.layers.Dense(n_classes, activation='softmax')(resnet)
model = tf.keras.models.Model(inputs=input, outputs=output)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
he_normal = tf.keras.initializers.he_normal()
for layer in model.layers[1].layers:
if layer.name.endswith('_bn'):
new_weights = [
he_normal(layer.weights[0].shape), # Gamma
tf.zeros(layer.weights[1].shape), # Beta
tf.zeros(layer.weights[2].shape), # Mean
tf.ones(layer.weights[3].shape)] # Std
layer.set_weights(new_weights)
layer.trainable = True
else:
layer.trainable = False
model.summary()
model.compile(loss=loss_fn, optimizer=optimizer, metrics=['accuracy'])
print('Training ' + name + '...')
history = model.fit(X_train, y_train, batch_size=1024, epochs=1, validation_data=(X_test, y_test), shuffle=True)
history_df = pd.DataFrame(history.history)
print('Dumping model and history...')
history_df.to_csv(name + '.csv', sep=';')
model.save(name + '.h5')
print('Testing Complete!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment