Created
June 27, 2017 13:07
-
-
Save tsu-nera/3eeb77be38132b18345bbf16725d3c51 to your computer and use it in GitHub Desktop.
Q-LearningでGyroBoyの立ち上げ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import ev3dev.ev3 as ev3 | |
import random, os, time | |
import pickle | |
# qlearningAgents.py | |
# ------------------ | |
## based on http://inst.eecs.berkeley.edu/~cs188/sp09/pacman.html | |
class QLearningAgent(): | |
""" | |
Q-Learning Agent | |
Instance variables | |
- self.epsilon (exploration prob) | |
- self.alpha (learning rate) | |
- self.discount (discount rate aka gamma) | |
Functions | |
- self.getLegalActions(state) | |
which returns legal actions for a state | |
- self.getQValue(state,action) | |
which returns Q(state,action) | |
- self.setQValue(state,action,value) | |
which sets Q(state,action) := value | |
- self.saveQ() | |
which save Q | |
- self.loadQ() | |
which load Q | |
""" | |
def __init__(self, alpha, epsilon, discount, getLegalActions): | |
"We initialize agent and Q-values here." | |
self.getLegalActions = getLegalActions | |
self._qValues = {} | |
self.alpha = alpha | |
self.epsilon = epsilon | |
self.discount = discount | |
self._pickle_name = 'qvalue.pickle' | |
def getQValue(self, state, action): | |
if not (state in self._qValues) or not (action in self._qValues[state]): | |
return 0.0 | |
else: | |
return self._qValues[state][action] | |
def setQValue(self, state, action, value): | |
self._qValues[state][action] = value | |
def getValue(self, state): | |
possibleActions = self.getLegalActions(state) | |
if len(possibleActions) == 0: | |
return 0.0 | |
best_q = None | |
for a in possibleActions: | |
v = self.getQValue(state, a) | |
if best_q == None or v > best_q: | |
best_q = v | |
return best_q | |
def getPolicy(self, state): | |
possibleActions = self.getLegalActions(state) | |
if len(possibleActions) == 0: | |
return None | |
best_q = None | |
best_action = None | |
for a in possibleActions: | |
v = self.getQValue(state, a) | |
if best_q == None or v > best_q: | |
best_q = v | |
best_action = a | |
return best_action | |
def getAction(self, state): | |
possibleActions = self.getLegalActions(state) | |
if len(possibleActions) == 0: | |
return None | |
epsilon = self.epsilon | |
if random.random() < epsilon: | |
action = random.choice(possibleActions) | |
else: | |
action = self.getPolicy(state) | |
return action | |
def update(self, state, action, nextState, reward): | |
if not (state in self._qValues): | |
self._qValues[state] = {} | |
if not (action in self._qValues[state]): | |
self._qValues[state][action] = reward | |
else: | |
gamma = self.discount | |
learning_rate = self.alpha | |
reference_qvalue = reward + gamma * self.getValue(nextState) | |
updated_qvalue = (1 - learning_rate) * self.getQValue(state, action) + learning_rate * reference_qvalue | |
self.setQValue(state, action, updated_qvalue) | |
def saveQ(self): | |
print("save Q table") | |
with open(self._pickle_name, 'wb') as handle: | |
pickle.dump(self._qValues, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
def loadQ(self): | |
if not os.path.exists(self._pickle_name): | |
return | |
with open(self._pickle_name, 'rb') as handle: | |
self._qValues = pickle.load(handle) | |
# Function for fast reading from sensor files | |
def FastRead(infile): | |
infile.seek(0) | |
return int(infile.read().decode().strip()) | |
# Function for fast writing to motor files | |
def FastWrite(outfile, value): | |
outfile.truncate(0) | |
outfile.write(str(int(value))) | |
outfile.flush() | |
# Function to set the duty cycle of the motors | |
def SetDuty(motorDutyFileHandle, duty): | |
# Apply the signal to the motor | |
FastWrite(motorDutyFileHandle, duty) | |
class Environment(): | |
def __init__(self): | |
self._right_motor = ev3.LargeMotor('outA') | |
self._left_motor = ev3.LargeMotor('outD') | |
# Open motor files for (fast) writing | |
self.motorDutyCycleLeft = open(self._left_motor._path + "/duty_cycle_sp", "w") | |
self.motorDutyCycleRight = open(self._right_motor._path + "/duty_cycle_sp", "w") | |
self._gyro_sensor = ev3.GyroSensor('in2') | |
self._touch_sensor = ev3.TouchSensor('in3') | |
self.action_space_n = 2 | |
# self.observation_space = ??? | |
self._position = 0 | |
self._gyro_angle = 0 | |
self._gyro_rate = 0 | |
def is_ready(self): | |
if self._touch_sensor.is_pressed: | |
self.calibrate_gyro() | |
self._right_motor.run_direct(duty_cycle_sp=-50) | |
self._left_motor.run_direct(duty_cycle_sp=-50) | |
return True | |
else: | |
return False | |
def reset(self): | |
self._left_motor.position = 0 | |
self._right_motor.position = 0 | |
return self._state() | |
def stop(self): | |
self._left_motor.position = 0 | |
self._right_motor.position = 0 | |
self._left_motor.stop() | |
self._right_motor.stop() | |
def shutdown(self): | |
self.motorDutyCycleLeft.close() | |
self.motorDutyCycleRight.close() | |
def calibrate_gyro(self): | |
self._gyro_sensor.mode = self._gyro_sensor.MODE_GYRO_RATE | |
self._gyro_sensor.mode = self._gyro_sensor.MODE_GYRO_G_A | |
def step(self, action): | |
self._update() | |
# fallen down | |
if abs(self._gyro_angle) > 25: | |
return (self._state(), 0, True, {}) | |
duty = 50 if action == 0 else -50 | |
SetDuty(self.motorDutyCycleRight, duty) | |
SetDuty(self.motorDutyCycleLeft, duty) | |
# self._right_motor.run_direct(duty_cycle_sp=duty) | |
# self._left_motor.run_direct(duty_cycle_sp=duty) | |
return (self._state(), 1, False, {}) | |
def _update(self): | |
self._position = self._left_motor.position | |
self._gyro_angle, self._gyro_rate = self._gyro_sensor.rate_and_angle | |
def _state(self): | |
return (self._position / 100, 0, self._gyro_angle/100, self._gyro_rate / 100) | |
def build_state(features): | |
"""get our features and put all together converting into an integer""" | |
return int("".join(map(lambda feature: str(int(feature)), features))) | |
def to_bin(value, bins): | |
return np.digitize(x=[value], bins=bins)[0] | |
cart_position_bins = np.linspace(-2.4, 2.4, 2) | |
cart_velocity_bins = np.linspace(-2, 2, 10) | |
pole_angle_bins = np.linspace(-0.4, 0.4, 50) | |
pole_velocity_bins = np.linspace(-3.5, 3.5, 20) | |
def transform(observation): | |
# return an int | |
cart_pos, cart_vel, pole_angle, pole_vel = observation | |
return build_state([ | |
to_bin(cart_pos, cart_position_bins), | |
to_bin(cart_vel, cart_velocity_bins), | |
to_bin(pole_angle, pole_angle_bins), | |
to_bin(pole_vel, pole_velocity_bins) | |
]) | |
def run(env, agent): | |
agent.loadQ() | |
try: | |
for i in range(10000): | |
print("wait:",i) | |
while not env.is_ready(): | |
time.sleep(0.3) | |
print("go") | |
play_and_train(env, agent) | |
if i % 30 == 0: | |
agent.saveQ() | |
finally: | |
agent.saveQ() | |
env.stop() | |
env.shutdown() | |
def play_and_train(env, agent, t_max=10 ** 4): | |
s = env.reset() | |
state_id = transform(s) | |
total_rewards = 0 | |
SLEEP_TIME = 0.01 | |
for _ in range(t_max): | |
start_time = time.time() | |
action = agent.getAction(state_id) | |
next_s, r, done, _ = env.step(action) | |
next_state_id = transform(next_s) | |
agent.update(state_id, action, next_state_id, r) | |
state_id = next_state_id | |
total_rewards += r | |
if done: | |
env.stop() | |
break | |
elapsed_second = time.time() - start_time | |
if elapsed_second < SLEEP_TIME: | |
sleep_time = SLEEP_TIME - elapsed_second | |
time.sleep(sleep_time) | |
print("reward:", total_rewards) | |
if __name__ == '__main__': | |
env = Environment() | |
agent = QLearningAgent(alpha=0.1, epsilon=0.2,discount=0.99, | |
getLegalActions = lambda s: range(env.action_space_n)) | |
run(env, agent) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment