Last active
November 12, 2017 22:36
-
-
Save andrewliao11/02284732148d7bc81e35cce89be6b648 to your computer and use it in GitHub Desktop.
Since the structure of each algorithms in baselines is different, we implementation the save function in different ways.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for model in `trpo_mpi`, `ppo` | |
class CnnPolicy(): | |
def __init__(): | |
# build graph | |
_ = conv2d() | |
_ = conv2d() | |
def step(): | |
sess.run(act, feed_dict) | |
def train(...): | |
policy_fn = CnnPolicy | |
# session declared in `algo.learn` | |
algo_learn(policy_fn, ...) | |
# save ckeckpoint | |
U.save_state(checkpoint) | |
def algo_learn(policy_fn): | |
# build graph | |
policy = policy_fn() | |
# train_loop | |
for _ in range(max_steps): | |
#train body | |
return | |
------------------------------------------------------------------- | |
# for ddpg, acktr, 'a2c' | |
class Model(policy_fn): | |
# build graph | |
# declare session here | |
policy = policy_fn() | |
self.sess = tf.Session() | |
self.saver = tf.train.Saver() | |
def save(): | |
self.saver.save(self.sess, 'checkpoint%d'%self.iter) | |
def load(): | |
self.saver.restore(self.sess, load_model_path) | |
def train(...): | |
policy = Model() | |
# training loop | |
for _ in range(max_steps): | |
# train body | |
# save checkpoint | |
policy.save() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment