Created
June 30, 2017 20:12
-
-
Save tsu-nera/57c4b1c84ce2470e8405d410e9fdfcfa to your computer and use it in GitHub Desktop.
強化学習(Q-Learning)で LEGO Mindstormsの crawlerを動かす
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ev3dev.ev3 as ev3 | |
import random, os, time | |
import pickle | |
class QLearningAgent(): | |
""" | |
Q-Learning Agent | |
Instance variables | |
- self.epsilon (exploration prob) | |
- self.alpha (learning rate) | |
- self.discount (discount rate aka gamma) | |
Functions | |
- self.getLegalActions(state) | |
which returns legal actions for a state | |
- self.getQValue(state,action) | |
which returns Q(state,action) | |
- self.setQValue(state,action,value) | |
which sets Q(state,action) := value | |
- self.saveQ() | |
which save Q | |
- self.loadQ() | |
which load Q | |
""" | |
def __init__(self, alpha, epsilon, discount, getLegalActions): | |
"We initialize agent and Q-values here." | |
self.getLegalActions = getLegalActions | |
self._qValues = {} | |
self.alpha = alpha | |
self.epsilon = epsilon | |
self.discount = discount | |
self._pickle_name = 'qvalue.pickle' | |
def getQValue(self, state, action): | |
if not (state in self._qValues) or not (action in self._qValues[state]): | |
return 0.0 | |
else: | |
return self._qValues[state][action] | |
def setQValue(self, state, action, value): | |
self._qValues[state][action] = value | |
def getValue(self, state): | |
possibleActions = self.getLegalActions(state) | |
if len(possibleActions) == 0: | |
return 0.0 | |
best_q = None | |
for a in possibleActions: | |
v = self.getQValue(state, a) | |
if best_q == None or v > best_q: | |
best_q = v | |
return best_q | |
def getPolicy(self, state): | |
possibleActions = self.getLegalActions(state) | |
if len(possibleActions) == 0: | |
return None | |
best_q = None | |
best_action = None | |
for a in possibleActions: | |
v = self.getQValue(state, a) | |
if best_q == None or v > best_q: | |
best_q = v | |
best_action = a | |
return best_action | |
def getAction(self, state): | |
possibleActions = self.getLegalActions(state) | |
if len(possibleActions) == 0: | |
return None | |
epsilon = self.epsilon | |
if random.random() < epsilon: | |
action = random.choice(possibleActions) | |
else: | |
action = self.getPolicy(state) | |
return action | |
def update(self, state, action, nextState, reward): | |
if not (state in self._qValues): | |
self._qValues[state] = {} | |
if not (action in self._qValues[state]): | |
self._qValues[state][action] = reward | |
else: | |
gamma = self.discount | |
learning_rate = self.alpha | |
reference_qvalue = reward + gamma * self.getValue(nextState) | |
updated_qvalue = (1 - learning_rate) * self.getQValue(state, action) + learning_rate * reference_qvalue | |
self.setQValue(state, action, updated_qvalue) | |
def saveQ(self): | |
print("save Q table") | |
with open(self._pickle_name, 'wb') as handle: | |
pickle.dump(self._qValues, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
def loadQ(self): | |
if not os.path.exists(self._pickle_name): | |
return | |
print("load Q table") | |
with open(self._pickle_name, 'rb') as handle: | |
self._qValues = pickle.load(handle) | |
class Environment(): | |
def __init__(self): | |
self._front_motor = ev3.LargeMotor('outA') | |
self._rear_motor = ev3.LargeMotor('outD') | |
self._sonner = ev3.UltrasonicSensor('in4') | |
self.action_space_n = 2 | |
self._front_motor.reset() | |
self._rear_motor.reset() | |
self.observation = 0 | |
def reset(self): | |
self._rear_motor.run_to_rel_pos(position_sp=-1*self._rear_motor.position, | |
speed_sp=1000, stop_action='hold') | |
while self._rear_motor.is_running: | |
time.sleep(1) | |
self._rear_motor.reset() | |
self.observation = 0 | |
return self.observation, 0, False, {} | |
def step(self, action): | |
reward = 0 | |
done = False | |
motor = None | |
position_sp = 0 | |
start_len = self._sonner.value() | |
if action == 0: | |
self.observation += 1 | |
position_sp = 45 | |
elif action == 1: | |
self.observation -= 1 | |
position_sp = -45 | |
self._rear_motor.run_to_rel_pos(position_sp=position_sp, speed_sp=1000, stop_action='hold') | |
while self._rear_motor.is_running: | |
if self._rear_motor.is_stalled: | |
return self.observation, -10, True, {} | |
time.sleep(0.5) | |
if (start_len - self._sonner.value()) > 20: | |
ev3.Sound().beep() | |
reward = 10 | |
done = True | |
return self.observation, reward, done, {} | |
def run(env, agent): | |
agent.loadQ() | |
try: | |
for i in range(10000): | |
play_and_train(env, agent) | |
if i % 30 == 0: | |
agent.saveQ() | |
finally: | |
agent.saveQ() | |
def play_and_train(env, agent, t_max=20): | |
total_rewards = 0 | |
state, _, _, _= env.reset() | |
for _ in range(t_max): | |
action = agent.getAction(state) | |
print(state, action) | |
next_s, r, done, _ = env.step(action) | |
agent.update(state, action, next_s, r) | |
state = next_s | |
total_rewards += r | |
if done: | |
break | |
if __name__ == '__main__': | |
env = Environment() | |
agent = QLearningAgent(alpha=0.1, epsilon=0.2,discount=0.99, | |
getLegalActions = lambda s: range(env.action_space_n)) | |
run(env, agent) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment