Last active
December 21, 2016 04:02
-
-
Save culurciello/6ae059dd58f5fd26ec85dd48d3c5785f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# E. Culurciello, example of reinforcement learning in Python | |
# | |
# http://mnemstudio.org/path-finding-q-learning-tutorial.htm | |
# Game: 5 rooms connected through doors. One room is the goal-room. | |
# the goal of the game is to get to the goal-room | |
import numpy as np | |
# this is how rooms are connected: | |
# reward table: 0 = connected-rooms, -1 = no-connection, 100 = goal-room | |
# rows are states (rooms) and columns are possible actions (move to next room) | |
R = np.array( [[-1, -1, -1, -1, 0, -1], | |
[-1, -1, -1, 0, -1, 100], | |
[-1, -1, -1, 0, -1, -1], | |
[-1, 0, 0, -1, 0, -1], | |
[ 0, -1, -1, 0, -1, 100], | |
[-1, 0, -1, -1, 0, 100]] ) | |
goalState = 5 # goal-room number | |
gamma = 0.8 # reward update parameter | |
print('This is the reward table:\n', R) | |
# print rewardTable[1,2] | |
# initialize the Q table: | |
Q = np.zeros( (6,6) ) | |
print('This is the Q table:\n', Q) | |
# Q learning algorithm in unsupervised mode: | |
# while True: | |
for i in range (0, 100): | |
# select an initial state: | |
currState = np.random.randint(0,6) # current state | |
# print('Current state: ', currState) | |
print('Learning: Actions sequence: ', end=" ") | |
print(currState, end=" ") | |
while True: | |
# make an allowed move at random: | |
while True: | |
nextState = np.random.randint(0,6) | |
# print('Next state: ', nextState) | |
if R[ currState, nextState ] != -1: | |
break | |
# update the Q table with rewards: | |
maxQ = Q[nextState].max() | |
# print 'Max Q of next state:', maxQ | |
Q[ currState, nextState ] = R[ currState, nextState ] + gamma * maxQ | |
# next move: | |
currState = nextState | |
print(currState, end=" ") | |
if nextState == goalState: | |
print('Goal reached') | |
break | |
# normalize Q matrix: | |
Q = Q / Q.max() * 100 | |
print('Q table after learning:\n', Q.astype(int)) # as integer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment