Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@yukiB
Last active May 27, 2016 13:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yukiB/422fecce136bf0f2a707f667f36c8db5 to your computer and use it in GitHub Desktop.
Save yukiB/422fecce136bf0f2a707f667f36c8db5 to your computer and use it in GitHub Desktop.
TensorFlow学習パラメータのsave, restoreでつまった ref: http://qiita.com/yukiB/items/a7a92af4b27e0c4e6eb2
# 幾つか変数作成
w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
saver = tf.train.Saver()
...
# 保存
saver.save(sess, "model.ckpt", global_step=100)
# v1をmy_v1として保存
tf.train.Saver('my_v1': v1)
# v1, v2のみ保存
tf.train.Saver([v1, v2])
w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
with tf.Session() as sess:
# 変数の読み込み
saver.restore(sess, "model.ckpt")
ckpt = tf.train.get_checkpoint_state('./'):
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./'):
if ckpt: # checkpointがある場合
last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス
print "load " + last_model
saver.restore(sess, last_model) # 変数データの読み込み
...
else: # 保存データがない場合
init = tf.initialize_all_variables()
sess.run(init) #変数を初期化して実行
init = tf.initialize_all_variables()
sess.run(init) #変数を初期化して実行
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment