Last active
October 20, 2017 06:59
-
-
Save yukiB/8ade337cfe26076b243d3c6a9cfa63f8 to your computer and use it in GitHub Desktop.
[Python]強化学習(DQN)を実装しながらKerasに慣れる ref: http://qiita.com/yukiB/items/0a3faa759ca5561e12f8
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# input layer (8 x 8) | |
self.x = tf.placeholder(tf.float32, [None, 8, 8]) | |
# flatten (64) | |
x_flat = tf.reshape(self.x, [-1, 64]) | |
# fully connected layer (32) | |
W_fc1 = tf.Variable(tf.truncated_normal([64, 64], stddev=0.01)) | |
b_fc1 = tf.Variable(tf.zeros([64])) | |
h_fc1 = tf.nn.relu(tf.matmul(x_flat, W_fc1) + b_fc1) | |
# output layer (n_actions) | |
W_out = tf.Variable(tf.truncated_normal([64, self.n_actions], stddev=0.01)) | |
b_out = tf.Variable(tf.zeros([self.n_actions])) | |
self.y = tf.matmul(h_fc1, W_out) + b_out | |
# loss function | |
self.y_ = tf.placeholder(tf.float32, [None, self.n_actions]) | |
self.loss = tf.reduce_mean(tf.square(self.y_ - self.y)) | |
# train operation | |
optimizer = tf.train.RMSPropOptimizer(self.learning_rate) | |
self.training = optimizer.minimize(self.loss) | |
# saver | |
self.saver = tf.train.Saver() | |
# session | |
self.sess = tf.Session() | |
self.sess.run(tf.global_variables_initializer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.model = Sequential() | |
self.model.add(InputLayer(input_shape=(8, 8))) | |
self.model.add(Flatten()) | |
self.model.add(Dense(32, activation='relu')) | |
self.model.add(Dense(self.n_actions)) | |
optimizer=RMSprop(lr=self.learning_rate) | |
self.model.compile(loss='mean_squared_error', | |
optimizer=optimizer, | |
metrics=['accuracy']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers.core import Dense, Dropout, Activation, Flatten | |
from keras.layers import Lambda, Input | |
losses = {'loss': lambda y_true, y_pred: y_pred, #dummy loss func | |
'main_output': lambda y_true, y_pred: K.zeros_like(y_pred)} | |
def customized_loss(args): | |
import tensorflow as tf | |
y_true, y_pred, action = args | |
a_one_hot = tf.one_hot(action, K.shape(y_pred)[1], 1.0, 0.0) | |
q_value = tf.reduce_sum(tf.mul(y_pred, a_one_hot), reduction_indices=1) | |
error = tf.abs(q_value - y_true) | |
quadratic_part = tf.clip_by_value(error, 0.0, 1.0) | |
linear_part = error - quadratic_part | |
loss = tf.reduce_sum(0.5 * tf.square(quadratic_part) + linear_part) | |
return loss | |
... | |
def init_model(self): | |
state_input = Input(shape=(1, 8, 8), name='state') | |
action_input = Input(shape=[None], name='action', dtype='int32') | |
x = Flatten()(state_input) | |
x = Dense(32, activation='relu')(x) | |
y_pred = Dense(3, activation='linear', name='main_output')(x) | |
y_true = Input(shape=(1, ), name='y_true') | |
loss_out = Lambda(customized_loss, output_shape=(1, ), name='loss')([y_true, y_pred, action_input]) | |
self.model = Model(input=[state_input, action_input, y_true], output=[loss_out, y_pred]) | |
self.model.compile(loss=losses, | |
optimizer=RMSprop(lr=self.learning_rate), | |
metrics=['accuracy']) | |
slef.init_model() | |
... | |
res = model.predict({'state': np.array([states]), | |
'action': np.array([0]), #dummy | |
'y_true': np.array([[0] * self.n_actions]) #dummy | |
}) | |
return res[1][0] | |
... | |
self.model.fit({'action': np.array(action_minibatch), | |
'state': np.array(state_minibatch), | |
'y_true': np.array(y_minibatch)}, | |
[np.zeros([minibatch_size]), | |
np.array(y_minibatch)], | |
batch_size=minibatch_size, | |
nb_epoch=1, | |
verbose=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def Q_values(self, states): | |
res = self.model.predict(np.array([states])) | |
return res[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# training | |
self.model.fit(np.array(state_minibatch), np.array(y_minibatch), batch_size=minibatch_size,nb_epoch=1,verbose=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_model(self, model_path=None): | |
yaml_string = open(os.path.join(f_model, model_filename)).read() | |
self.model = model_from_yaml(yaml_string) | |
self.model.load_weights(os.path.join(f_model, weights_filename)) | |
self.model.compile(loss='mean_squared_error', | |
optimizer=RMSProp(lr=self.learning_rate), | |
metrics=['accuracy']) | |
def save_model(self, num=None): | |
yaml_string = self.model.to_yaml() | |
model_name = 'dqn_model{0}.yaml'.format((str(num) if num else '')) | |
weight_name = 'dqn_model_weights{0}.hdf5'.format((str(num) if num else '')) | |
open(os.path.join(f_model, model_name), 'w').write(yaml_string) | |
self.model.save_weights(os.path.join(f_model, weight_name)) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import model_from_config | |
def clone_model(model, custom_objects={}): | |
config = { | |
'class_name': model.__class__.__name__, | |
'config': model.get_config(), | |
} | |
clone = model_from_config(config, custom_objects=custom_objects) | |
clone.set_weights(model.get_weights()) | |
return clone | |
self.target_model = clone_model(self.model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
self.target_model = copy.copy(self.model) | |
## deepcopyはエラーになる | |
# self.target_model = copy.deepcopy(self.model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def Q_values(self, states, isTarget=False): | |
model = self.target_model if isTarget else self.model | |
res = model.predict(np.array([states])) | |
return res[0] | |
def store_experience(self, states, action, reward, states_1, terminal): | |
self.D.append((states, action, reward, states_1, terminal)) | |
return (len(self.D) >= self.replay_memory_size) | |
def experience_replay(self): | |
state_minibatch = [] | |
y_minibatch = [] | |
action_minibatch = [] | |
# sample random minibatch | |
minibatch_size = min(len(self.D), self.minibatch_size) | |
minibatch_indexes = np.random.randint(0, len(self.D), minibatch_size) | |
for j in minibatch_indexes: | |
state_j, action_j, reward_j, state_j_1, terminal = self.D[j] | |
action_j_index = self.enable_actions.index(action_j) | |
y_j = self.Q_values(state_j) | |
if terminal: | |
y_j[action_j_index] = reward_j | |
else: | |
if not self.use_ddqn: | |
v = np.max(self.Q_values(state_j_1, isTarget=True)) | |
else: # for DDQN | |
v = self.Q_values(state_j_1, isTarget=True)[action_j_index] | |
y_j[action_j_index] = reward_j + self.discount_factor * v | |
state_minibatch.append(state_j) | |
y_minibatch.append(y_j) | |
action_minibatch.append(action_j_index) | |
# training | |
self.model.fit(np.array(state_minibatch), np.array(y_minibatch), verbose=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def loss_func(y_true, y_pred): | |
error = tf.abs(y_pred - y_true) | |
quadratic_part = tf.clip_by_value(error, 0.0, 1.0) | |
linear_part = error - quadratic_part | |
loss = tf.reduce_sum(0.5 * tf.square(quadratic_part) + linear_part) | |
return loss | |
self.model.compile(loss=loss_func, optimizer='rmsprops', metrics=['accuracy']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
state = tf.placeholder(tf.float32, [None, 8, 8]) # 状態 | |
a = tf.placeholder(tf.int64, [None]) # 行動 | |
supervisor = tf.placeholder(tf.float32, [None]) # 教師信号 | |
output = self.inference(state) | |
loss = lossfunc(output, supervisor) | |
... | |
loss_val = sess.run(loss, feed_dict={ | |
self.state: np.float32(np.array(state_batch), | |
self.action: action_batch, | |
self.super_visor: y_batch | |
}) | |
def lossfunc(self, a, output, supervisor) | |
a_one_hot = tf.one_hot(a, self.num_actions, 1.0, 0.0) # 行動をone hot vectorに変換する | |
q_value = tf.reduce_sum(tf.mul(output, a_one_hot), reduction_indices=1) # 行動のQ値の計算 | |
# エラークリップ | |
error = tf.abs(supervisor - q_value) | |
quadratic_part = tf.clip_by_value(error, 0.0, 1.0) | |
linear_part = error - quadratic_part | |
loss = tf.reduce_mean(0.5 * tf.square(quadratic_part) + linear_part) # 誤差関数 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment