Detailed description and tidying up of code to come later; for now:
import tensorflow as tf
import numpy as np
from multiprocessing.pool import Pool
from time import time
class MetaModel():
def __init__(
self, initial_val=-1.0,
num_inner_steps=5,
num_outer_steps=5,
):
# Define some random model, in order to test functionality...
self.num_inner_steps = num_inner_steps
self.num_outer_steps = num_outer_steps
# Define model:
self.variable = tf.Variable(
initial_value=initial_val,
# name='param'
)
self.variable2 = tf.Variable(
initial_value=[2*initial_val, 3.5*initial_val]
)
self.variable3 = tf.Variable(
initial_value=[
[initial_val + 2, initial_val + 3.5],
[initial_val + 6.3, initial_val + 11],
]
)
# Placeholder to hold a tensor of numerical data:
self.data_placeholder = tf.placeholder(dtype=tf.float32)
# Random update rule to verify functionality:
self.update_params = self.variable.assign_add(self.data_placeholder)
# ***
# NB the following statements all depend on calls to
# tf.global_variables(), so any tf.Variables added to the graph after
# them will not be taken into account (NB tf.Operations and
# tf.placeholders are not instances of tf.Variables):
# ***
# oh look it's an init op
self.init_op = tf.initializers.global_variables()
# (does this want to be defined here, or within a method?):
self.uninitialised_op = tf.report_uninitialized_variables()
# ... <descriptive comment>
self.global_vars = tf.global_variables()
# ... etc
self.parameter_placeholder_list = [
tf.placeholder(dtype=tf.float32) for _ in self.global_vars
]
# var.assign actually returns a Tensor, not an Operation...
# see https://www.tensorflow.org/api_docs/python/tf/Variable#assign
# ... but I like to think of assignment as an operation
self.parameter_assign_op_list = [
var.assign(placeholder) for (var, placeholder) in zip(
self.global_vars, self.parameter_placeholder_list
)
]
# ...
# NB axis must be set to zero because of the list-comprehension...
self.reduced_mean_tensor_list = [
tf.reduce_mean(
input_tensor=placeholder,
axis=0
) for placeholder in self.parameter_placeholder_list
]
# Can insert an operation here in between reduce_mean and assign, EG
# Bayesian inference for parameters by using previous value as a prior
self.assign_reduced_mean_op_list = [
var.assign(reduced_mean) for (var, reduced_mean) in zip(
self.global_vars, self.reduced_mean_tensor_list
)
]
def initialise_variables(self, sess):
sess.run(self.init_op)
def get_global_vars(self, sess):
return sess.run(self.global_vars)
def set_global_vars(self, sess, global_vars):
# `global_vars` must come from a call to the `get_global_vars` method;
# could maybe use an `assert` statement for the shape of each param?
sess.run(
fetches=self.parameter_assign_op_list,
feed_dict={
placeholder: value for placeholder, value in zip(
self.parameter_placeholder_list, global_vars
)
}
)
def adapt(self, sess, task_data):
# In reality this will be some steps of gradient descent
for _ in range(self.num_inner_steps):
sess.run(
fetches=self.update_params,
feed_dict={self.data_placeholder: task_data}
)
def meta_update(self, sess, params_list):
sess.run(
fetches=self.assign_reduced_mean_op_list,
feed_dict={
placeholder: [
params[val_index] for params in params_list
] for (placeholder, val_index) in zip(
self.parameter_placeholder_list,
range(len(self.parameter_placeholder_list))
)
}
)
def report_uninitialised(self, sess):
uninitialised_list = sess.run(self.uninitialised_op)
if len(uninitialised_list) > 0:
print("Remaining uninitialised variables:")
for var in uninitialised_list:
print(" - {}".format(var))
def inner_update(initial_params, task_data):
# Create an instance of the model:
model = MetaModel()
# Start a session:
with tf.Session() as sess:
# Set initial parameters:
model.set_global_vars(sess, initial_params)
# print("Model vars:")
# for p in model.get_global_vars(sess): print(p)
# Check that all variables have been initialised:
model.report_uninitialised(sess)
# Perform fast adaptation:
model.adapt(sess, task_data)
# Retrieve the learned task-specific parameters:
task_specific_params = model.get_global_vars(sess)
# Have to reset the default graph, otherwise variables in this worker's
# graph persist even after the session ends and those variables lose their
# values and the worker starts handling a new task:
tf.reset_default_graph()
# Wait... if all the variables remain in the pool-worker... is it possible
# to retreive the model object itself... ?????
# Return the task specific parameters to the outer loop
return task_specific_params
def train_parallel(model, task_set, worker_pool):
with tf.Session() as sess:
model.initialise_variables(sess)
model.report_uninitialised(sess)
# Retrieve meta-parameters, ready to pass to the inner update:
meta_params = model.get_global_vars(sess)
print("Initial params:")
for p in meta_params: print(p)
inner_arg_list = [(meta_params, task) for task in task_set]
# Each elem of the result is a list of task-specific params:
task_specific_params_list = worker_pool.starmap(
inner_update, inner_arg_list
)
for ps in task_specific_params_list:
print("Task-specific params:")
for p in ps: print(p)
model.meta_update(sess, task_specific_params_list)
meta_params = model.get_global_vars(sess)
print("Final params:")
for p in meta_params: print(p)
if __name__ == "__main__":
num_tasks = 3
initial_param = 2.134
task_set = range(5, 5+num_tasks)
model = MetaModel(initial_param)
with Pool() as p:
print("Warming up the pool...")
p.map(int,[0])
print("Starting training...")
start_time = time()
train_parallel(model, task_set, worker_pool=p)
end_time = time()
print("{} tasks took {:.3f}s".format(num_tasks, end_time - start_time))
Returns:
Warming up the pool...
Starting training...
Initial params:
2.134
[4.268 7.469]
[[ 4.134 5.634]
[ 8.434 13.134]]
Task-specific params:
27.133999
[4.268 7.469]
[[ 4.134 5.634]
[ 8.434 13.134]]
Task-specific params:
32.134
[4.268 7.469]
[[ 4.134 5.634]
[ 8.434 13.134]]
Task-specific params:
37.134
[4.268 7.469]
[[ 4.134 5.634]
[ 8.434 13.134]]
Final params:
32.134
[4.268 7.469]
[[ 4.134 5.6340003]
[ 8.434 13.134 ]]
3 tasks took 0.293s