Skip to content

Instantly share code, notes, and snippets.

Created July 4, 2018 14:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoSaverioZuppichini/002b66e4b384fe31bd005a00845f5fa3 to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/002b66e4b384fe31bd005a00845f5fa3 to your computer and use it in GitHub Desktop.

Serialize Class to TensorFlow Graph

Francesco Saverio Zuppichini

Would it be cool to automatically bind class fields to tensorflow variables in a graph and restore them without manually get each variable back from the name and name them?

Image you have a Model class

/usr/local/lib/python3.6/site-packages/h5py/ FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Usually, you first build your model, then you train it. After that, you want to get from the saved graph the old variable without rebuild the whole model.

<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>

Now, imagine we have just trained our model and we want to store it. The usual pattern is

Now you want to perform inference, aka get your stuff back, by loading the stored graph. In our case, we want the variable named variable

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt

Now we can get back our variable from the graph

name: "variable"
op: "VariableV2"
attr {
  key: "container"
  value {
    s: ""
attr {
  key: "dtype"
  value {
    type: DT_INT32
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: 1
attr {
  key: "shared_name"
  value {
    s: ""

But, what if we want to use our model class again? If we try now to call model.variable we get None


One solution is to build again the whole model and restore the graph after that

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>

You can already see that is a big waste of time. We can bind model.variable directly to the correct graph node by

name: "variable"
op: "VariableV2"
attr {
  key: "container"
  value {
    s: ""
attr {
  key: "dtype"
  value {
    type: DT_INT32
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: 1
attr {
  key: "shared_name"
  value {
    s: ""

Now image we have a very big model with nested variables. In order to correct restore each variable pointer in the model you need to:

  • name each variable
  • get the variables back from the graph

Would it be cool if we can automatically retrieve all the variables setted as a field in the Model class?


I have created a class, called TFGraphConvertible. You can use the TFGraphConvertible to automatically serialize and deserialize" a class.

Let's recreate our model

It exposes two methods: to_graph and from_graph

Serialize - to_graph

In order to serialize a class you can call the to_graph method that creates a dictionary of field names -> tensorflow variables name. You need to pass a fields arguments, a dictionary of what field we want to serialize. In our case, we can just pass all of them.

{'variable': 'variable_2:0'}

It will create a dictionary with all the fields as keys and the corresponding tensorflow variables name as values

Deserialize - from_graph

In order to deserialize a class you can call the from_graph method that takes the previous created dictionary and bind each class fields to the correct tensorflow variables


<tf.Tensor 'variable_2:0' shape=(1,) dtype=int32_ref>

And now you have your model back!

Full Example

Let's see a more interesting example! We are going to train/restore a model for the MNIST dataset

Let's get the dataset!

Using TensorFlow backend.

Now it is time to train it


Perfect! Let's store the serialized model in memory

{'x': 'ExpandDims:0', 'y': 'one_hot:0', 'forward_raw': 'dense_1/BiasAdd:0', 'accuracy': 'Mean:0', 'loss': 'Mean_1:0', 'train_step': 'Adam'}

Then we reset the graph and recreat the model

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt

Of course, our variables in the mnist_model do not exist


AttributeError                            Traceback (most recent call last)

<ipython-input-21-9def5e0d8f6c> in <module>()
----> 1 mnist_model.accuracy

AttributeError: 'MNISTModel' object has no attribute 'accuracy'

Let's recreate them by calling the from_graph method.

<tf.Tensor 'Mean:0' shape=() dtype=float32>

Now mnist_model is ready to go, let's see the accuracy on a bacth of the test set

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt


With this tutorial we have seen how to serialize a class and bind each field back to the correct tensor in the tensorflow graph. Be awere that you can store the serialized_model in .json format and load it directly where you need. In this way, you can directly create your model by using Object Oriented Programming and retrieve all the variales inside them without having to rebuild them.

Thank you for reading

Francesco Saverio Zuppichini

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment