Skip to content

Instantly share code, notes, and snippets.

@AurelianTactics
Created October 26, 2018 15:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AurelianTactics/1098480190b543b53bf08e7508a06caf to your computer and use it in GitHub Desktop.
Save AurelianTactics/1098480190b543b53bf08e7508a06caf to your computer and use it in GitHub Desktop.
if self.td3_variant:
logger.info('using TD3 variant model')
self.normalized_critic_tf, self.normalized_critic_tf2 = critic(normalized_obs0, self.actions)
self.critic_tf = denormalize(
tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
self.normalized_critic_with_actor_tf, _ = critic(normalized_obs0, self.actor_tf, reuse=True)
self.critic_with_actor_tf = denormalize(
tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]),
self.ret_rms)
out_q1, out_q2 = target_critic(normalized_obs1, target_actor(normalized_obs1))
min_q1 = tf.minimum(out_q1,out_q2)
Q_obs1 = denormalize(min_q1, self.ret_rms)
else:
self.normalized_critic_tf = critic(normalized_obs0, self.actions)
self.critic_tf = denormalize(
tf.clip_by_value(self.normalized_critic_tf, self.return_range[0], self.return_range[1]), self.ret_rms)
self.normalized_critic_with_actor_tf = critic(normalized_obs0, self.actor_tf, reuse=True)
self.critic_with_actor_tf = denormalize(
tf.clip_by_value(self.normalized_critic_with_actor_tf, self.return_range[0], self.return_range[1]),
self.ret_rms)
Q_obs1 = denormalize(target_critic(normalized_obs1, target_actor(normalized_obs1)), self.ret_rms)
self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment