Skip to content

Instantly share code, notes, and snippets.

@louiskirsch
Created June 7, 2018 06:24
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save louiskirsch/308ff50c7f15191c8fe6582be3c810f0 to your computer and use it in GitHub Desktop.
Save louiskirsch/308ff50c7f15191c8fe6582be3c810f0 to your computer and use it in GitHub Desktop.
population-based-training-tensorflow
import tensorflow as tf
import tensorflow.contrib as tfc
import numpy as np
import matplotlib.pyplot as plt
import observations
from functools import lru_cache
tf.reset_default_graph()
train_data, test_data = observations.cifar10('data/cifar',)
test_data = test_data[0], test_data[1].astype(np.uint8) # Fix test_data dtype
train = tf.data.Dataset.from_tensor_slices(train_data).repeat().shuffle(10000).batch(64)
test = tf.data.Dataset.from_tensors(test_data).repeat()
handle = tf.placeholder(tf.string, [])
itr = tf.data.Iterator.from_string_handle(handle, train.output_types, train.output_shapes)
inputs, labels = itr.get_next()
def make_handle(sess, dataset):
iterator = dataset.make_initializable_iterator()
handle, _ = sess.run([iterator.string_handle(), iterator.initializer])
return handle
inputs = tf.cast(inputs, tf.float32) / 255.0
inputs = tf.layers.flatten(inputs)
labels = tf.cast(labels, tf.int32)
class Model:
def __init__(self, model_id: int, regularize=True):
self.model_id = model_id
self.name_scope = tf.get_default_graph().get_name_scope()
# Regularization
if regularize:
l1_reg = self._create_regularizer()
else:
l1_reg = None
# Network and loglikelihood
logits = self._create_network(l1_reg)
# We maximixe the loglikelihood of the data as a training objective
distr = tf.distributions.Categorical(logits)
loglikelihood = distr.log_prob(labels)
# Define accuracy of prediction
prediction = tf.argmax(logits, axis=-1, output_type=tf.int32)
self.accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, labels), tf.float32))
# Loss and optimization
self.loss = -tf.reduce_mean(loglikelihood)
# Retrieve all weights and hyper-parameter variables of this model
trainable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name_scope + '/')
# The loss to optimize is the negative loglikelihood + the l1-regularizer
reg_loss = self.loss + tf.losses.get_regularization_loss()
self.optimize = tf.train.AdamOptimizer().minimize(reg_loss, var_list=trainable)
def _create_network(self, l1_reg):
# Our deep neural network will have two hidden layers with plenty of units
hidden = tf.layers.dense(inputs, 1024, activation=tf.nn.relu,
kernel_regularizer=l1_reg)
hidden = tf.layers.dense(hidden, 1024, activation=tf.nn.relu,
kernel_regularizer=l1_reg)
logits = tf.layers.dense(hidden, 10,
kernel_regularizer=l1_reg)
return logits
def _create_regularizer(self):
# We will define the l1 regularizer scale in log2 space
# This allows changing one unit to half or double the effective l1 scale
self.l1_scale = tf.get_variable('l1_scale', [], tf.float32, trainable=False,
initializer=tf.constant_initializer(np.log2(1e-5)))
# We define a 'pertub' operation that adds some noise to our regularizer scale
# We will use this pertubation during exploration in our population based training
noise = tf.random_normal([], stddev=0.5)
self.perturb = self.l1_scale.assign_add(noise)
return tfc.layers.l1_regularizer(2 ** self.l1_scale)
@lru_cache(maxsize=None)
def copy_from(self, other_model):
# This method is used for exploitation. We copy all weights and hyper-parameters
# from other_model to this model
my_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name_scope + '/')
their_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, other_model.name_scope + '/')
assign_ops = [mine.assign(theirs).op for mine, theirs in zip(my_weights, their_weights)]
return tf.group(*assign_ops)
def create_model(*args, **kwargs):
with tf.variable_scope(None, 'model'):
return Model(*args, **kwargs)
ITERATIONS = 50000
nonreg_accuracy_hist = np.zeros((ITERATIONS // 100,))
model = create_model(0, regularize=False)
with tf.Session() as sess:
train_handle = make_handle(sess, train)
test_handle = make_handle(sess, test)
sess.run(tf.global_variables_initializer())
feed_dict = {handle: train_handle}
test_feed_dict = {handle: test_handle}
for i in range(ITERATIONS):
# Training
sess.run(model.optimize, feed_dict)
# Evaluate
if i % 100 == 0:
nonreg_accuracy_hist[i // 100] = sess.run(model.accuracy, test_feed_dict)
POPULATION_SIZE = 10
BEST_THRES = 3
WORST_THRES = 3
POPULATION_STEPS = 500
ITERATIONS = 100
accuracy_hist = np.zeros((POPULATION_SIZE, POPULATION_STEPS))
l1_scale_hist = np.zeros((POPULATION_SIZE, POPULATION_STEPS))
best_accuracy_hist = np.zeros((POPULATION_STEPS,))
best_l1_scale_hist = np.zeros((POPULATION_STEPS,))
models = [create_model(i) for i in range(POPULATION_SIZE)]
with tf.Session() as sess:
train_handle = make_handle(sess, train)
test_handle = make_handle(sess, test)
sess.run(tf.global_variables_initializer())
feed_dict = {handle: train_handle}
test_feed_dict = {handle: test_handle}
for i in range(POPULATION_STEPS):
# Copy best
sess.run([m.copy_from(models[0]) for m in models[-WORST_THRES:]])
# Perturb others
sess.run([m.perturb for m in models[BEST_THRES:]])
# Training
for _ in range(ITERATIONS):
sess.run([m.optimize for m in models], feed_dict)
# Evaluate
l1_scales = sess.run({m: m.l1_scale for m in models})
accuracies = sess.run({m: m.accuracy for m in models}, test_feed_dict)
models.sort(key=lambda m: accuracies[m], reverse=True)
# Logging
best_accuracy_hist[i] = accuracies[models[0]]
best_l1_scale_hist[i] = l1_scales[models[0]]
for m in models:
l1_scale_hist[m.model_id, i] = l1_scales[m]
accuracy_hist[m.model_id, i] = accuracies[m]
f = plt.figure(figsize=(10, 5))
ax = f.add_subplot(1, 1, 1)
ax.plot(best_accuracy_hist)
ax.plot(nonreg_accuracy_hist, c='red')
ax.set(xlabel='Hundreds of training iterations', ylabel='Test accuracy')
ax.legend(['Population based training', 'Non-regularized baseline'])
plt.show()
f = plt.figure(figsize=(10, 5))
ax = f.add_subplot(1, 1, 1)
ax.plot(2 ** l1_scale_hist.T)
ax.set_yscale('log')
ax.set(xlabel='Hundreds of training iterations', ylabel='L1 regularizer scale')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment