Last active
September 4, 2019 12:17
[Tensorflow] Save and load model weight to/from numpy array csv file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
# (3,4) shape weight define | |
W = tf.Variable([[1,2,3,4], [5,6,7,8], [9,10,11,12]], name='W') | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
weight = sess.run(W) | |
print("Original weight:\n", weight) | |
np.savetxt("weight.csv", weight, delimiter=",") # save [3,4] weights | |
tf.reset_default_graph() | |
call = np.loadtxt("weight.csv", delimiter= ",").astype(np.float32) # load the saved weights | |
# Using tf.assign for initialize the variable to numpy array | |
W2 = tf.get_variable("W2", [3,2]) | |
M = tf.placeholder(tf.float32, shape = (3,2)) | |
W2_init = W2.assign(M) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
weight2 = sess.run(W2_init, feed_dict={M: call[:,:2]}) # Different variable shape, so cut [3,2] from [3,4] weights | |
print("\nLoaded weight:\n", weight2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment