Skip to content

Instantly share code, notes, and snippets.

@andrewliao11
Last active November 12, 2017 22:36
Show Gist options
  • Save andrewliao11/02284732148d7bc81e35cce89be6b648 to your computer and use it in GitHub Desktop.
Save andrewliao11/02284732148d7bc81e35cce89be6b648 to your computer and use it in GitHub Desktop.
Since the structure of each algorithms in baselines is different, we implementation the save function in different ways.
# for model in `trpo_mpi`, `ppo`
class CnnPolicy():
def __init__():
# build graph
_ = conv2d()
_ = conv2d()
def step():
sess.run(act, feed_dict)
def train(...):
policy_fn = CnnPolicy
# session declared in `algo.learn`
algo_learn(policy_fn, ...)
# save ckeckpoint
U.save_state(checkpoint)
def algo_learn(policy_fn):
# build graph
policy = policy_fn()
# train_loop
for _ in range(max_steps):
#train body
return
-------------------------------------------------------------------
# for ddpg, acktr, 'a2c'
class Model(policy_fn):
# build graph
# declare session here
policy = policy_fn()
self.sess = tf.Session()
self.saver = tf.train.Saver()
def save():
self.saver.save(self.sess, 'checkpoint%d'%self.iter)
def load():
self.saver.restore(self.sess, load_model_path)
def train(...):
policy = Model()
# training loop
for _ in range(max_steps):
# train body
# save checkpoint
policy.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment