Skip to content

Instantly share code, notes, and snippets.

@fedden
Last active October 12, 2017 19:03
Show Gist options
  • Save fedden/35ae8eeae888d00b4d47eb3ccaa201bd to your computer and use it in GitHub Desktop.
Save fedden/35ae8eeae888d00b4d47eb3ccaa201bd to your computer and use it in GitHub Desktop.
Using DFO to Optimise TensorFlow Neural Networks
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using DFO to Optimise TensorFlow Neural Networks\n",
"\n",
"\n",
"### Imports\n",
"\n",
"It's well worth noting your TensorFlow and gym version. The API's change quickly so if you are having issues with this code then check that you are on the same version or your code is updated correctly. As it stands I am using the following:\n",
"- TensorFlow version: **1.3.0**\n",
"- OpenAI Gym version: **0.9.3**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"from collections import deque\n",
"import matplotlib.pyplot as plt\n",
"import threading\n",
"import multiprocessing\n",
"import gym\n",
"import json\n",
"import os\n",
"from time import sleep\n",
"\n",
"print('TensorFlow version:', tf.__version__)\n",
"print('OpenAI Gym version:', gym.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup parameters\n",
"\n",
"The network is setup with a rectified linear unit as the activation function. The hidden layers dimensions are specified in the list. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"hidden_non_linearity = tf.nn.relu\n",
"hidden_sizes = [6, 3]\n",
"output_size = 2\n",
"input_size = 4\n",
"cpu_only = True\n",
"env_name = 'CartPole-v0'\n",
"number_iterations = 200"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"scrolled": true
},
"source": [
"### Setup TF graph\n",
"\n",
"The ```tf.reset_default_graph()``` method is good for wiping the graph clean to try different dimensionalities. The network is then constructed using a for loop and there are special tensors created at the end of the cell to allow efficient insertion of weights concurrently into the graph. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"model_input = tf.placeholder(dtype=tf.float32, \n",
" shape=[None, input_size])\n",
"net = model_input\n",
"\n",
"for hidden_size in hidden_sizes:\n",
" net = tf.layers.dense(inputs=net,\n",
" units=hidden_size,\n",
" activation=hidden_non_linearity)\n",
"\n",
"net = tf.layers.dense(inputs=net,\n",
" units=output_size,\n",
" activation=tf.nn.softmax)\n",
"model_output = net\n",
"\n",
"graph = tf.get_default_graph()\n",
"\n",
"restore_dict = {}\n",
"restore_ops = []\n",
"for var in graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):\n",
" place_holder = tf.placeholder(tf.float32, var.get_shape(), 'ph%s' % var.name.split(':')[0])\n",
" restore_dict[var.name] = place_holder\n",
" restore_ops.append(tf.assign(var, place_holder))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define multi-threaded DFO class\n",
"\n",
"This class encapsulates the main DFO algorithm. It is used further down when the tensorflow session is created, and the main method worth checking out is the ```run()``` function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class DFOStrategy(object):\n",
" \"\"\"Class to manage the DFO optimisation.\"\"\" \n",
"\n",
" def __init__(self, \n",
" weights, \n",
" get_reward_func, \n",
" population_size=100, \n",
" disturbance_threshold=0.01,\n",
" name='CartPole-v0', \n",
" sess=None,\n",
" meta=None,\n",
" mode=0):\n",
" \"\"\"__init__ to set up the classes member variables.\n",
"\n",
" The population is created from a gaussian distribution\n",
" with the same deviation and mean from the passed in \n",
" weights.\n",
"\n",
" Note:\n",
" A legitimate tf.Session must be passed to the class.\n",
"\n",
" Args:\n",
" weights (list): list of hyperparameters from the mlp.\n",
" get_reward_func (method): externel method that returns \n",
" reward.\n",
" population_size (int): amount of flies.\n",
" disturbance_threshold (float): how regularily the flies \n",
" re-init.\n",
" name (str): the name of the OpenAI gym environment to run. \n",
" sess (tf.Session): a valid sess object to access tensors.\n",
" meta (tuple): important vars to reconstruct mlp weights.\n",
" \"\"\"\n",
" np.random.seed(0)\n",
" self.weights = weights\n",
" self.meta = meta\n",
" self.get_reward = get_reward_func\n",
" self.population_size = population_size\n",
" dev = np.std(weights)\n",
" mean = np.mean(weights)\n",
" if mean < 0:\n",
" mean = 0\n",
" self.population = np.array(\n",
" [np.random.normal(dev, mean, len(weights)) \n",
" for _ in range(self.population_size)])\n",
" self.disturbance_threshold = disturbance_threshold\n",
" self.env_name = name\n",
" self.env = gym.make(self.env_name)\n",
" self.sess = sess\n",
" self.swarms_best = None\n",
" self.swarms_best_score = np.finfo(np.float32).max\n",
" self.all_time_best = None\n",
" self.all_time_best_score = np.finfo(np.float32).max\n",
" self.mode = mode\n",
" self.best_reward_record = deque()\n",
" self.reward_mean_record = deque()\n",
" self.reward_sigma_record = deque()\n",
" assert sess is not None \n",
"\n",
"\n",
" def get_weights(self):\n",
" \"\"\"Returns the best weights generated thus far.\n",
" \n",
" Returns:\n",
" A list of weights if run has been called else returns None. \n",
" \"\"\"\n",
" return self.all_time_best\n",
" \n",
" \n",
" def draw(self):\n",
" \"\"\"Quick and dirty data plotter.\"\"\"\n",
" x = np.arange(len(self.best_reward_record))\n",
" \n",
" f, axarr = plt.subplots(3, sharex=True)\n",
" axarr[0].plot(x, self.best_reward_record)\n",
" axarr[0].set_title('Reward, Mean, Std Dev.')\n",
" axarr[1].plot(x, self.reward_mean_record)\n",
" axarr[2].plot(x, self.reward_sigma_record)\n",
" plt.show()\n",
" \n",
" \n",
" def print_out(self, iteration, print_step):\n",
" \"\"\"Quick and dirty data printer.\"\"\"\n",
" if iteration % print_step == 0 and self.all_time_best is not None:\n",
" print('iter %d. reward: %f. dt: %f. best: %f.' % (iteration,\n",
" self.swarms_best_score,\n",
" self.disturbance_threshold,\n",
" self.all_time_best_score))\n",
"\n",
"\n",
" def run(self, iteration_amount, elitism=0, print_step=10, decay=0.98):\n",
" \"\"\"The main optimisation method.\n",
"\n",
" As many as optimal threads are created and the environments are\n",
" passed in and computed in parallel. The main DFO algoithm is then\n",
" computed and the flies are updated.\n",
" \n",
" Args:\n",
" iteration_amount (int): How many rounds of optimisations are ran.\n",
" save_step (int): How often should we save the weights.\n",
" print_step (int): How often we should print the reward/fitness.\n",
"\n",
" \"\"\"\n",
" saver = tf.train.Saver()\n",
" best_neighbour = np.zeros_like(self.population[0])\n",
" \n",
" envs = [gym.make(self.env_name) for _ in range(multiprocessing.cpu_count())]\n",
" for iteration in range(iteration_amount):\n",
" \n",
" self.print_out(iteration, print_step)\n",
" \n",
" amount_per_thread = int(np.floor(self.population_size / multiprocessing.cpu_count()))\n",
" left_over = self.population_size - amount_per_thread * multiprocessing.cpu_count()\n",
" \n",
" fitnesses = np.zeros(len(self.population))\n",
" \n",
" def get_weights_reward(begin, size, env):\n",
" for i in range(begin, begin + size):\n",
" fitnesses[i] = -self.get_reward(self.population[i], \n",
" self.sess, \n",
" env,\n",
" self.meta)\n",
" threads = []\n",
" idx = 0\n",
" for i in range(multiprocessing.cpu_count()):\n",
" amt = (amount_per_thread + 1) if i < left_over else amount_per_thread\n",
" thread = threading.Thread(target=get_weights_reward,\n",
" args=[idx, amt, envs[i]])\n",
" threads.append(thread)\n",
" idx += amt\n",
" \n",
" assert idx == len(self.population)\n",
" \n",
" for t in threads:\n",
" t.start()\n",
" for t in threads:\n",
" t.join()\n",
" \n",
" swarms_best_index = np.argmin(fitnesses)\n",
" self.swarms_best_score = np.amin(fitnesses)\n",
" self.swarms_best = self.population[swarms_best_index]\n",
" \n",
" if self.swarms_best_score <= self.all_time_best_score:\n",
" self.all_time_best_score = self.swarms_best_score\n",
" self.all_time_best = self.swarms_best\n",
" self.get_reward(self.all_time_best, sess, self.env, self.meta)\n",
" saver.save(sess, self.env_name + '_dfo')\n",
" \n",
" r = np.random.uniform(0.0, 1.0, self.population.shape)\n",
" self.lower = np.amin(self.population)\n",
" self.upper = np.amax(self.population)\n",
" dev = np.std(self.population)\n",
" mean = np.mean(self.population)\n",
" self.best_reward_record.append(self.swarms_best_score) \n",
" self.reward_mean_record.append(mean) \n",
" self.reward_sigma_record.append(dev) \n",
" if mean < 0:\n",
" mean = 0\n",
" \n",
" if elitism > 0: \n",
" n = elitism\n",
" n_fittest = np.argpartition(fitnesses, range(n))[:n]\n",
" \n",
" leader_rate = np.random.uniform(0.0, 1.0)\n",
" self.disturbance_threshold *= decay\n",
" \n",
" for i, p in enumerate(self.population):\n",
" \n",
" if self.mode != 'n_fittest' and elitism > 0 and i in n_fittest:\n",
" pass\n",
" \n",
" else:\n",
" \n",
" left = (i - 1) if i != 0 else len(self.population) - 1\n",
" right = (i + 1) if i != (len(self.population) - 1) else 0\n",
"\n",
" if fitnesses[left] < fitnesses[right]:\n",
" best_neighbour = self.population[left] \n",
" else:\n",
" best_neighbour = self.population[right]\n",
"\n",
" for x in range(len(p)):\n",
"\n",
" if self.mode == 'original':\n",
" if r[i][x] < self.disturbance_threshold:\n",
" p[x] = np.random.normal(dev, mean)\n",
" else:\n",
" leader_rate = np.random.uniform(0.0, 1.0)\n",
" update = self.swarms_best[x] - best_neighbour[x]\n",
" p[x] = best_neighbour[x] + leader_rate * update\n",
"\n",
" elif self.mode == 'hybrid':\n",
" if r[i][x] < self.disturbance_threshold:\n",
" p[x] = np.random.normal(dev, mean)\n",
" else:\n",
" leader_rate = np.random.uniform(0.0, 1.0)\n",
" update = (best_neighbour[x] + self.swarms_best[x]) / 2.0 - p[x]\n",
" p[x] = p[x] + leader_rate * update \n",
" \n",
" elif self.mode == 'n_fittest':\n",
" if r[i][x] < self.disturbance_threshold:\n",
" p[x] = np.random.normal(dev, mean)\n",
" else:\n",
" leader_rate = np.random.uniform(0.0, 1.0)\n",
" update = np.average(self.population[n_fittest]) - best_neighbour[x]\n",
" p[x] = best_neighbour[x] + leader_rate * update \n",
" \n",
" elif self.mode == 'no_leader_with_random':\n",
" if r[i][x] < self.disturbance_threshold:\n",
" p[x] = np.random.normal(dev, mean)\n",
" else:\n",
" update = best_neighbour[x] - p[x]\n",
" p[x] = p[x] + leader_rate * update\n",
"\n",
" elif self.mode == 'no_leader': \n",
" update = best_neighbour[x] - p[x]\n",
" p[x] = p[x] + leader_rate * update\n",
"\n",
" elif self.mode == 'random_gauss': \n",
" p[x] = np.random.normal(dev, mean)\n",
"\n",
" elif self.mode == 'random_uniform':\n",
" p[x] = np.random.sample()\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Getters and setters for TF trainable variables\n",
"\n",
"These methods are the nuts and bolts of getting and setting, and more generally, the way of optimising the weights other than tf's built in auto backprop optimsers. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_weights(sess):\n",
" \"\"\"Get weights from sess.\n",
"\n",
" This method essentially obtains and compresses all of the the trainable\n",
" tensors into a list.\n",
" \n",
" Args:\n",
" sess (tf.Session): The tf session with the correct tensors in the graph.\n",
"\n",
" Returns:\n",
" genotype (list): list of floats, comprising of the weights.\n",
" meta (tuple): important variables to rebuild the weights in set_weights.\n",
"\n",
" \"\"\"\n",
" all_variable_names = [v.name for v in tf.trainable_variables()]\n",
" all_variable_values = sess.run(all_variable_names)\n",
" all_variable_shapes = [v.shape for v in all_variable_values]\n",
" all_variable_cutoffs = [np.prod(s) for s in all_variable_shapes]\n",
" genotype = np.concatenate([v.flatten() for v in all_variable_values])\n",
" return genotype, (all_variable_names, all_variable_shapes, all_variable_cutoffs)\n",
"\n",
"\n",
"def set_weights(sess, new_genotype, meta):\n",
" \"\"\"Set the weights in the sess.\n",
"\n",
" This takes the list of weights, chops them up correctly, feeds it into the\n",
" graph to the right tensors. This is concurrently safe.\n",
" \n",
" Args:\n",
" new_genotype (list): A list of floats which are the new weights for the mlp.\n",
" sess (tf.Session): The tf session with the correct tensors in the graph.\n",
" meta (tuple): important variables to rebuild the weights correctly.\n",
" \"\"\"\n",
" names, shapes, cutoffs = meta\n",
" new_genotype = np.array(new_genotype)\n",
" new_variable_values = []\n",
" start = 0\n",
" end = cutoffs[0]\n",
" for i in range(1, len(cutoffs)):\n",
" new_variable = new_genotype[start:end]\n",
" new_variable_values.append(new_variable)\n",
" start = end\n",
" end += cutoffs[i]\n",
" new_variable_values.append(new_genotype[:-start])\n",
" \n",
" feed_dict = {}\n",
" for i in range(len(shapes)):\n",
" new_variable_values[i] = new_variable_values[i].reshape(shapes[i]) \n",
" feed_dict[restore_dict[names[i]]] = new_variable_values[i]\n",
" sess.run(restore_ops, feed_dict=feed_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define our reward / fitness function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_reward(weights, sess, env, meta):\n",
" \"\"\"Get the reward from the passed in weights in the passed in environment.\n",
"\n",
" The env will run until it returns false for done.\n",
" \n",
" Args:\n",
" weights (list): A list of floats which are the new weights for the mlp.\n",
" sess (tf.Session): The tf session with the correct tensors in the graph.\n",
" env (openai): the OpenAI gym environment to run. \n",
" meta (tuple): important variables to rebuild the weights correctly.\n",
"\n",
" Returns:\n",
" A float representing the total reward. Bigger is better. \n",
"\n",
" \"\"\"\n",
" set_weights(sess, weights, meta)\n",
"\n",
" total_reward = 0\n",
" done = False\n",
" observation = env.reset()\n",
"\n",
" while not done:\n",
"\n",
" feed_dict = {\n",
" model_input: observation.reshape((1, -1))\n",
" }\n",
" prediction = sess.run(model_output, \n",
" feed_dict=feed_dict) \n",
" action = prediction[0]\n",
" action = np.argmax(action)\n",
"\n",
" observation, reward, done, info = env.step(action)\n",
" total_reward += reward\n",
"\n",
" return total_reward"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run and opimise the MLP\n",
"\n",
"Now we can run the optimisation steps. A number of modifications have been made to the DFO algo which are worth playing with. Once a new best fly has been created, the tensorflow weights are saved. Parameters in particular to be aware of are population_size, disturbance_threshold, mode, decay and elitism. From the mode parameter the following options are available: \n",
"```\n",
"'original'\n",
"'hybrid'\n",
"'n_highest'\n",
"'no_leader_with_random'\n",
"'no_leader'\n",
"'random_gauss'\n",
"'random_uniform'\n",
"```\n",
"\n",
"**Please note:** the Cartpole environment is considered solved when 200 reward has been achieved. Note also that we are inversing our reward because DFO is minimising and hasn't been modfied to maximise so we are trying to get -200.0! "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"config = tf.ConfigProto(device_count = {'GPU': 0}) if cpu_only else None \n",
"\n",
"with tf.Session(config=config) as sess:\n",
" \n",
" sess.run(tf.global_variables_initializer())\n",
" \n",
" initial_weights, meta = get_weights(sess)\n",
"\n",
" es = DFOStrategy(initial_weights, \n",
" get_reward,\n",
" population_size=1000, \n",
" disturbance_threshold=0.1,\n",
" name=env_name,\n",
" sess=sess,\n",
" meta=meta,\n",
" mode='original')\n",
" \n",
" es.run(number_iterations, \n",
" elitism=20,\n",
" print_step=1,\n",
" decay=0.99)\n",
" \n",
" best_weights = es.get_weights()\n",
" set_weights(sess, best_weights, meta)\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Graph of results\n",
"\n",
"Note that the reward is inversed to be negative here, as the stock DFO algorithm minimises an objective function and the Openai envs return positive rewards correlated with positive performance. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"es.draw()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preview the network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"from gym import wrappers\n",
"\n",
"with tf.Session(config=config) as sess:\n",
" saver = tf.train.Saver()\n",
" saver.restore(sess, env_name + '_dfo')\n",
"\n",
" env = es.env\n",
" env = wrappers.Monitor(env, \".\", force=True)\n",
" env.reset()\n",
" env.render(close=True)\n",
" \n",
" observation = es.env.reset()\n",
" \n",
" for i in range(200):\n",
"\n",
" env.render()\n",
"\n",
" feed_dict = {\n",
" model_input: observation.reshape((1, -1))\n",
" }\n",
" prediction = sess.run(model_output,\n",
" feed_dict=feed_dict)\n",
"\n",
" action = prediction[0]\n",
" action = np.argmax(action)\n",
"\n",
" observation, reward, done, info = env.step(action)\n",
" if done == True:\n",
" print(\"Done early at step\", i)\n",
" break\n",
" sleep(1.0/60.0)\n",
" env.close() "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment