Created
January 9, 2019 18:30
-
-
Save AurelianTactics/e1bb1536871064ffa4ad0462eb49c4ae to your computer and use it in GitHub Desktop.
tf_target_network_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#copies the mainQN values to the targetQN | |
#from Denny Britz's excellent RL repo | |
#https://github.com/dennybritz/reinforcement-learning/blob/master/DQN/Double%20DQN%20Solution.ipynb | |
def copy_model_parameters(sess, estimator1, estimator2): | |
""" | |
Copies the model parameters of one estimator to another. | |
Args: | |
sess: Tensorflow session instance | |
estimator1: Estimator to copy the paramters from | |
estimator2: Estimator to copy the parameters to | |
""" | |
e1_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator1.name)] | |
e1_params = sorted(e1_params, key=lambda v: v.name) | |
e2_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator2.name)] | |
e2_params = sorted(e2_params, key=lambda v: v.name) | |
update_ops = [] | |
for e1_v, e2_v in zip(e1_params, e2_params): | |
op = e2_v.assign(e1_v) | |
update_ops.append(op) | |
sess.run(update_ops) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment