Skip to content

Instantly share code, notes, and snippets.

@tillahoffmann
Created July 28, 2016 13:48
Show Gist options
  • Save tillahoffmann/7aef8e89b04f1370f24c4262e30fce95 to your computer and use it in GitHub Desktop.
Save tillahoffmann/7aef8e89b04f1370f24c4262e30fce95 to your computer and use it in GitHub Desktop.
Serialisation of tensorflow models without using collections
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Definition and serialisation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Define coefficients\n",
"p = 5\n",
"coefficient_values = np.random.normal(0, 1, p)\n",
"\n",
"# Create a graph\n",
"with tf.Graph().as_default() as original_graph:\n",
" placeholder = tf.placeholder(tf.float32, [None, p], name='placeholder')\n",
" coefficients = tf.Variable(coefficient_values, name='coefficients', dtype=tf.float32)\n",
" predictor = tf.reduce_sum(coefficients * placeholder, 1, name='predictor')\n",
" init_op = tf.initialize_all_variables()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Create a session and initialise\n",
"session = tf.Session(graph=original_graph)\n",
"session.run(init_op)\n",
"\n",
"# Make sure the graph performs as expected\n",
"X = np.random.normal(0, 1, (100, p))\n",
"actual = session.run(predictor, {placeholder: X})\n",
"desired = np.dot(X, coefficient_values)\n",
"np.testing.assert_allclose(actual, desired, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Serialise the graph\n",
"with original_graph.as_default():\n",
" saver = tf.train.Saver()\n",
" saver.save(session, 'linear_model.tf')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deserialisation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Deserialise the graph\n",
"with tf.Graph().as_default() as restored_graph:\n",
" restored_session = tf.Session(graph=restored_graph)\n",
" saver = tf.train.import_meta_graph('linear_model.tf.meta')\n",
" saver.restore(session, 'linear_model.tf')\n",
" restored_session.run(tf.initialize_all_variables())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Option 1 gets the tensors (see https://github.com/tensorflow/tensorflow/issues/3378 and \n",
"# http://stackoverflow.com/a/37870634/419116 for information on the :0 at the end)\n",
"restored_placeholder = restored_graph.get_tensor_by_name('placeholder:0')\n",
"restored_predictor = restored_graph.get_tensor_by_name('predictor:0')\n",
"actual = restored_session.run(restored_predictor, {restored_placeholder: X})\n",
"np.testing.assert_allclose(actual, desired, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Option 2 simply uses their names\n",
"actual = restored_session.run('predictor:0', {'placeholder:0': X})\n",
"np.testing.assert_allclose(actual, desired, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor 'placeholder:0' shape=(?, 5) dtype=float32>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The shape is persisted\n",
"restored_graph.get_tensor_by_name('placeholder:0')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment