Would it be cool to automatically bind class fields to tensorflow variables in a graph and restore them without manually get each variable back from the name and name them?
Image you have a Model
class
https://gist.github.com/764c20a0b7c871851f2b6d354fd17372
/usr/local/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Usually, you first build your model, then you train it. After that, you want to get from the saved graph the old variable without rebuild the whole model.
https://gist.github.com/e46c216a4883a88930268d3d72860788
<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>
Now, imagine we have just trained our model and we want to store it. The usual pattern is
https://gist.github.com/b752a437082a584fbc4d0b55046b596f
Now you want to perform inference, aka get your stuff back, by loading the stored graph. In our case, we want the variable named variable
https://gist.github.com/d992c0a431745306635e62cf13d5ae98
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Now we can get back our variable
from the graph
https://gist.github.com/4310a0bd3b1eccd7b63785fde7cebcfb
name: "variable"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
But, what if we want to use our model
class again? If we try now to call model.variable
we get None
https://gist.github.com/734dc7d2e3b4677321f427f77af72703
None
One solution is to build again the whole model and restore the graph after that
https://gist.github.com/23e4054280e934c9a4b4c3d9eb715def
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>
You can already see that is a big waste of time. We can bind model.variable
directly to the correct graph node by
https://gist.github.com/69d878978e8196c710256dfea1b24bf3
name: "variable"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
Now image we have a very big model with nested variables. In order to correct restore each variable pointer in the model you need to:
- name each variable
- get the variables back from the graph
Would it be cool if we can automatically retrieve all the variables setted as a field in the Model class?
I have created a class, called TFGraphConvertible
. You can use the TFGraphConvertible
to automatically serialize and deserialize" a class.
Let's recreate our model
https://gist.github.com/4bcb5cc60c98d447d65d8a79ee92d8ac
It exposes two methods: to_graph
and from_graph
In order to serialize a class you can call the to_graph method that creates a dictionary of field names -> tensorflow variables name. You need to pass a fields
arguments, a dictionary of what field we want to serialize. In our case, we can just pass all of them.
https://gist.github.com/bd46adf6bce31bb83c1a6ac1baacb83e
{'variable': 'variable_2:0'}
It will create a dictionary with all the fields as keys and the corresponding tensorflow variables name as values
In order to deserialize a class you can call the from_graph method that takes the previous created dictionary and bind each class fields to the correct tensorflow variables
https://gist.github.com/e922b3949aebe63e12f8d2368a749225
None
<tf.Tensor 'variable_2:0' shape=(1,) dtype=int32_ref>
And now you have your model
back!
Let's see a more interesting example! We are going to train/restore a model for the MNIST dataset
https://gist.github.com/be6a79ff84d333a0896ff57fc6105bc0
Let's get the dataset!
https://gist.github.com/c7d22cdc77e53fe74c491be60fec4890
Using TensorFlow backend.
Now it is time to train it
https://gist.github.com/7e22a21a3fedbeaf8a2cac3b1ea98e7e
0.125
0.46875
0.8125
0.953125
0.828125
0.890625
0.796875
0.9375
0.953125
0.921875
Perfect! Let's store the serialized model in memory
https://gist.github.com/e5ecc4a5dcefd595f1b503b3c761a7c8
{'x': 'ExpandDims:0', 'y': 'one_hot:0', 'forward_raw': 'dense_1/BiasAdd:0', 'accuracy': 'Mean:0', 'loss': 'Mean_1:0', 'train_step': 'Adam'}
Then we reset the graph and recreat the model
https://gist.github.com/cf720a3eff740cda8bc7a4437a2a9af4
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Of course, our variables in the mnist_model
do not exist
https://gist.github.com/1e4401a45bd0809af57486891ce99655
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-21-9def5e0d8f6c> in <module>()
----> 1 mnist_model.accuracy
AttributeError: 'MNISTModel' object has no attribute 'accuracy'
Let's recreate them by calling the from_graph
method.
https://gist.github.com/44a2aab032610f7b0d3da037c533d84b
<tf.Tensor 'Mean:0' shape=() dtype=float32>
Now mnist_model
is ready to go, let's see the accuracy on a bacth of the test set
https://gist.github.com/70b7c85931867f9d3953cb5c1e391fcc
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
1.0
With this tutorial we have seen how to serialize a class and bind each field back to the correct tensor in the tensorflow graph. Be awere that you can store the serialized_model
in .json
format and load it directly where you need. In this way, you can directly create your model by using Object Oriented Programming and retrieve all the variales inside them without having to rebuild them.
Thank you for reading
Francesco Saverio Zuppichini