Skip to content

Instantly share code, notes, and snippets.

@asimshankar
Last active October 10, 2021 12:45
Show Gist options
  • Save asimshankar/fb1f42c3bd91e1bb041f34a848e59fe1 to your computer and use it in GitHub Desktop.
Save asimshankar/fb1f42c3bd91e1bb041f34a848e59fe1 to your computer and use it in GitHub Desktop.
TensorFlow: Saving and restoring variables in Go
import tensorflow as tf
# Construct the graph
x = tf.Variable(1, name='x')
y = tf.Variable(2, name='y')
sum = tf.assign_add(x, y, name='sum')
# Add operations to save and restore checkpoints
saver = tf.train.Saver()
# Save the graph
with open('/tmp/graph.pb', 'w') as f: f.write(tf.get_default_graph().as_graph_def().SerializeToString())
# Print out Go code snippet to save/restore
# Perhaps it may make sense for tf.Session to return a pointer
# to the tf.Graph it operates on instead of having to pass both
# the graph and session consistently.
sd = saver.saver_def
print('''
// save saves the current value of variables in graph/sess in files with the
// given prefix and returns the string to provide to restore.
func save(graph *tf.Graph, sess *tf.Session, prefix string) (string, error) {
t, err := tf.NewTensor(prefix)
if err != nil {
return "", err
}
o := graph.Operation("%s").Output(0)
ret, err := sess.Run(map[tf.Output]*tf.Tensor{o:t}, []tf.Output{graph.Operation("%s").Output(0)}, nil)
if err != nil {
return "", err
}
return ret[0].Value().(string), nil
}
// restore restores the value of variables previously saved using save.
func restore(graph *tf.Graph, sess *tf.Session, path string) error {
t, err := tf.NewTensor(path)
if err != nil {
return err
}
o := graph.Operation("%s").Output(0)
_, err = sess.Run(map[tf.Output]*tf.Tensor{o:t}, nil, []*tf.Operation{graph.Operation("%s")})
return err
}
''') % (sd.filename_tensor_name[:-2], sd.save_tensor_name[:-2], sd.filename_tensor_name[:-2], sd.restore_op_name)
# For fun, save the checkpoint where x=3
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(sum)
print "Saved to: " + saver.save(sess, "/tmp/ckpt1")
package main
import (
"fmt"
"io/ioutil"
"log"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func main() {
gdef, err := ioutil.ReadFile("/tmp/graph.pb")
if err != nil {
log.Fatal(err)
}
graph := tf.NewGraph()
if err := graph.Import(gdef, ""); err != nil {
log.Fatal(err)
}
sess, err := tf.NewSession(graph, nil)
if err != nil {
log.Fatal(err)
}
defer sess.Close()
// Restore an existing checkpoint
if err := restore(graph, sess, "/tmp/ckpt1"); err != nil {
log.Fatal(err)
}
// Run an update and save a new checkpoint.
if _, err := sess.Run(nil, nil, []*tf.Operation{graph.Operation("sum")}); err != nil {
log.Fatal(err)
}
path, err := save(graph, sess, "/tmp/ckpt2")
if err != nil {
log.Fatal(err)
}
fmt.Println("Saved checkpoint to", path)
}
// Code below generated by the python script above
// save saves the current value of variables in graph/sess in files with the
// given prefix and returns the string to provide to restore.
func save(graph *tf.Graph, sess *tf.Session, prefix string) (string, error) {
t, err := tf.NewTensor(prefix)
if err != nil {
return "", err
}
o := graph.Operation("save/Const").Output(0)
ret, err := sess.Run(map[tf.Output]*tf.Tensor{o: t}, []tf.Output{graph.Operation("save/control_dependency").Output(0)}, nil)
if err != nil {
return "", err
}
return ret[0].Value().(string), nil
}
// restore restores the value of variables previously saved using save.
func restore(graph *tf.Graph, sess *tf.Session, path string) error {
t, err := tf.NewTensor(path)
if err != nil {
return err
}
o := graph.Operation("save/Const").Output(0)
_, err = sess.Run(map[tf.Output]*tf.Tensor{o: t}, nil, []*tf.Operation{graph.Operation("save/restore_all")})
return err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment