Skip to content

Instantly share code, notes, and snippets.

@gehring
Created July 30, 2019 15:00
Show Gist options
  • Save gehring/ee88c2572b9dedd235a8db6d5f728736 to your computer and use it in GitHub Desktop.
Save gehring/ee88c2572b9dedd235a8db6d5f728736 to your computer and use it in GitHub Desktop.
def build_eager_policy(name, postprocess_fn, loss_fn, stats_fn):
class EagerPolicy():
def __init__(self, action_space, obs_space):
self.model = get_model(action_space, obs_space)
self.optimizer = make_optimizer()
def postprocess_trajectory(self, batch):
return postprocess_fn(batch)
def compute_actions(self, obs):
outputs = self.model(obs)
actions = outputs.action_dist.sample()
return actions
def learn_on_batch(self, batch):
batch = {(k, tf.as_tensor(v)) for (k, v) in batch.items()}
with tf.gradient_tape() as tape:
outputs = self.model(batch["obs"])
# Allow outputs to be an arbitrary potentially nested
# collection of tensors/arrays. The model packs what is
# necessary.
loss = loss_fn(outputs, batch)
stats = stats_fn(outputs, batch)
grad = tape.gradient(loss)
self.optimizer.apply(grad)
return stats
EagerPolicy.__name__ = name
return EagerPolicy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment