Last active
March 25, 2020 03:51
-
-
Save nuzrub/6fcc46485f435b372ab0004618a26a98 to your computer and use it in GitHub Desktop.
Reproducing the main findings of the paper "Training BatchNorm and Only BatchNorm: On the Expressive Power of Random Features in CNNs"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reproducing the main findings of the paper "Training BatchNorm and Only BatchNorm: On the Expressive Power of Random Features in CNNs" | |
# Goal: Train a ResNet model to solve the CIFAR-10 dataset using only batchnorm layers, all else is frozen at their random initial state. | |
# https://medium.com/@ygorrebouasserpa | |
# https://www.linkedin.com/in/ygor-rebouças-serpa-11606093/ | |
import tensorflow as tf | |
import numpy as np | |
import pandas as pd | |
architectures = [ | |
('ResNet-50', tf.keras.applications.resnet.ResNet50), | |
('ResNet-101', tf.keras.applications.resnet.ResNet101), | |
('ResNet-152', tf.keras.applications.resnet.ResNet152)] | |
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data() | |
n_train_images = X_train.shape[0] | |
n_test_images = X_test.shape[0] | |
n_classes = np.max(y_train) + 1 | |
X_train = X_train.astype(np.float32) / 255 | |
X_test = X_test.astype(np.float32) / 255 | |
y_train = tf.keras.utils.to_categorical(y_train, n_classes) | |
y_test = tf.keras.utils.to_categorical(y_test, n_classes) | |
for name, architecture in architectures: | |
input = tf.keras.layers.Input((32, 32, 3)) | |
resnet = architecture(include_top=False, weights='imagenet', input_shape=(32, 32, 3), pooling='avg')(input) | |
output = tf.keras.layers.Dense(n_classes, activation='softmax')(resnet) | |
model = tf.keras.models.Model(inputs=input, outputs=output) | |
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01) | |
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True) | |
he_normal = tf.keras.initializers.he_normal() | |
for layer in model.layers[1].layers: | |
if layer.name.endswith('_bn'): | |
new_weights = [ | |
he_normal(layer.weights[0].shape), # Gamma | |
tf.zeros(layer.weights[1].shape), # Beta | |
tf.zeros(layer.weights[2].shape), # Mean | |
tf.ones(layer.weights[3].shape)] # Std | |
layer.set_weights(new_weights) | |
layer.trainable = True | |
else: | |
layer.trainable = False | |
model.summary() | |
model.compile(loss=loss_fn, optimizer=optimizer, metrics=['accuracy']) | |
print('Training ' + name + '...') | |
history = model.fit(X_train, y_train, batch_size=1024, epochs=1, validation_data=(X_test, y_test), shuffle=True) | |
history_df = pd.DataFrame(history.history) | |
print('Dumping model and history...') | |
history_df.to_csv(name + '.csv', sep=';') | |
model.save(name + '.h5') | |
print('Testing Complete!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment