Skip to content

Instantly share code, notes, and snippets.

@takuseno
Created September 1, 2019 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save takuseno/ec5342c9db5f10d04883df2fe83c5a68 to your computer and use it in GitHub Desktop.
Save takuseno/ec5342c9db5f10d04883df2fe83c5a68 to your computer and use it in GitHub Desktop.
import numpy as np
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
#------------------------------- neural network ------------------------------#
def q_network(obs, action):
with nn.parameter_scope('critic'):
out = PF.affine(obs, 64, name='fc1')
out = F.tanh(out)
out = F.concatenate(out, action, axis=1)
out = PF.affine(out, 64, name='fc2')
out = F.tanh(out)
out = PF.affine(out, 1, name='fc3')
return out
def policy_network(obs, action_size):
with nn.parameter_scope('actor'):
out = PF.affine(obs, 64, name='fc1')
out = F.tanh(out)
out = PF.affine(out, 64, name='fc2')
out = F.tanh(out)
out = PF.affine(out, action_size, name='fc3')
return F.tanh(out)
#-----------------------------------------------------------------------------#
#-------------------------- DDPG algorithm -----------------------------------#
class DDPG:
def __init__(self,
obs_shape,
action_size,
batch_size,
critic_lr,
actor_lr,
tau,
gamma):
# inference
self.infer_obs_t = nn.Variable((1,) + obs_shape)
with nn.parameter_scope('trainable'):
self.infer_policy_t = policy_network(self.infer_obs_t, action_size)
# training
self.obs_t = nn.Variable((batch_size,) + obs_shape)
self.act_t = nn.Variable((batch_size, action_size))
self.rew_tp1 = nn.Variable((batch_size, 1))
self.obs_tp1 = nn.Variable((batch_size,) + obs_shape)
self.ter_tp1 = nn.Variable((batch_size, 1))
# critic training
with nn.parameter_scope('trainable'):
q_t = q_network(self.obs_t, self.actions_t)
with nn.parameter_scope('target'):
policy_tp1 = policy_network(self.obs_tp1, action_size)
q_tp1 = q_network(self.obs_tp1, policy_tp1)
y = self.rew_tp1 + gamma * q_tp1 * (1.0 - self.ter_tp1)
self.critic_loss = F.mean(F.squared_error(q_t, y))
# actor training
with nn.parameter_scope('trainable'):
policy_t = policy_network(self.obs_t, action_size)
q_t_with_actor = q_network(self.obs_t, policy_t)
self.actor_loss = -F.mean(q_t_with_actor)
# get neural network parameters
with nn.parameter_scope('trainable'):
with nn.parameter_scope('critic'):
critic_params = nn.get_parameters()
with nn.parameter_scope('actor'):
actor_params = nn.get_parameters()
# setup optimizers
self.critic_solver = S.Adam(critic_lr)
self.critic_solver.set_parameters(critic_params)
self.actor_solver = S.Adam(actor_lr)
self.actor_solver.set_parameters(actor_params)
with nn.parameter_scope('trainable'):
trainable_params = nn.get_parameters()
with nn.parameter_scope('target'):
target_params = nn.get_parameters()
# build target update
update_targets = []
sync_targets = []
for key, src in trainable_params.items():
dst = target_params[key]
update_targets.append(F.assign(dst, (1.0 - tau) * dst + tau * src))
sync_targets.append(F.assign(dst, src))
self.update_target_expr = F.sink(*update_targets)
self.sync_target_expr = F.sink(*sync_targets)
def infer(self, obs_t):
self.infer_obs_t.d = np.array([obs_t])
self.infer_policy_t.forward(clear_buffer=True)
return self.infer_policy_t.d[0]
def train_critic(self, obs_t, actions_t, rewards_tp1, obs_tp1, dones_tp1):
self.obs_t.d = np.array(obs_t)
self.act_t.d = np.array(actions_t)
self.rew_tp1.d = np.array(rewards_tp1)
self.obs_tp1.d = np.array(obs_tp1)
self.ter_tp1.d = np.array(dones_tp1)
self.critic_loss.forward()
self.critic_solver.zero_grad()
self.critic_loss.backward(clear_buffer=True)
self.critic_solver.update()
return self.critic_loss.d
def train_actor(self, obs_t):
self.obs_t.d = np.array(obs_t)
self.actor_loss.forward()
self.actor_solver.zero_grad()
self.actor_loss.backward(clear_buffer=True)
self.actor_solver.update()
return self.actor_loss.d
def update_target(self):
self.update_target_expr.forward(clear_buffer=True)
def sync_target(self):
self.sync_target_expr.forward(clear_buffer=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment