Created
February 1, 2017 14:03
-
-
Save threecourse/3b428c70c8fad43472affc6ede0b4e9f to your computer and use it in GitHub Desktop.
dqn_keras_cartpole
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding:utf-8 | |
# args | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--test", action="store_true") | |
parser.add_argument("--prms", type=int, action="store", default=0) | |
args = parser.parse_args() | |
TEST = args.test | |
TRAIN = not TEST | |
# imports | |
import os | |
import gym | |
import random | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
from collections import deque | |
from keras.models import Model | |
from keras.layers import Input, Dense, merge, Reshape, Lambda, Dropout | |
from keras import backend as K | |
from keras.optimizers import RMSprop | |
from keras.utils import np_utils | |
# log | |
import logging | |
logger = logging.getLogger("my_logging") | |
logger.setLevel(logging.INFO) | |
fh = logging.FileHandler('result.log') | |
ch = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s : %(message)s') | |
fh.setFormatter(formatter) | |
ch.setFormatter(formatter) | |
logger.addHandler(fh) | |
logger.addHandler(ch) | |
# environment attributes | |
ENV_NAME = 'CartPole-v0' # Environment name | |
STATE_WIDTH = 4 # Resized frame width | |
NUM_ACTIONS = 2 | |
# parameters | |
params = {} | |
params["NUM_EPISODES"] = 1500 | |
params["EPISODE_MAX_STEPS"] = 300 | |
params["NUM_EPISODES_AT_TEST"] = 20 | |
params["INITIAL_REPLAY_SIZE"] = 10000 | |
params["EXPLORATION_STEPS"] = 20000 | |
params["NUM_REPLAY_MEMORY"] = 400000 | |
params["TARGET_UPDATE_INTERVAL"] = 1000 | |
params["NO_OP_STEPS"] = 2 | |
params["TRAIN_INTERVAL"] = 4 | |
params["INITIAL_EPSILON"] = 1.0 | |
params["FINAL_EPSILON"] = 0.1 | |
params["LEARNING_RATE"] = 0.001 | |
params["GAMMA"] = 0.99 | |
params["BATCH_SIZE"] = 256 | |
params["DUMP_DATA"] = False | |
params["SAVE_NETWORK_PATH"] = 'saved_networks/{}.weights.hdf5'.format(ENV_NAME) | |
params["REWARD_TYPE"] = "normal" | |
params["NETWORK_TYPE"] = "network1" | |
params["LOSS_TYPE"] = "mse_clip" | |
if args.prms == 1: | |
params["GAMMA"] = 0.90 | |
elif args.prms == 2: | |
params["NETWORK_TYPE"] = "network3" | |
elif args.prms == 3: | |
params["LEARNING_RATE"] = 0.00025 | |
class RLNetwork: | |
def __init__(self): | |
self.target_network = self.create_target_network() | |
self.q_network = self.create_q_network() | |
def create_network_structure(self): | |
if params["NETWORK_TYPE"] == "network1": | |
Layer1 = 4 | |
Layer2 = 4 | |
input_state = Input(shape=(STATE_WIDTH,), dtype='float32', name='inputs_state') | |
x = Dense(Layer1, activation='relu', name="D1")(input_state) | |
x = Dense(Layer2, activation='relu', name="D2")(x) | |
output_by_action = Dense(NUM_ACTIONS, name="D3")(x) | |
return input_state, output_by_action | |
if params["NETWORK_TYPE"] == "network2": | |
Layer1 = 4 | |
Layer2 = 4 | |
input_state = Input(shape=(STATE_WIDTH,), dtype='float32', name='inputs_state') | |
x = Dense(Layer1, activation='relu', name="D1")(input_state) | |
x = Dropout(0.2)(x) | |
x = Dense(Layer2, activation='relu', name="D2")(x) | |
x = Dropout(0.2)(x) | |
output_by_action = Dense(NUM_ACTIONS, name="D3")(x) | |
return input_state, output_by_action | |
if params["NETWORK_TYPE"] == "network3": | |
Layer1 = 20 | |
Layer2 = 4 | |
input_state = Input(shape=(STATE_WIDTH,), dtype='float32', name='inputs_state') | |
x = Dense(Layer1, activation='relu', name="D1")(input_state) | |
x = Dropout(0.2)(x) | |
x = Dense(Layer2, activation='relu', name="D2")(x) | |
output_by_action = Dense(NUM_ACTIONS, name="D3")(x) | |
return input_state, output_by_action | |
def weights_to_update(self): | |
return ["D1", "D2", "D3"] | |
def create_target_network(self): | |
input_state, output_by_action = self.create_network_structure() | |
model = Model(input=input_state, output=output_by_action) | |
return model | |
def create_q_network(self): | |
input_state, output_by_action = self.create_network_structure() | |
input_action = Input(shape=(NUM_ACTIONS,), dtype='float32', name='inputs_action') | |
# merge mode="dot" not worked | |
x = merge([output_by_action, input_action], mode='mul') | |
x = Lambda(lambda x: K.sum(x, axis=1), output_shape=(1,))(x) | |
output = Reshape((1,))(x) | |
model = Model(input=[input_state, input_action], output=output) | |
optimizer = RMSprop(lr=params["LEARNING_RATE"]) | |
if params["LOSS_TYPE"] == "mse": | |
loss = "mse" | |
elif params["LOSS_TYPE"] == "mse_clip": | |
def mean_squared_error_clip(y_true, y_pred): | |
error = K.abs(y_true - y_pred) | |
quadratic_part = K.clip(error, 0.0, 1.0) | |
linear_part = error - quadratic_part | |
return K.mean(K.square(quadratic_part) + linear_part * 2.0, axis=-1) | |
loss = mean_squared_error_clip | |
else: | |
raise Exception | |
model.compile(optimizer=optimizer, loss=loss) | |
return model | |
def update_weights(self): | |
def get_layer(layers, name): | |
return [l for l in layers if l.name == name][0] | |
for target in self.weights_to_update(): | |
l1 = get_layer(self.target_network.layers, target) | |
l2 = get_layer(self.q_network.layers, target) | |
l1.set_weights(l2.get_weights()) | |
def get_q_values(self, state): | |
# state.shape -> (1, STATE_WIDTH) | |
state = state.reshape(1, -1) | |
return self.target_network.predict(state) | |
def get_q_values_batch(self, state_batch): | |
return self.target_network.predict_on_batch(state_batch) | |
def train_q_network(self, state_batch, action_batch, y_batch): | |
return self.q_network.train_on_batch([state_batch, action_batch], y_batch) | |
class EnvUtil: | |
@classmethod | |
def process_observation(cls, observation): | |
return np.array(observation).astype(np.float32).reshape(-1) | |
@classmethod | |
def get_initial_state(cls, observation, last_observations=None): | |
# depending on task, can use list of observations | |
return cls.process_observation(observation) | |
@classmethod | |
def next_state(cls, state, observation, last_observations=None): | |
# depending on task, can use list of observations | |
return cls.process_observation(observation) | |
@classmethod | |
def train_reward(cls, reward, terminal): | |
if params["REWARD_TYPE"] == "normal": | |
if terminal: | |
return -1 | |
else: | |
return 1 | |
elif params["REWARD_TYPE"] == "negative": | |
if terminal: | |
return -1 | |
else: | |
return 0 | |
else: | |
raise Exception | |
class Status(): | |
def __init__(self): | |
self.training_steps = 0 | |
self.total_steps = 0 | |
self.ep_reward = 0 | |
self.ep_q_max = 0 | |
self.ep_loss = 0 | |
self.ep_steps = 0 | |
self.episode = 0 | |
self.ep_rewards = [] | |
def episode_initialize(self): | |
self.episode += 1 | |
self.ep_reward = 0 | |
self.ep_q_max = 0 | |
self.ep_loss = 0 | |
self.ep_steps = 0 | |
self.total_steps = 0 | |
def episode_finalize(self): | |
self.ep_rewards.append(self.ep_reward) | |
class Agent(): | |
def __init__(self, is_test=False): | |
# reinforcement learning parameters | |
self.NUM_EPISODES = params["NUM_EPISODES"] | |
self.EPISODE_MAX_STEPS = params["EPISODE_MAX_STEPS"] | |
self.NUM_EPISODES_AT_TEST = params["NUM_EPISODES_AT_TEST"] | |
self.INITIAL_REPLAY_SIZE = params["INITIAL_REPLAY_SIZE"] | |
self.EXPLORATION_STEPS = params["EXPLORATION_STEPS"] | |
self.NUM_REPLAY_MEMORY = params["NUM_REPLAY_MEMORY"] | |
self.TARGET_UPDATE_INTERVAL = params["TARGET_UPDATE_INTERVAL"] | |
self.NO_OP_STEPS = params["NO_OP_STEPS"] | |
self.TRAIN_INTERVAL = params["TRAIN_INTERVAL"] | |
self.INITIAL_EPSILON = params["INITIAL_EPSILON"] | |
self.FINAL_EPSILON = params["FINAL_EPSILON"] | |
self.GAMMA = params["GAMMA"] | |
self.BATCH_SIZE = params["BATCH_SIZE"] | |
self.DUMP_DATA = params["DUMP_DATA"] | |
self.SAVE_NETWORK_PATH = params["SAVE_NETWORK_PATH"] | |
self.epsilon = self.INITIAL_EPSILON | |
self.epsilon_step = (self.INITIAL_EPSILON - self.FINAL_EPSILON) / self.EXPLORATION_STEPS | |
self.st = Status() | |
self.replay_memory = deque() | |
self.networks = RLNetwork() | |
if TEST: | |
self.load_weights() | |
def load_weights(self): | |
print "load weights" | |
self.networks.q_network.load_weights(self.SAVE_NETWORK_PATH) | |
self.networks.update_weights() | |
def episode_start(self): | |
self.st.episode_initialize() | |
def get_action_random(self): | |
action = random.randrange(NUM_ACTIONS) | |
return action | |
def get_action(self, state): | |
# epsilon greedy | |
if self.epsilon >= random.random() or self.st.training_steps < self.INITIAL_REPLAY_SIZE: | |
action = random.randrange(NUM_ACTIONS) | |
else: | |
action = np.argmax(self.networks.get_q_values(state)) | |
# anneal epsilon linearly over time | |
if self.epsilon > self.FINAL_EPSILON and self.st.training_steps >= self.INITIAL_REPLAY_SIZE: | |
self.epsilon -= self.epsilon_step | |
return action | |
def get_action_at_test(self, state): | |
action = np.argmax(self.networks.get_q_values(state)) | |
return action | |
def run_step_no_op(self, reward): | |
self.st.ep_reward += reward | |
self.st.ep_steps += 1 | |
self.st.total_steps += 1 | |
def run_step(self, state, action, next_state, reward, train_reward, terminal): | |
# store transition in replay memory | |
self.replay_memory.append((state, action, train_reward, next_state, terminal)) | |
if len(self.replay_memory) > self.NUM_REPLAY_MEMORY: | |
self.replay_memory.popleft() | |
if self.st.training_steps == self.INITIAL_REPLAY_SIZE: | |
self.dump_data() | |
if self.st.training_steps >= self.INITIAL_REPLAY_SIZE: | |
# train network | |
if self.st.training_steps % self.TRAIN_INTERVAL == 0: | |
self.train_network() | |
# update target network | |
if self.st.training_steps % self.TARGET_UPDATE_INTERVAL == 0: | |
self.update_weights() | |
self.st.ep_reward += reward | |
self.st.ep_q_max += np.max(self.networks.get_q_values(state)) | |
self.st.ep_steps += 1 | |
self.st.training_steps += 1 | |
self.st.total_steps += 1 | |
def run_step_at_test(self, state, action, next_state, reward, train_reward, terminal): | |
self.st.ep_reward += reward | |
self.st.ep_q_max += np.max(self.networks.get_q_values(state)) | |
self.st.ep_steps += 1 | |
self.st.training_steps += 1 | |
self.st.total_steps += 1 | |
def episode_end(self): | |
if TRAIN: | |
# update status | |
self.st.episode_finalize() | |
# debug | |
if self.st.training_steps < self.INITIAL_REPLAY_SIZE: | |
mode = 'random' | |
elif self.INITIAL_REPLAY_SIZE <= self.st.training_steps < self.INITIAL_REPLAY_SIZE + self.EXPLORATION_STEPS: | |
mode = 'explore' | |
else: | |
mode = 'exploit' | |
print(('EPISODE: {0:6d} / TRAININGSTEP: {1:8d} / DURATION: {2:5d} / EPSILON: {3:.5f} / ' + | |
'REWARDS: {4:3.0f} / AVG_MAX_Q: {5:2.4f} / ' + | |
'AVG_LOSS: {6:.5f} / MODE: {7}').format( | |
self.st.episode, self.st.training_steps, self.st.ep_steps, self.epsilon, | |
self.st.ep_reward, self.st.ep_q_max / float(self.st.ep_steps), | |
self.st.ep_loss / (float(self.st.ep_steps) / float(self.TRAIN_INTERVAL)), mode)) | |
# score | |
if self.st.episode == self.NUM_EPISODES: | |
msg = "prms: {}, last 100 episodes reward average: {}".format(args.prms, np.array(self.st.ep_rewards[-100:]).mean()) | |
logger.info(msg) | |
# save weights | |
if self.st.episode == self.NUM_EPISODES: | |
print "save weights" | |
self.networks.q_network.save_weights(params["SAVE_NETWORK_PATH"], overwrite=True) | |
else: | |
print(('EPISODE: {0:6d} / DURATION: {1:5d} / ' + | |
'REWARDS: {2:3.0f} / AVG_MAX_Q: {3:2.4f}').format( | |
self.st.episode, self.st.ep_steps, | |
self.st.ep_reward, self.st.ep_q_max / float(self.st.ep_steps))) | |
def process_batch(self, batch_size=None): | |
state_batch = [] | |
action_batch = [] | |
reward_batch = [] | |
next_state_batch = [] | |
terminal_batch = [] | |
if batch_size is None: | |
minibatch = self.replay_memory | |
else: | |
minibatch = random.sample(self.replay_memory, self.BATCH_SIZE) | |
for data in minibatch: | |
state_batch.append(data[0]) | |
action_batch.append(data[1]) | |
reward_batch.append(data[2]) | |
next_state_batch.append(data[3]) | |
terminal_batch.append(data[4]) | |
state_batch = np.array(state_batch) | |
action_batch = np.array(action_batch) | |
reward_batch = np.array(reward_batch) | |
next_state_batch = np.array(next_state_batch) | |
terminal_batch = np.array(terminal_batch) | |
action_batch = np_utils.to_categorical(action_batch) | |
return state_batch, action_batch, reward_batch, next_state_batch, terminal_batch | |
def train_network(self): | |
state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = self.process_batch(self.BATCH_SIZE) | |
# reward + (if not terminal) gamma * max(q_value) | |
target_q_values_batch = self.networks.get_q_values_batch(next_state_batch) | |
y_batch = reward_batch + (1 - terminal_batch) * self.GAMMA * np.max(target_q_values_batch, axis=1) | |
loss = self.networks.train_q_network(state_batch, action_batch, y_batch) | |
self.st.ep_loss += loss | |
def update_weights(self): | |
self.networks.update_weights() | |
def dump_data(self): | |
# for debugging, dump state data | |
if self.DUMP_DATA: | |
state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = self.process_batch() | |
ary = np.hstack([state_batch, action_batch, reward_batch.reshape(-1, 1), next_state_batch, terminal_batch.reshape(-1, 1)]) | |
df = pd.DataFrame(ary) | |
df.to_csv("data/data.txt", sep="\t") | |
def main(): | |
# logger.info("run started") | |
env = gym.make(ENV_NAME) | |
agent = Agent(is_test=TEST) | |
if TRAIN: | |
num_episodes = agent.NUM_EPISODES | |
else: | |
num_episodes = agent.NUM_EPISODES_AT_TEST | |
for _ in xrange(num_episodes): | |
observation = env.reset() | |
no_op_steps = random.randint(1, agent.NO_OP_STEPS) | |
state = None | |
agent.episode_start() | |
while True: | |
# last_observation = observation | |
if agent.st.ep_steps < no_op_steps: | |
action = agent.get_action_random() | |
observation, reward, terminal, info = env.step(action) | |
if TEST: | |
env.render() | |
agent.run_step_no_op(reward) | |
if agent.st.ep_steps == no_op_steps: | |
state = EnvUtil.get_initial_state(observation) | |
else: | |
if TRAIN: | |
action = agent.get_action(state) | |
else: | |
action = agent.get_action_at_test(state) | |
observation, reward, terminal, info = env.step(action) | |
if TEST: | |
env.render() | |
train_reward = EnvUtil.train_reward(reward, terminal) | |
next_state = EnvUtil.next_state(state, observation) | |
if TRAIN: | |
agent.run_step(state, action, next_state, reward, train_reward, terminal) | |
else: | |
agent.run_step_at_test(state, action, next_state, reward, train_reward, terminal) | |
state = next_state | |
if terminal or agent.st.ep_steps == agent.EPISODE_MAX_STEPS: | |
agent.episode_end() | |
break | |
if __name__ == '__main__': | |
main() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
def run(cmd, times=1): | |
for i in range(times): | |
subprocess.check_call(cmd, shell=True) | |
run("python dqn_keras.py --prms 0", times=5) | |
run("python dqn_keras.py --prms 1", times=5) | |
run("python dqn_keras.py --prms 2", times=5) | |
run("python dqn_keras.py --prms 3", times=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2017-02-01 22:22:12,111 : prms: 0, last 100 episodes reward average: 14.62 | |
2017-02-01 22:23:18,734 : prms: 0, last 100 episodes reward average: 59.79 | |
2017-02-01 22:23:55,972 : prms: 0, last 100 episodes reward average: 13.37 | |
2017-02-01 22:25:06,106 : prms: 0, last 100 episodes reward average: 43.1 | |
2017-02-01 22:25:45,898 : prms: 0, last 100 episodes reward average: 13.18 | |
2017-02-01 22:26:58,870 : prms: 1, last 100 episodes reward average: 37.07 | |
2017-02-01 22:27:37,843 : prms: 1, last 100 episodes reward average: 14.15 | |
2017-02-01 22:31:34,060 : prms: 1, last 100 episodes reward average: 247.56 | |
2017-02-01 22:34:28,860 : prms: 1, last 100 episodes reward average: 182.57 | |
2017-02-01 22:37:05,031 : prms: 1, last 100 episodes reward average: 145.93 | |
2017-02-01 22:39:55,277 : prms: 2, last 100 episodes reward average: 106.5 | |
2017-02-01 22:44:34,832 : prms: 2, last 100 episodes reward average: 146.17 | |
2017-02-01 22:46:15,321 : prms: 2, last 100 episodes reward average: 153.11 | |
2017-02-01 22:47:54,614 : prms: 2, last 100 episodes reward average: 90.85 | |
2017-02-01 22:54:06,135 : prms: 2, last 100 episodes reward average: 226.38 | |
2017-02-01 22:54:45,047 : prms: 3, last 100 episodes reward average: 13.2 | |
2017-02-01 22:55:24,530 : prms: 3, last 100 episodes reward average: 14.03 | |
2017-02-01 22:56:02,366 : prms: 3, last 100 episodes reward average: 12.04 | |
2017-02-01 22:56:37,736 : prms: 3, last 100 episodes reward average: 13.36 | |
2017-02-01 22:57:15,393 : prms: 3, last 100 episodes reward average: 13.2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment