This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
env = gym.make('CartPole-v1') | |
agent = Agent(env) | |
agent.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Agent(): | |
def __init__(self, env): | |
# set hyper parameters | |
self.max_episodes = 10000 | |
self.max_actions = 10000 | |
self.exploration_rate = 1.0 | |
self.exploration_decay = 0.0001 | |
# set environment | |
self.env = env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QNET(): | |
def update(self): | |
""" for updatte target network""" | |
self.session.run(self.update_opt) | |
def get_action(self, state, e_rate): | |
""" for training stage of the Agent, exploitation or exploration""" | |
if np.random.random()<e_rate: # exploration | |
return np.random.choice(self.out_units) | |
else: # exploitation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QNET(): | |
def batch_train(self, batch_size=64): | |
"""Implement Double DQN Algorithm, batch training""" | |
if self.exp.get_num() < self.exp.get_min(): | |
#The number of experiences is not enough for batch training | |
return | |
# get a batch of experiences | |
state, action, reward, next_state, done = self.exp.get_batch(batch_size) | |
state = state.reshape(batch_size, self.in_units) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QNET(): | |
def __init__(self, in_units, out_units, exp, hidden_units=250): | |
# experience replay | |
self.exp = exp | |
def _batch_learning_model(self): | |
"""For batch learning""" | |
with tf.variable_scope('qnet'): | |
# TD-target | |
self.target = tf.placeholder(tf.float32, shape=(None, )) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QNET(): | |
def __init__(self, in_units, out_units, exp, hidden_units=250): | |
# Target Network | |
self.tnet = TNET(in_units, out_units) | |
# Q network architecture | |
self.in_units = in_units | |
self.out_units = out_units | |
self.hidden_units = hidden_units | |
self._model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TNET(): | |
""" | |
Target network is for calculating the maximum estimated Q-value in given action a. | |
""" | |
def __init__(self, in_units, out_units, hidden_units=250): | |
self.in_units = in_units | |
self.out_units = out_units | |
self.hidden_units = hidden_units | |
self._model() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ExpReplay(): | |
def __init__(self, e_max=15000, e_min=100): | |
self._max = e_max # maximum number of experiences | |
self._min = e_min # minimum number of experiences for training | |
self.exp = {'state':[], 'action':[], 'reward':[], 'next_state':[], 'done':[]} # total experiences the Agent stored | |
def get_max(self): | |
"""return the maximum number of experiences""" | |
return self._max | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
env = gym.make('FrozenLake-v0') | |
next_state, rewards, done, _ = env.step(action) | |
''' | |
α - learning rate | |
γ - discount factor | |
''' | |
QA(state, action) = QA(state, action) + α(rewards + γ*expected_q - QA(state, action)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# RegionRouting Overview | |
G = GraphPartition() | |
start = depot | |
while ActiveRegions() is None: | |
subG = ActiveRegions().get() # apply fast kNN search | |
result = RegionRouting(subG) | |
enqueue(Route, result) | |
start = result.end_object |
NewerOlder