This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#in graph | |
#TRFL way | |
target_network_update_ops = trfl.update_target_variables(targetQN.get_qnetwork_variables(),mainQN.get_qnetwork_variables(),tau=1.0) | |
#in session | |
with tf.Session() as sess: | |
#... | |
for ep in range(1, train_episodes): | |
#... | |
#update target q network |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QNetwork: | |
def __init__(self, name, learning_rate=0.01, state_size=4, | |
action_size=2, hidden_size=10, batch_size=20): | |
#same code here | |
#... | |
#method to get trainable variables for TRFL | |
def get_qnetwork_variables(self): | |
return [t for t in tf.trainable_variables() if t.name.startswith(self.name)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#TRFL qlearning | |
#qloss, q_learning = trfl.qlearning(self.output,self.actions_,self.reward,self.discount,self.targetQs_) | |
#TRFL double qlearing | |
qloss, q_learning = trfl.double_qlearning(self.output,self.actions_,self.reward,self.discount,self.targetQs_,self.output) | |
self.loss = tf.reduce_mean(qloss) | |
self.opt = tf.train.AdamOptimizer(learning_rate).minimize(self.loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#tutorial way | |
#targets = rewards + gamma * np.max(target_Qs, axis=1) | |
# loss, _ = sess.run([mainQN.loss, mainQN.opt], | |
# feed_dict={mainQN.inputs_: states, | |
# mainQN.targetQs_: targets, | |
# mainQN.actions_: actions}) | |
#TRFL way, calculate td_error within TRFL | |
loss, _ = sess.run([mainQN.loss, mainQN.opt], | |
feed_dict={mainQN.inputs_: states, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#standard way from tutorial: https://github.com/udacity/deep-learning/blob/master/reinforcement/Q-learning-cart.ipynb | |
#self.Q = tf.reduce_sum(tf.multiply(self.output, one_hot_actions), axis=1) | |
#self.loss = tf.reduce_mean(tf.square(self.targetQs_ - self.Q)) | |
#self.opt = tf.train.AdamOptimizer(learning_rate).minimize(self.loss) | |
#TRFL way | |
self.targetQs_ = tf.placeholder(tf.float32, [batch_size,action_size], name='target') | |
self.reward = tf.placeholder(tf.float32,[batch_size],name="reward") | |
self.discount = tf.constant(0.99,shape=[batch_size],dtype=tf.float32,name="discount") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#standard q_learning | |
# max_q_value = np.max(q_table[next_obs_vel_index,next_obs_angle_index,:]) | |
# q_table[obs_vel_index,obs_angle_index,action] = q_table[obs_vel_index,obs_angle_index,action] \ | |
# + alpha *(reward + gamma*max_q_value - q_table[obs_vel_index,obs_angle_index,action]) | |
#with trfl.qlearning | |
qlearning_output = sess.run([q_learning_tab],feed_dict={qt_tab:np.expand_dims(q_table[obs_vel_index,obs_angle_index,:],axis=0), | |
qt_next_tab:np.expand_dims(q_table[next_obs_vel_index,next_obs_angle_index,:],axis=0), | |
reward_tab:np.expand_dims(reward,axis=0), | |
action_tab:np.expand_dims(action,axis=0)}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Tabular Q Learning on CartPole using trfl.qlearing | |
#simple trfl.qlearning example using CartPole | |
num_actions = env.action_space.n | |
batch_size = 1 | |
qt_tab = tf.placeholder(dtype=tf.float32,shape=[batch_size,num_actions],name="qt_tab") | |
qt_next_tab = tf.placeholder(dtype=tf.float32,shape=[batch_size,num_actions],name="qt_tab") | |
reward_tab = tf.placeholder(dtype=tf.float32,shape=[batch_size],name="reward_tab") | |
action_tab = tf.placeholder(dtype=tf.int32,shape=[batch_size],name="action_tab") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#create a vector of multiple environment | |
#cleaner make_env example here: https://github.com/openai/baselines/blob/6e607efa905a5d5aedd8260afaecb5ad981d713c/baselines/common/cmd_util.py | |
def make_vec_env(args, time_int, start_index=0): | |
""" | |
Create a wrapped, monitored SubprocVecEnv | |
""" | |
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 | |
seed = args.seed + 10000 * mpi_rank if args.seed is not None else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from baselines.common.models import register | |
from baselines.a2c import utils | |
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch | |
#define your network. this is the nature CNN with tf.nn.leaky_relu instead of relu | |
def custom_cnn(unscaled_images, **conv_kwargs): | |
scaled_images = tf.cast(unscaled_images, tf.float32) / 255. | |
activ = tf.nn.leaky_relu | |
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), | |
**conv_kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if self.td3_policy_noise > 0: | |
noise = np.random.normal(loc=0.0,scale=self.td3_policy_noise,size=np.shape(batch['actions'])) | |
noise = np.clip(noise,-self.td3_noise_clip,self.td3_noise_clip) | |
# Get all gradients and perform a synced update. | |
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss] | |
actor_grads, actor_loss, critic_grads, critic_loss = self.sess.run(ops, feed_dict={ | |
self.obs0: batch['obs0'], | |
self.actions: np.clip(batch['actions'] + noise,self.action_range[0],self.action_range[1]), | |
self.critic_target: target_Q, | |
}) |