Last active
May 27, 2016 13:13
-
-
Save yukiB/422fecce136bf0f2a707f667f36c8db5 to your computer and use it in GitHub Desktop.
TensorFlow学習パラメータのsave, restoreでつまった ref: http://qiita.com/yukiB/items/a7a92af4b27e0c4e6eb2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 幾つか変数作成 | |
w1 = tf.Variable(..., name="v1") | |
w2 = tf.Variable(..., name="w2") | |
v1 = tf.Variable(..., name="v1") | |
v2 = tf.Variable(..., name="v2") | |
... | |
saver = tf.train.Saver() | |
... | |
# 保存 | |
saver.save(sess, "model.ckpt", global_step=100) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# v1をmy_v1として保存 | |
tf.train.Saver('my_v1': v1) | |
# v1, v2のみ保存 | |
tf.train.Saver([v1, v2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
w1 = tf.Variable(..., name="v1") | |
w2 = tf.Variable(..., name="w2") | |
v1 = tf.Variable(..., name="v1") | |
v2 = tf.Variable(..., name="v2") | |
with tf.Session() as sess: | |
# 変数の読み込み | |
saver.restore(sess, "model.ckpt") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ckpt = tf.train.get_checkpoint_state('./'): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.Session() as sess: | |
ckpt = tf.train.get_checkpoint_state('./'): | |
if ckpt: # checkpointがある場合 | |
last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス | |
print "load " + last_model | |
saver.restore(sess, last_model) # 変数データの読み込み | |
... | |
else: # 保存データがない場合 | |
init = tf.initialize_all_variables() | |
sess.run(init) #変数を初期化して実行 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
init = tf.initialize_all_variables() | |
sess.run(init) #変数を初期化して実行 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment