Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created December 5, 2018 06:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/f300153f475cd5ac3bb86ff19b919bba to your computer and use it in GitHub Desktop.
Save koshian2/f300153f475cd5ac3bb86ff19b919bba to your computer and use it in GitHub Desktop.
RICAP on CIFAR
import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import History, Callback
from tensorflow.contrib.tpu.python.tpu import keras_support
import tensorflow.keras.backend as K
from keras.datasets import cifar10
from keras.utils import to_categorical
from ricap import ricap
import numpy as np
import os, pickle, glob, zipfile
def create_network():
net = MobileNet(include_top=False, weights="imagenet", input_shape=(128,128,3))
for layer in net.layers:
if layer.name == "conv_dw_6": break
layer.trainable = False
x = GlobalAveragePooling2D()(net.layers[-1].output)
x = Dense(10, activation="softmax")(x)
return Model(net.inputs, x)
class RICAPGenerator(ImageDataGenerator):
def __init__(self, use_batchwise_random, beta, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_batchwise_random = use_batchwise_random
self.beta = beta
def flow(self, *args, **kwargs):
for X_batch, y_batch in super().flow(*args, **kwargs):
if self.beta > 0:
X_ricap, y_ricap = ricap(X_batch, y_batch, self.beta, self.use_batchwise_random)
else:
X_ricap, y_ricap = X_batch, y_batch # No RICAP
# 4倍に拡大
X_ricap = X_ricap.repeat(4, axis=1).repeat(4, axis=2)
yield X_ricap, y_ricap
class RICAPBetaCallback(Callback):
def __init__(self, phase, beta_min, beta_max, generator, use_scheduling):
self.gen = generator
self.beta_min = beta_min
self.beta_max = beta_max
self.phase = phase
self.use_scheduling = use_scheduling
def on_epoch_end(self, epoch, logs):
if self.use_scheduling:
beta = self.beta_min + (self.beta_max-self.beta_min)*(1.0-np.cos(epoch/self.phase*np.pi))/2.0
print("set beta to", beta)
self.gen.beta = beta
def train(use_scheduling, phase, beta_max):
print(use_scheduling, beta_max, "Starts")
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
X_test = X_test.repeat(4, axis=1).repeat(4, axis=2)
batch_size = 1024
initial_beta = 0
if use_scheduling:
if beta_max > 0:
initial_beta = 0.01
else:
initial_beta = beta_max
train_gen_instance = RICAPGenerator(rescale=1.0/255.0, use_batchwise_random=False, beta=initial_beta)
train_generator = train_gen_instance.flow(X_train, y_train, batch_size=batch_size)
test_generator = ImageDataGenerator(rescale=1.0/255.0).flow(
X_test, y_test, batch_size=batch_size)
model = create_network()
model.compile(tf.train.RMSPropOptimizer(1e-4), "categorical_crossentropy", ["acc"])
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
ricap_scheduler = RICAPBetaCallback(phase, 0.01, beta_max, train_gen_instance, use_scheduling)
if not os.path.exists("mobilenet_result"):
os.mkdir("mobilenet_result")
hist = History()
model.fit_generator(train_generator, steps_per_epoch=X_train.shape[0]//batch_size,
validation_data=test_generator, validation_steps=X_test.shape[0]//batch_size,
epochs=100, callbacks=[hist, ricap_scheduler])
history = hist.history
with open(f"mobilenet_result/scheduling_{use_scheduling}_phase_{phase}_betamax_{beta_max}.dat", "wb") as fp:
pickle.dump(history, fp)
if __name__ == "__main__":
K.clear_session()
train(False, 0, 0)
train(False, 0, 0.3)
train(False, 0, 0.5)
train(True, 20, 0.5)
train(True, 50, 0.5)
train(True, 100, 0.5)
with zipfile.ZipFile("mobilenet_result.zip", "w") as zip:
for f in glob.glob("mobilenet_result/*.dat"):
zip.write(f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment