Skip to content

Instantly share code, notes, and snippets.

@AurelianTactics
Created January 9, 2019 18:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AurelianTactics/e1bb1536871064ffa4ad0462eb49c4ae to your computer and use it in GitHub Desktop.
Save AurelianTactics/e1bb1536871064ffa4ad0462eb49c4ae to your computer and use it in GitHub Desktop.
tf_target_network_update.py
#copies the mainQN values to the targetQN
#from Denny Britz's excellent RL repo
#https://github.com/dennybritz/reinforcement-learning/blob/master/DQN/Double%20DQN%20Solution.ipynb
def copy_model_parameters(sess, estimator1, estimator2):
"""
Copies the model parameters of one estimator to another.
Args:
sess: Tensorflow session instance
estimator1: Estimator to copy the paramters from
estimator2: Estimator to copy the parameters to
"""
e1_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator1.name)]
e1_params = sorted(e1_params, key=lambda v: v.name)
e2_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator2.name)]
e2_params = sorted(e2_params, key=lambda v: v.name)
update_ops = []
for e1_v, e2_v in zip(e1_params, e2_params):
op = e2_v.assign(e1_v)
update_ops.append(op)
sess.run(update_ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment