Skip to content

Instantly share code, notes, and snippets.

@fgvbrt
Created July 3, 2017 18:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save fgvbrt/a97181579abdd3f11ef5171f8b4a7f8f to your computer and use it in GitHub Desktop.
Save fgvbrt/a97181579abdd3f11ef5171f8b4a7f8f to your computer and use it in GitHub Desktop.
import numpy as np
import theano
import theano.tensor as T
import lasagne
from collections import OrderedDict
def get_adam_steps_and_updates(all_grads, params, learning_rate=0.001,
beta1=0.9, beta2=0.999, epsilon=1e-8):
t_prev = theano.shared(lasagne.utils.floatX(0.))
updates = OrderedDict()
# Using theano constant to prevent upcasting of float32
one = T.constant(1)
t = t_prev + 1
a_t = learning_rate*T.sqrt(one-beta2**t)/(one-beta1**t)
adam_steps = []
for param, g_t in zip(params, all_grads):
value = param.get_value(borrow=True)
m_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
v_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
m_t = beta1*m_prev + (one-beta1)*g_t
v_t = beta2*v_prev + (one-beta2)*g_t**2
step = a_t*m_t/(T.sqrt(v_t) + epsilon)
updates[m_prev] = m_t
updates[v_prev] = v_t
adam_steps.append(step)
updates[t_prev] = t
return adam_steps, updates
def build_model(state_size, num_act,
critic_loss_coeff=0.5,
entropy_coeff=0.0001,
learning_rate=0.00025):
# input tensors
states = T.matrix('states')
v_targets = T.vector('v_target')
actions = T.matrix('actions')
l_input = lasagne.layers.InputLayer([None, state_size])
# actor layers
l_hid = l_input
for i in xrange(1):
l_hid = lasagne.layers.DenseLayer(
l_hid, 64,
nonlinearity=lasagne.nonlinearities.elu
)
l_actor = lasagne.layers.DenseLayer(
l_hid, num_act*2,
nonlinearity=lasagne.nonlinearities.rectify
)
l_actor_a = lasagne.layers.SliceLayer(l_actor, slice(0, num_act))
l_actor_b = lasagne.layers.SliceLayer(l_actor, slice(num_act, None))
# critic layers
l_hid = l_input
for i in xrange(1):
l_hid = lasagne.layers.DenseLayer(
l_hid, 64,
nonlinearity=lasagne.nonlinearities.elu
)
l_critic = lasagne.layers.DenseLayer(
l_hid, 1,
nonlinearity=lasagne.nonlinearities.identity
)
# calculate prediction
a = lasagne.layers.get_output(l_actor_a, states) + T.constant(10e-8)
b = lasagne.layers.get_output(l_actor_b, states) + T.constant(10e-8)
v_vals = lasagne.layers.get_output(l_critic, states)
v_vals = T.flatten(v_vals)
# make logvar shared variable
'''
logsigma_w = theano.shared(np.random.rand(num_act).astype('float32'))
logsigma = -1*lasagne.nonlinearities.rectify(logsigma_w) - T.constant(1.2)
logsigma_sum = T.sum(logsigma)
sigma = T.exp(logsigma)
'''
# CRITIC
td_error = v_targets - v_vals
critic_loss = 0.5 * (td_error ** 2)
critic_loss = T.mean(critic_loss)
# ACTOR
# entropy term
#entropy = 0.5*num_act*(1. + T.log(2.*np.pi)) + logsigma_sum
entropy = T.sum((1. - 1./a) + (1.-1./b)*T.log(b) + T.log(a*b), axis=1)
# objective part
#log_prob = -1.*logsigma_sum - 0.5*(num_act*T.log(2.*np.pi) + T.sum(((actions-mu)/sigma)**2, axis=1))
log_prob = T.sum(T.log(a) + T.log(b) + (b-1.)*T.log(1. - actions**a + 10e-8) + (a-1.)*T.log(b), axis=1)
adv = theano.gradient.disconnected_grad(td_error)
#actor_loss = -1. * (log_prob * adv + entropy_coeff*entropy)
actor_loss = -1. * (log_prob * adv)
actor_loss = T.mean(actor_loss)
# total loss
total_loss = actor_loss + critic_loss_coeff*critic_loss
# combine params
actor_params = lasagne.layers.get_all_params(l_actor)
crit_params = lasagne.layers.get_all_params(l_critic)
params = [p for p in crit_params if p not in actor_params] + actor_params
#params.append(logsigma_w)
# calculate grads and steps
grads = T.grad(total_loss, params)
grads = lasagne.updates.total_norm_constraint(grads, 10)
steps, updates = get_adam_steps_and_updates(grads, params, learning_rate)
steps_fn = theano.function([states, v_targets, actions], steps, updates=updates)
actor_fn = theano.function([states], [a, b])
val_fn = theano.function([states], v_vals)
return steps_fn, actor_fn, val_fn, params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment