Created
April 2, 2018 08:01
-
-
Save immars/7d22e32104528199ab56ef48f957c8a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SVGDUpdater(NetworkUpdater): | |
def __init__(self, generator_net, logprob_net, m_particles=16, alpha_exploration=1.0): | |
""" | |
SVGD, updates generator_net to match PDF of logprob_net. | |
Using unit gaussian kernel. | |
:param generator_net: generator_net(state, noise) => action | |
:type generator_net: Network | |
:param logprob_net: logprob_net(state, action) => log pdf of action | |
:type logprob_net: Network | |
""" | |
super(SVGDUpdater, self).__init__() | |
self._m_particles = m_particles | |
self._alpha_exploration = alpha_exploration | |
m = m_particles | |
h_epsilon = 1e-3 | |
self._generator_net, self._logprob_net = generator_net, logprob_net | |
state_shape = generator_net.inputs[0].shape.as_list() | |
noise_shape = generator_net.inputs[1].shape.as_list() | |
action_shape = logprob_net.inputs[1].shape.as_list() | |
dim_state = state_shape[1] | |
self._dim_noise = noise_shape[1] | |
dim_action = action_shape[1] | |
with tf.name_scope("inputs"): | |
self._input_state = tf.placeholder(tf.float32, state_shape, "input_state") | |
self._input_noise = tf.placeholder(tf.float32, noise_shape, "input_noise") | |
self._input_alpha = tf.placeholder(tf.float32, [], "input_alpha") | |
with tf.name_scope("generate"): | |
state_batch = tf.shape(self._input_state)[0] | |
state = tf.reshape(tf.tile( | |
tf.reshape(self._input_state, shape=[-1, 1, dim_state]) | |
, (1, m, 1)) | |
, shape=[-1, dim_state]) | |
noise = tf.tile(self._input_noise, (state_batch, 1)) | |
# generate action with tuple: | |
# (s0, n0), (s0, n1), ..., (s1, n0), (s1, n1), ... | |
generator_net = generator_net([state, noise], name_scope="batch_generator") | |
for name in generator_net.outputs: | |
actions = generator_net[name].op | |
break | |
# actions: [bs * m, dim_a] | |
action_square = tf.tile(tf.reshape(actions, [-1, 1, m, dim_action]), (1, m, 1, 1)) | |
# sub: [b_s, m, m, dim_a] | |
action_sub = tf.transpose(action_square, perm=[0, 2, 1, 3]) - action_square | |
# dis square: [b_s, m, m] | |
dis_square = tf.reduce_sum(tf.square(action_sub), axis=3) | |
# h: [b_s] | |
# h from median | |
# median_square, _ = tf.nn.top_k(tf.reshape(dis_square, [-1, m * m]), m * m // 2 + 1, True) | |
# median_square = median_square[:, -1] | |
# h = tf.sqrt(median_square) | |
# h from mean | |
h = tf.reduce_mean(tf.sqrt(dis_square), axis=(1, 2)) | |
# h = h / (2 * np.log(m + 1)) | |
# h = h**2 / (2 * np.log(m + 1)) | |
h = h**2 / (np.log(m + 1)) # more stable | |
h = h + h_epsilon | |
self._h = h | |
# k: [bs, m, m] | |
k = tf.exp(-1.0 / tf.reshape(h, (-1, 1, 1)) * dis_square) | |
# dk: [bs, m, m, dim_a] | |
dk = tf.reshape(k, (-1, m, m, 1)) * (2 / tf.reshape(h, (-1, 1, 1, 1))) * action_sub | |
# dlogprob: [bs, m, 1] | |
logprob_net = logprob_net([state, actions], name_scope="batch_logprob") | |
for name in logprob_net.outputs: | |
action_logprob = logprob_net[name].op | |
break | |
dlogp = tf.gradients(action_logprob, actions) | |
# dlogp/da: [bs, m, m, dim_a] | |
dlogp_matrix = tf.tile(tf.reshape(dlogp, (state_batch, 1, m, dim_action)), (1, m, 1, 1)) | |
# svgd gradient: [bs, m, m, dim_a] | |
grad_svgd = tf.reshape(k, (-1, m, m, 1)) * dlogp_matrix + dk * self._input_alpha | |
# [bs, m, dim_a] | |
grad_svgd = tf.reduce_mean(grad_svgd, axis=2) | |
# [bs * m, dim_a] | |
grad_svgd = tf.reshape(grad_svgd, (-1, dim_action)) | |
generator_loss = tf.reduce_mean(-tf.stop_gradient(grad_svgd) * actions) | |
self._loss = generator_loss | |
self._grad_svgd = grad_svgd | |
self._op = MinimizeLoss(self._loss, var_list=generator_net.variables) | |
def declare_update(self): | |
return self._op | |
def update(self, sess, batch, *args, **kwargs): | |
state = batch["state"] | |
noise = np.random.normal(0, 1, (self._m_particles, self._dim_noise)) | |
return UpdateRun(feed_dict={self._input_state: state, | |
self._input_noise: noise, | |
self._input_alpha: self._alpha_exploration}, | |
fetch_dict={ | |
"loss": self._loss, | |
"svgd_grad": self._grad_svgd, | |
"h": self._h | |
}) | |
class SoftVFunction(Function): | |
def __init__(self, q_func, actor_func=None, m_particles=16, alpha_exploration=1.0): | |
self._q_func, self._actor_func = q_func, actor_func | |
self._m_particles = m_particles | |
self._alpha_exploration = alpha_exploration | |
if actor_func is not None: | |
noise_shape = actor_func.inputs[1].shape.as_list() | |
self._dim_noise = noise_shape[1] | |
else: | |
self._dim_noise = None | |
self._dim_action = q_func.inputs[1].shape.as_list()[1] | |
super(SoftVFunction, self).__init__() | |
def __call__(self, *args, **kwargs): | |
next_state = args[0] | |
b_s, m, = next_state.shape[0], self._m_particles | |
next_states = np.repeat(next_state, m, axis=0) | |
if self._actor_func is not None: | |
# sample action according to {actor_func} distribution | |
noises = np.random.normal(0, 1, (b_s * m, self._dim_noise)) | |
next_actions = self._actor_func(next_states, noises) | |
else: | |
# sample action uniformly | |
next_actions = np.random.uniform(-1, 1, [b_s * m, dim_action]) | |
# Equation 10 | |
next_qs = self._q_func(next_states, next_actions).reshape((b_s, m)) | |
next_value = self._alpha_exploration * \ | |
np.log(np.average(np.exp(1.0 / self._alpha_exploration * next_qs), axis=1)) | |
return next_value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment