Skip to content

Instantly share code, notes, and snippets.

@PabRod
Created July 24, 2019 14:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PabRod/f093ebf33fab36dabcc22baa2d5efa9e to your computer and use it in GitHub Desktop.
Save PabRod/f093ebf33fab36dabcc22baa2d5efa9e to your computer and use it in GitHub Desktop.
Simple example of q-learning
# This gist reproduces the algorithm available at:
# http://mnemstudio.org/path-finding-q-learning-tutorial.htm
import numpy as np
import random
## Initialize q-table
Nstates = 6
Nactions = 6
Q = np.zeros((Nstates, Nactions))
## Set the reward matrix
R = np.array([
[-1, -1, -1, -1, 0, -1],
[-1, -1, -1, 0, -1, 100],
[-1, -1, -1, 0, -1, -1],
[-1, 0, 0, -1, 0, -1],
[0, -1, -1, 0, -1, 100],
[-1, 0, -1, -1, 0, 100]
])
def random_indices(mat):
""" Returns a random position in the matrix
"""
nrows = len(mat)
ncols = len(mat[0])
indices = [random.randint(0, nrows-1), random.randint(0, ncols-1)]
return indices
def accessible_states(current_state, R):
""" Returns the indices of the accessible future states
"""
is_accessible = (R[current_state, :] != -1)
return (np.where(is_accessible))
# def optimal_action(current_state, R):
# """ Returns the optimal action
# """
# return (np.argmax(R[current_state, :])) #TODO: return multiple coincidences
def updateQ(Q_current, R, goal_state, g=0.8, init = 'random'):
""" Simulates a single episode
"""
## Choose a starting point
if init == 'random': # Randomly
(state_current, unused) = random_indices(Q_current)
else: # Or provided as an input
state_current = init
Q_updated = Q_current
while state_current != goal_state:
possibilities = accessible_states(state_current, R)[0] # From the available transitions...
state_next = possibilities[random.randint(0, len(possibilities)-1)] # ... choose one randomly
## Update
Q_updated[state_current, state_next] = R[state_current, state_next] + g * np.max(Q_updated[state_next, :])
state_current = state_next
return Q_updated
## Train!
episodes = 500
for i in range(0, episodes):
Q = updateQ(Q, R, goal_state = 5)
print(Q/np.max(Q))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment