Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created May 14, 2019 00:05
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/2b8973e53ef630ad8170524673dbd959 to your computer and use it in GitHub Desktop.
Save koshian2/2b8973e53ef630ad8170524673dbd959 to your computer and use it in GitHub Desktop.
GPU vs TPU 1
import tensorflow as tf
import tensorflow.python.keras as keras
import tensorflow.python.keras.layers as layers
from tensorflow.contrib.tpu.python.tpu import keras_support
import datetime
import time
import pickle
import os
def conv_bn_relu(input, ch, reps):
x = input
for i in range(reps):
x = layers.Conv2D(ch, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
return x
def create_10layers_model():
input = layers.Input((32,32,3))
x = conv_bn_relu(input, 64, 3)
x = layers.AveragePooling2D(2)(x)
x = conv_bn_relu(x, 128, 3)
x = layers.AveragePooling2D(2)(x)
x = conv_bn_relu(x, 256, 3)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation="softmax")(x)
return keras.models.Model(input, x)
def _create_normal_residual_block(inputs, ch, N, stride):
# Conv with skip connections
x = inputs
for i in range(N):
# adjust channels
if i == 0:
skip = layers.Conv2D(ch, 1, strides=stride)(x)
skip = layers.BatchNormalization()(skip)
skip = layers.Activation("relu")(skip)
else:
skip = x
s = stride if i == 0 else 1 # ダウンサンプリング
x = layers.Conv2D(ch, 3, padding="same", strides=s)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(ch, 3, padding="same", strides=1)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Add()([x, skip])
return x
def create_normal_wide_resnet(N=4, k=10):
"""
Create vanilla conv Wide ResNet (N=4, k=10)
"""
# input
input = layers.Input((32,32,3))
# 16 channels block
x = layers.Conv2D(16, 3, padding="same")(input)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
# 1st block
x = _create_normal_residual_block(x, 16*k, N, 1)
# 2nd block
x = _create_normal_residual_block(x, 32*k, N, 2)
# 3rd block
x = _create_normal_residual_block(x, 64*k, N, 2)
# FC
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation="softmax")(x)
model = keras.models.Model(input, x)
return model
class TimeCallback(keras.callbacks.Callback):
def __init__(self):
self.times = []
self.last_time = time.time()
def on_train_begin(self, logs):
self.train_begin = datetime.datetime.now()
def on_train_end(self, logs):
self.train_end = datetime.datetime.now()
def on_epoch_end(self, epoch, logs):
current_time = time.time()
self.times.append(current_time - self.last_time)
self.last_time = current_time
def train(batch_size, network, device):
assert device in ["gpu", "multigpu", "tpu"]
assert network in [0, 1]
(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
X_train, X_test = X_train/255.0, X_test/255.0
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
if network == 0:
model = create_10layers_model()
elif network == 1:
model = create_normal_wide_resnet()
if device == "multigpu":
model = keras.utils.multi_gpu_model(model, gpus=2)
initial_lr = 0.1 * batch_size / 128
model.compile(keras.optimizers.SGD(initial_lr, 0.9), "categorical_crossentropy", ["acc"])
def scheduler(epoch):
x = initial_lr
if x >= 50: x /= 10.0
if x >= 80: x /= 10.0
return x
if device == "tpu":
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
hist = keras.callbacks.History()
lr_step = keras.callbacks.LearningRateScheduler(scheduler)
times = TimeCallback()
model.fit(X_train, y_train, validation_data=(X_test, y_test),
batch_size=batch_size, epochs=100, callbacks=[hist, times, lr_step],
verbose=1 if device != "tpu" else 0)
history = {**hist.history,
"train_begin":times.train_begin,
"train_end":times.train_end,
"times":times.times}
if not os.path.exists("result"):
os.mkdir("result")
with open(f"result/{device}_{network}_{batch_size}.pkl", "wb") as fp:
pickle.dump(history, fp)
# main
def train_gpus():
network = [0, 1]
device = ["gpu", "multigpu"]
batch_size = [128, 256, 512, 1024, 2048]
for net in network:
for dev in device:
for batch in batch_size:
# WideResNetは大きすぎるバッチでは計測しない
if net == 1 and batch > 512:
continue
keras.backend.clear_session()
print("Network", net, dev, "Batch", batch, "Starts")
train(batch, net, dev)
def train_tpu():
tf.logging.set_verbosity(tf.logging.FATAL)
network = [0, 1]
batch_size = [128, 256, 512, 1024, 2048]
for net in network:
for batch in batch_size:
keras.backend.clear_session()
print("Network", net, "Batch", batch, "Starts")
train(batch, net, "tpu")
if __name__ == "__main__":
train_gpus()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment