Created
May 6, 2019 11:04
-
-
Save Dviejopomata/ea04a485a0c653fbf708e3417ec4c51f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import random | |
import numpy as np | |
from loguru import logger | |
def my_filter(record): | |
if "warn_only" in record["extra"]: | |
return record["level"].no >= logger.level("WARNING").no | |
return True | |
import os | |
os.remove("debug.log") | |
logger.remove() | |
import sys | |
logger.add(sys.stdout, level="INFO") | |
logger.add("debug.log", format="{message}", enqueue=True, level="DEBUG", filter=my_filter) | |
class Game: | |
""" | |
Juego 1D | |
""" | |
def __init__(self, random_state=False): | |
""" | |
0 - casilla sin usar | |
1 - jugador | |
2 - recompensa | |
""" | |
self.random_state = random_state | |
if random_state: | |
self.state = random.sample([0, 1, 2], 3) | |
else: | |
self.state = [1, 0, 2] | |
""" | |
Accion 0 - Izquierda | |
Accion 1 - Derecha | |
""" | |
self.actions = [0, 1] | |
def step(self, action): | |
distance = 1 if action == 1 else -1 | |
logger.debug({'action': action, 'distance': distance, 'state': self.state}) | |
player_idx = self.state.index(1) | |
reward_idx = self.state.index(2) | |
reward = 0 | |
if player_idx == 0 and distance < 0 or player_idx == 2 and distance > 0: | |
reward = -10 | |
done = False | |
else: | |
new_idx = player_idx + distance | |
self.state[player_idx] = 0 | |
self.state[new_idx] = 1 | |
done = new_idx == reward_idx | |
if done: | |
reward = 10 | |
return self.state, reward, done, {} | |
def reset(self): | |
if self.random_state: | |
self.state = random.sample([0, 1, 2], 3) | |
else: | |
self.state = [1, 0, 2] | |
return self.state | |
def state_to_number(self): | |
map = { | |
0: [1, 0, 2], | |
1: [0, 1, 2], | |
2: [0, 2, 1], | |
3: [1, 2, 0], | |
4: [2, 1, 0], | |
5: [2, 0, 1], | |
} | |
return [key for key, value in map.items() if value == self.state] | |
game = Game(random_state=True) | |
LR = 0.01 | |
GAMMA = 0.9 | |
NUM_EPISODES = 1000 | |
epsilon = 0.8 | |
states_len = sum(1 for i in itertools.combinations(range(4), 2)) | |
q_table = np.array([ | |
# estado 0 [1,0,2] | |
[10, -1], | |
# estado 1 [0,1,2] | |
[10, -1], | |
# estado 2 [0,2,1] | |
[0, 0], | |
# estado 3 [1,2,0] | |
[0, 0], | |
# estado 4 [2,1,0] | |
[0, 0], | |
# estado 5 [2,0,1] | |
[0, 0], | |
], dtype="float64") | |
q_table = np.random.uniform(low=-1, high=1, size=(states_len, len(game.actions))) | |
logger.info(q_table) | |
for i in range(NUM_EPISODES): | |
state = game.reset() | |
state_id = game.state_to_number() | |
logger.debug("Episode {} state={} state_id={}", i, state, state_id) | |
while True: | |
if np.random.random() < 1 - epsilon: | |
action = np.argmax(q_table[state_id]) | |
else: | |
action = np.random.randint(0, len(game.actions)) | |
state2, reward, done, info = game.step(action) | |
logger.debug([state2, reward, done]) | |
if not done: | |
state2_id = game.state_to_number() | |
logger.debug('state {}', state2_id) | |
delta = LR * (reward + GAMMA * np.max(q_table[state2_id]) - q_table[state_id, action]) | |
q_table[state_id, action] += delta | |
state_id = state2_id | |
else: | |
q_table[state_id, action] = reward | |
if done: | |
logger.debug("done {}", q_table) | |
break | |
logger.info(q_table) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment