Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lucasdavid/2f0fa565357e3e20496dd99f93fcc98a to your computer and use it in GitHub Desktop.
Save lucasdavid/2f0fa565357e3e20496dd99f93fcc98a to your computer and use it in GitHub Desktop.
Training a classifier on cifar100, using full TB functionality
import tensorflow as tf
from keras.callbacks import TensorBoard
from keras.datasets import cifar100
from keras.layers import Dense, Conv2D, MaxPooling2D, GlobalAveragePooling2D, BatchNormalization, Activation, Dropout
from keras.layers import Input
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator
from sacred import Experiment
ex = Experiment('tb-efficiency')
@ex.config
def my_config():
epochs = 10
batch_size = 256
device = '/gpu:0'
def conv2d_bn(x, filters, kernel_size=(3, 3), dropout=0.2):
y = Conv2D(filters, kernel_size, use_bias=False)(x)
y = BatchNormalization()(y)
y = Activation('relu')(y)
return y
@ex.automain
def main(epochs, batch_size, device):
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
x_train, x_test = (x.astype(float) / 127.0 - 1
for x in (x_train, x_test))
g = ImageDataGenerator(validation_split=1 / 3)
train = g.flow(x_train, y_train, batch_size=batch_size, subset='training')
valid = g.flow(x_train, y_train, batch_size=batch_size, subset='validation')
with tf.device(device):
x = Input(shape=(32, 32, 3))
y = conv2d_bn(x, 32)
y = conv2d_bn(y, 32)
y = MaxPooling2D()(y)
y = conv2d_bn(y, 64)
y = conv2d_bn(y, 64)
y = MaxPooling2D()(y)
y = conv2d_bn(y, 128)
y = conv2d_bn(y, 128)
y = GlobalAveragePooling2D()(y)
y = Dropout(rate=0.5)(y)
y = Dense(100, activation='softmax')(y)
model = Model(x, y)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit_generator(train,
epochs=epochs,
validation_data=valid,
callbacks=[
TensorBoard(histogram_freq=1,
batch_size=batch_size,
write_grads=True)
],
verbose=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment