Created
December 5, 2018 06:39
-
-
Save koshian2/f300153f475cd5ac3bb86ff19b919bba to your computer and use it in GitHub Desktop.
RICAP on CIFAR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.applications import MobileNet | |
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from tensorflow.keras.callbacks import History, Callback | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
import tensorflow.keras.backend as K | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
from ricap import ricap | |
import numpy as np | |
import os, pickle, glob, zipfile | |
def create_network(): | |
net = MobileNet(include_top=False, weights="imagenet", input_shape=(128,128,3)) | |
for layer in net.layers: | |
if layer.name == "conv_dw_6": break | |
layer.trainable = False | |
x = GlobalAveragePooling2D()(net.layers[-1].output) | |
x = Dense(10, activation="softmax")(x) | |
return Model(net.inputs, x) | |
class RICAPGenerator(ImageDataGenerator): | |
def __init__(self, use_batchwise_random, beta, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.use_batchwise_random = use_batchwise_random | |
self.beta = beta | |
def flow(self, *args, **kwargs): | |
for X_batch, y_batch in super().flow(*args, **kwargs): | |
if self.beta > 0: | |
X_ricap, y_ricap = ricap(X_batch, y_batch, self.beta, self.use_batchwise_random) | |
else: | |
X_ricap, y_ricap = X_batch, y_batch # No RICAP | |
# 4倍に拡大 | |
X_ricap = X_ricap.repeat(4, axis=1).repeat(4, axis=2) | |
yield X_ricap, y_ricap | |
class RICAPBetaCallback(Callback): | |
def __init__(self, phase, beta_min, beta_max, generator, use_scheduling): | |
self.gen = generator | |
self.beta_min = beta_min | |
self.beta_max = beta_max | |
self.phase = phase | |
self.use_scheduling = use_scheduling | |
def on_epoch_end(self, epoch, logs): | |
if self.use_scheduling: | |
beta = self.beta_min + (self.beta_max-self.beta_min)*(1.0-np.cos(epoch/self.phase*np.pi))/2.0 | |
print("set beta to", beta) | |
self.gen.beta = beta | |
def train(use_scheduling, phase, beta_max): | |
print(use_scheduling, beta_max, "Starts") | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
X_test = X_test.repeat(4, axis=1).repeat(4, axis=2) | |
batch_size = 1024 | |
initial_beta = 0 | |
if use_scheduling: | |
if beta_max > 0: | |
initial_beta = 0.01 | |
else: | |
initial_beta = beta_max | |
train_gen_instance = RICAPGenerator(rescale=1.0/255.0, use_batchwise_random=False, beta=initial_beta) | |
train_generator = train_gen_instance.flow(X_train, y_train, batch_size=batch_size) | |
test_generator = ImageDataGenerator(rescale=1.0/255.0).flow( | |
X_test, y_test, batch_size=batch_size) | |
model = create_network() | |
model.compile(tf.train.RMSPropOptimizer(1e-4), "categorical_crossentropy", ["acc"]) | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
ricap_scheduler = RICAPBetaCallback(phase, 0.01, beta_max, train_gen_instance, use_scheduling) | |
if not os.path.exists("mobilenet_result"): | |
os.mkdir("mobilenet_result") | |
hist = History() | |
model.fit_generator(train_generator, steps_per_epoch=X_train.shape[0]//batch_size, | |
validation_data=test_generator, validation_steps=X_test.shape[0]//batch_size, | |
epochs=100, callbacks=[hist, ricap_scheduler]) | |
history = hist.history | |
with open(f"mobilenet_result/scheduling_{use_scheduling}_phase_{phase}_betamax_{beta_max}.dat", "wb") as fp: | |
pickle.dump(history, fp) | |
if __name__ == "__main__": | |
K.clear_session() | |
train(False, 0, 0) | |
train(False, 0, 0.3) | |
train(False, 0, 0.5) | |
train(True, 20, 0.5) | |
train(True, 50, 0.5) | |
train(True, 100, 0.5) | |
with zipfile.ZipFile("mobilenet_result.zip", "w") as zip: | |
for f in glob.glob("mobilenet_result/*.dat"): | |
zip.write(f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment