Skip to content

Instantly share code, notes, and snippets.

@andrewliao11
Created November 8, 2017 16:58
Show Gist options
  • Save andrewliao11/cacb6237d2725be1c62e8d12701042a5 to your computer and use it in GitHub Desktop.
Save andrewliao11/cacb6237d2725be1c62e8d12701042a5 to your computer and use it in GitHub Desktop.
Since the structure of each algorithms in baselines is different, we implementation the sample/load function in different ways.
# code locate in baselines/gail
def sample(algo, load_model_path, policy_fn):
assert algo in ['trpo', 'ppo', 'acktr', 'ddpg', 'a2c']
if algo in ['trpo', 'ppo']:
with tf.Session() as sess:
# manually build graph
policy = policy_fn()
# load model
U.load_state(load_model_path)
elif algo in ['acktr', 'ddpg', 'a2c']:
policy = Model(policy_fn) # sess/graph declare inside
policy.load(load_model_path)
# sample expert
Sampler(algo, policy, sample_steps)
def Sampler(algo, policy, sample_steps):
# start sampling code
for _ in range(sample_steps):
if algo == 'ddpg':
actions, q = act(obs)
elif algo == 'a2c' or 'acktr':
actions, values, policy.states = policy.step(obs, policy.states, dones)
elif algo == 'ppo' or 'trpo:
actions, v = policy.act(stochastic, obs)
obs, rw, dones, _ = env.step(act)
@andrewliao11
Copy link
Author

andrewliao11 commented Nov 8, 2017

how about this?
Since the code in ppo, trpouse act() like this

def sample(algo, load_model_path, policy_fn):
  if algo in ['trpo', 'ppo']:
      U.make_session(num_cpu=1).__enter__()
      # manually build graph 
      policy = policy_fn()
      # load model
      U.load_state(load_model_path)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment