Skip to content

Instantly share code, notes, and snippets.

@codingPingjun
Created April 4, 2017 01:32
Show Gist options
  • Save codingPingjun/f7837609a4f6b15be482d168a0511e61 to your computer and use it in GitHub Desktop.
Save codingPingjun/f7837609a4f6b15be482d168a0511e61 to your computer and use it in GitHub Desktop.
Model Save and Load
def save_model(sess, saver, checkpoint_dir, model_name, step):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)
def load_model(sess, saver, checkpoint_dir):
print("[*] Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
return True
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment