Skip to content

Instantly share code, notes, and snippets.

@Dviejopomata
Created May 6, 2019 11:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dviejopomata/ea04a485a0c653fbf708e3417ec4c51f to your computer and use it in GitHub Desktop.
Save Dviejopomata/ea04a485a0c653fbf708e3417ec4c51f to your computer and use it in GitHub Desktop.
import itertools
import random
import numpy as np
from loguru import logger
def my_filter(record):
if "warn_only" in record["extra"]:
return record["level"].no >= logger.level("WARNING").no
return True
import os
os.remove("debug.log")
logger.remove()
import sys
logger.add(sys.stdout, level="INFO")
logger.add("debug.log", format="{message}", enqueue=True, level="DEBUG", filter=my_filter)
class Game:
"""
Juego 1D
"""
def __init__(self, random_state=False):
"""
0 - casilla sin usar
1 - jugador
2 - recompensa
"""
self.random_state = random_state
if random_state:
self.state = random.sample([0, 1, 2], 3)
else:
self.state = [1, 0, 2]
"""
Accion 0 - Izquierda
Accion 1 - Derecha
"""
self.actions = [0, 1]
def step(self, action):
distance = 1 if action == 1 else -1
logger.debug({'action': action, 'distance': distance, 'state': self.state})
player_idx = self.state.index(1)
reward_idx = self.state.index(2)
reward = 0
if player_idx == 0 and distance < 0 or player_idx == 2 and distance > 0:
reward = -10
done = False
else:
new_idx = player_idx + distance
self.state[player_idx] = 0
self.state[new_idx] = 1
done = new_idx == reward_idx
if done:
reward = 10
return self.state, reward, done, {}
def reset(self):
if self.random_state:
self.state = random.sample([0, 1, 2], 3)
else:
self.state = [1, 0, 2]
return self.state
def state_to_number(self):
map = {
0: [1, 0, 2],
1: [0, 1, 2],
2: [0, 2, 1],
3: [1, 2, 0],
4: [2, 1, 0],
5: [2, 0, 1],
}
return [key for key, value in map.items() if value == self.state]
game = Game(random_state=True)
LR = 0.01
GAMMA = 0.9
NUM_EPISODES = 1000
epsilon = 0.8
states_len = sum(1 for i in itertools.combinations(range(4), 2))
q_table = np.array([
# estado 0 [1,0,2]
[10, -1],
# estado 1 [0,1,2]
[10, -1],
# estado 2 [0,2,1]
[0, 0],
# estado 3 [1,2,0]
[0, 0],
# estado 4 [2,1,0]
[0, 0],
# estado 5 [2,0,1]
[0, 0],
], dtype="float64")
q_table = np.random.uniform(low=-1, high=1, size=(states_len, len(game.actions)))
logger.info(q_table)
for i in range(NUM_EPISODES):
state = game.reset()
state_id = game.state_to_number()
logger.debug("Episode {} state={} state_id={}", i, state, state_id)
while True:
if np.random.random() < 1 - epsilon:
action = np.argmax(q_table[state_id])
else:
action = np.random.randint(0, len(game.actions))
state2, reward, done, info = game.step(action)
logger.debug([state2, reward, done])
if not done:
state2_id = game.state_to_number()
logger.debug('state {}', state2_id)
delta = LR * (reward + GAMMA * np.max(q_table[state2_id]) - q_table[state_id, action])
q_table[state_id, action] += delta
state_id = state2_id
else:
q_table[state_id, action] = reward
if done:
logger.debug("done {}", q_table)
break
logger.info(q_table)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment