Last active
December 5, 2018 06:39
-
-
Save koshian2/aa697ceea8918303edfc0f21a9054c10 to your computer and use it in GitHub Desktop.
RICAP on CIFAR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Input, AveragePooling2D, GlobalAveragePooling2D, Dense | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from tensorflow.keras.callbacks import History, Callback | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
import tensorflow.keras.backend as K | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
from ricap import ricap | |
import numpy as np | |
import os, pickle, glob, zipfile | |
def create_block(input, ch, reps): | |
x = input | |
for i in range(reps): | |
x = Conv2D(ch, 3, padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
return x | |
def create_network(): | |
input = Input((32,32,3)) | |
x = create_block(input, 64, 3) | |
x = AveragePooling2D(2)(x) | |
x = create_block(x, 128, 3) | |
x = AveragePooling2D(2)(x) | |
x = create_block(x, 256, 3) | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(10, activation="softmax")(x) | |
return Model(input, x) | |
class RICAPGenerator(ImageDataGenerator): | |
def __init__(self, use_batchwise_random, beta, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.use_batchwise_random = use_batchwise_random | |
self.beta = beta | |
def flow(self, *args, **kwargs): | |
for X_batch, y_batch in super().flow(*args, **kwargs): | |
if self.beta > 0: | |
X_ricap, y_ricap = ricap(X_batch, y_batch, self.beta, self.use_batchwise_random) | |
else: | |
X_ricap, y_ricap = X_batch, y_batch # No RICAP | |
yield X_ricap, y_ricap | |
class RICAPBetaCallback(Callback): | |
def __init__(self, period, beta_min, beta_max, generator): | |
self.gen = generator | |
self.beta_min = beta_min | |
self.beta_max = beta_max | |
self.period = period | |
def on_epoch_end(self, epoch, logs): | |
beta = self.beta_min + (self.beta_max-self.beta_min)*(1.0-np.cos(epoch/self.period*np.pi))/2.0 | |
print("set beta to", beta) | |
self.gen.beta = beta | |
def train(period, beta_max): | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
batch_size = 2048 | |
train_gen_instance = RICAPGenerator(rescale=1.0/255.0, use_batchwise_random=False, beta=0.01) | |
train_generator = train_gen_instance.flow(X_train, y_train, batch_size=batch_size) | |
test_generator = ImageDataGenerator(rescale=1.0/255.0).flow( | |
X_test, y_test, batch_size=batch_size) | |
model = create_network() | |
model.compile(tf.train.AdamOptimizer(), "categorical_crossentropy", ["acc"]) | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
ricap_scheduler = RICAPBetaCallback(period, 0.01, beta_max, train_gen_instance) | |
if not os.path.exists("active_result"): | |
os.mkdir("active_result") | |
hist = History() | |
model.fit_generator(train_generator, steps_per_epoch=X_train.shape[0]//batch_size, | |
validation_data=test_generator, validation_steps=X_test.shape[0]//batch_size, | |
epochs=150, callbacks=[hist, ricap_scheduler]) | |
history = hist.history | |
with open(f"active_result/period_{period}_betamax_{beta_max}.dat", "wb") as fp: | |
pickle.dump(history, fp) | |
if __name__ == "__main__": | |
K.clear_session() | |
for period in [5,10,20,50,100]: | |
for betamax in [0.5, 1.0]: | |
print(period, betamax, "Starts") | |
train(period, betamax) | |
with zipfile.ZipFile("active_result.zip", "w") as zip: | |
for f in glob.glob("active_result/*.dat"): | |
zip.write(f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment