Skip to content

Instantly share code, notes, and snippets.

@pckujawa
Created October 26, 2012 07:37
Show Gist options
  • Save pckujawa/3957444 to your computer and use it in GitHub Desktop.
Save pckujawa/3957444 to your computer and use it in GitHub Desktop.
Active and passive temporal difference learning in grid world for AI
#-------------------------------------------------------------------------------
# Author: Pat Kujawa
#-------------------------------------------------------------------------------
#!/usr/bin/env python
import numpy as np
import random
from collections import defaultdict
from pprint import pprint, pformat
class Log():
@classmethod
def log(self,msg):
print msg
@classmethod
def i(self,msg):
return
self.log('INFO: {}'.format(msg))
class obj(object):
def __init__(self, **kwargs):
'''Create obj with attributes/values as in kwargs dict'''
self.__dict__.update(kwargs)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self): return self.__str__()
class Cell(obj):
def __init__(self, **kwargs):
self.is_terminal = False
self.policy = ''
self.value = 0
super(Cell, self).__init__(**kwargs) # override defaults if set
def __str__(self):
return pformat(self.__dict__)
def __repr__(self): return self.__str__()
def __eq__(self, other):
if isinstance(other, Cell):
return self.__dict__ == other.__dict__
else:
return self.__dict__ == other
world_bounds_rows = 3
world_bounds_cols = 4
class GridWorld(object):
def __init__(self, array):
cells = []
for row_idx, row in enumerate(array):
cells.append([])
for col_idx, value in enumerate(row):
cell = Cell()
cell.actual_move_to_count_map = defaultdict(lambda: 0)
cell.row = row_idx
cell.col = col_idx
cell.value = value
# TODO get reward from input
# If the value is nonzero, NaN, or +-Inf, set as reward too
cell.reward = value or 0
cell.blocks = np.isnan(value)
cells[row_idx].append(cell)
assert len(cells[row_idx]) == world_bounds_cols
assert len(cells) == world_bounds_rows
cells[0][3].is_terminal = True
cells[1][3].is_terminal = True
self.cells = cells
self.policy_to_move_map = {
'n': self.north, 's': self.south, 'e': self.east, 'w': self.west
}
def cells_str(self):
rows = []
for row in self.cells:
rows.append(''.join('{:4},{:5} '.format(c.value, c.policy) for c in row))
return '''
'''.join(rows) # newline
def policy(self, policy=None):
'''Get or set the policy of all cells. The policy arg must have the same number of rows and columns as the grid.'''
if not policy:
return [[c.policy for c in rows] for rows in self.cells]
assert len(policy) == len(self.cells)
for row,row_policy in zip(self.cells, policy):
assert len(row) == len(row_policy)
for cell,cell_policy in zip(row, row_policy):
cell.policy = cell_policy
def __iter__(self):
for row in self.cells:
for cell in row:
yield cell
def as_array(self):
a = [ [cell.value for cell in row] for row in self.cells ]
return np.array(a)
def start_cell(self, **kwargs):
if not kwargs:
return self._start_cell
row,col = kwargs['row'],kwargs['col']
self._start_cell = self.cells[row][col]
def next_following_policy(self, cell):
# Map policy to NSEW
p = cell.policy
if not p: raise ValueError('Cell does not have a policy, so we must be done or you made a mistake, e.g. your policy moves into a wall. Cell was: '+str(cell))
first_char = p[0].lower()
move = self.policy_to_move_map[first_char]
next_cell = move(cell)
return next_cell
def north(self, cell):
row = cell.row; col = cell.col
if row > 0:
row -= 1
return self._next_cell_if_not_blocked(cell, row, col, 'n')
def south(self, cell):
row = cell.row; col = cell.col
if row < world_bounds_rows-1:
row += 1
return self._next_cell_if_not_blocked(cell, row, col, 's')
def west(self, cell):
row = cell.row; col = cell.col
if col > 0:
col -= 1
return self._next_cell_if_not_blocked(cell, row, col, 'w')
def east(self, cell):
row = cell.row; col = cell.col
if col < world_bounds_cols-1:
col += 1
return self._next_cell_if_not_blocked(cell, row, col, 'e')
def _next_cell_if_not_blocked(self, cell, row, col, direction):
cell.actual_move = direction # Staying in place is implied if that's the case
n = self.cells[row][col]
if n.blocks:
return cell
return n
def __str__(self):
return pformat(self.as_array())
def __repr__(self): return self.__str__()
class TemporalDifferenceLearningAlgo(obj):
def __init__(self, world, **kwargs):
## self.learning_factor = lambda Ns: 1 #default
super(TemporalDifferenceLearningAlgo, self).__init__(
world=world, **kwargs)
self.reward = 0
self._init_itercnts()
self._init_utilmap()
north = self.world.north
south = self.world.south
east = self.world.east
west = self.world.west
self.policy_changed = False
# Mappings to pick move given random value
self.stochastic_map = {
'n': lambda r: north if r<=0.8 else east if r<=0.9 else west,
's': lambda r: south if r<=0.8 else east if r<=0.9 else west,
'e': lambda r: east if r<=0.8 else north if r<=0.9 else south,
'w': lambda r: west if r<=0.8 else north if r<=0.9 else south
}
def _init_itercnts(self):
self.iteration_map = []
for row in self.world.cells:
self.iteration_map.append([0 for c in row])
def _init_utilmap(self, and_iters=False):
self.utility_map = []
for row in self.world.cells:
umap = []
for idx,c in enumerate(row):
umap.append(c.value) # could be NaN
#TODO maybe use 'reward' to initialize
self.utility_map.append(umap)
def next_following_stochastic_policy(self, cell):
policy = cell.policy
if not policy: raise ValueError('The cell you gave me did not have a policy. Maybe you didnt check for a terminal. The cell was: {}'.format(cell))
direction = policy[0].lower()
move_func = self.stochastic_map[direction](random.uniform(0,1))
return move_func(cell)
def update(self, stochastic=False):
## Log.i('update()'.center(50,'-'))
# Start at start_cell and follow policies to update
cell = self.world.start_cell()
next_func = self.world.next_following_policy
if stochastic:
next_func = self.next_following_stochastic_policy
icnt=0
while(cell and not cell.is_terminal):
icnt+=1
## Log.i(' update_iteration: {}'.format(icnt))
## Log.i('cell: {}'.format(cell))
# Update cell's count
self.iteration_map[cell.row][cell.col] += 1
iter_cnt = self.iteration_map[cell.row][cell.col]
## Log.i('Ns[cell]: {}'.format(iter_cnt-1))
# Follow cell's policy to get next cell's utility
next_cell = next_func(cell)
rel_dir = self._movement_relative(cell)
cell.actual_move_to_count_map[rel_dir] += 1
## Log.i('next_cell: {}'.format(next_cell))
# Update utilities
utility = self.utility(cell)
## Log.i('U[cell]={}'.format(utility))
next_utility = self.utility(next_cell)
## Log.i('U[next_cell]={}'.format(next_utility))
rhs = self.reward + self.discount_factor*next_utility - utility
utility += self.learning_factor(iter_cnt) * rhs
## Log.i('new U[cell]={}'.format(utility))
self.utility(cell, utility)
cell = next_cell
relative_directions = ['n', 'e', 's', 'w', 'n']
def _movement_relative(self, cell):
policy = cell.policy
actual = cell.actual_move
dirs = self.relative_directions
diff = dirs.index(policy) - dirs.index(actual)
if diff == 0:
return 'straight'
if diff == -1 or diff == 3:
return 'right'
if diff == 1 or diff == -3:
return 'left'
raise ValueError('between {} and {} a diff of {} isnt valid'.format(
policy, actual, diff))
def _left(self, cell):
dirs = self.relative_directions
return dirs[ dirs.index(cell.policy, 1) - 1 ]
def _right(self, cell):
dirs = self.relative_directions
return dirs[ dirs.index(cell.policy) + 1 ]
def _straight(self, cell):
return cell.policy
def _get_estimated_prob(self, cell, direction):
iter_cnt = self.iteration_map[cell.row][cell.col]
if iter_cnt == 0: return 0
cell.actual_move = direction
rel_dir = self._movement_relative(cell)
return float(cell.actual_move_to_count_map[rel_dir]) / iter_cnt
def utility(self, cell, value=None):
if not value:
return self.utility_map[cell.row][cell.col]
self.utility_map[cell.row][cell.col] = value
def update_policy_greedily(self, reset_utils=False, reset_counts=False):
# Stupid
for cell in self.world:
cell.value = self.utility(cell)
for cell in self.world:
if cell.is_terminal or cell.blocks: continue
# Policies:
left = self._left(cell)
right = self._right(cell)
straight = cell.policy
pleft = self._get_estimated_prob(cell, left)
pright = self._get_estimated_prob(cell, right)
pstraight = self._get_estimated_prob(cell, straight)
policy = self._get_updated_policy(cell, pleft, pright, pstraight)
## self.policy_changed = True # HACK to print
if policy != cell.policy:
self.policy_changed = True
## print 'changed policy (now {}) for cell: {}'.format(policy, cell)
cell.policy = policy
if reset_utils:
self._init_utilmap()
if reset_counts:
self._init_itercnts()
def _get_updated_policy(self, state, pleft, pright, pstraight):
if state.is_terminal or state.blocks: return
max_pv = 0 # probability * value, as used in V(s) calculation
# Moves in NSEW order
moves_pvs = [
pstraight*self.world.north(state).value + pleft*self.world.west(state).value + pright*self.world.east(state).value,
pstraight*self.world.south(state).value + pright*self.world.west(state).value + pleft*self.world.east(state).value,
pstraight*self.world.east(state).value + pleft*self.world.north(state).value + pright*self.world.south(state).value,
pstraight*self.world.west(state).value + pright*self.world.north(state).value + pleft*self.world.south(state).value
]
moves_directions = ['n', 's', 'e', 'w']
max_idx = np.argmax(moves_pvs)
max_pv = moves_pvs[max_idx]
policy = moves_directions[max_idx]
if max_pv < 1e-3:
# If the max prob was ~zero, just leave the policy as-is
return state.policy
return policy
def print_grid_like_array(arr, min_cell_width=0, mapper=None):
if not mapper: mapper = lambda x: x
try:
for row in arr:
if min_cell_width > 0:
print ''.join('{:{}}'.format(mapper(c), min_cell_width) for c in row)
else:
print ''.join('{}'.format(mapper(c)) for c in row)
## print row
except AttributeError:
pass
def print_util(arr):
for row in arr:
print ''.join('{:8.3f}'.format(c) for c in row)
def print_actual_move_to_count_map(cells):
print_grid_like_array([[c.actual_move_to_count_map.items() for c in row] for row in cells])
import timeit
def temporal_difference_learning(active=False):
global world, algo
discount_factor = 0.9
learning_factor = lambda Ns: 1.0 / (Ns + 1) # alpha function
world_input = np.array([[0,0,0,1], [0,np.nan,0,-1], [0,0,0,0]])
world = GridWorld(world_input)
world.start_cell(row=2,col=0)
world.policy([
['e', 'e', 'e', '' ],
['s', '', 'n', '' ],
['e', 'e', 'n', 's']
])
algo = TemporalDifferenceLearningAlgo(world, discount_factor=discount_factor, learning_factor=learning_factor);
def run():
print 'starting with policy:'
print_grid_like_array(world.policy(), 3)
def printout():
print 'iteration {}'.format(iter_cnt).center(72, '-')
print 'utilities:'
print_util(algo.utility_map)
## print 'actual move counts:'
## print_actual_move_to_count_map(world.cells)
## print 'greedy policy values:'
## print_grid_like_array(world.cells, mapper=lambda c: '{:8.3f}'.format(c.greedy_policy_value))
iter_cnt = 0
prev_util = None
end_cnt = 100000
update_policy_after_cnt = 100
while(True):
if iter_cnt <= 3 or iter_cnt >= end_cnt:
printout()
if iter_cnt > end_cnt: break
if False and within_epsilon(prev_util, algo.utility_map, 0.0000001):
printout()
break
prev_util = [a[:] for a in algo.utility_map]
algo.update(True)
if active and iter_cnt % update_policy_after_cnt == 0 and iter_cnt>0:
prev_policy = world.policy()
algo.update_policy_greedily()
if algo.policy_changed:
algo.policy_changed = False
printout()
print 'previous policy:'
print_grid_like_array(prev_policy, 3)
print 'updated policy:'
print_grid_like_array(world.policy(), 3)
iter_cnt += 1
print 'Ns[cells]:'
print_grid_like_array(algo.iteration_map, 8)
t = timeit.Timer(run)
seconds = t.timeit(1)
print 'Took {}s to run.'.format(seconds)
def active_greedy():
temporal_difference_learning(active=True)
def within_epsilon(prev, current, epsilon=0.0001):
if not prev: return False # so there's no need to init prev for first check
for p_row,c_row in zip(prev, current):
for p,c, in zip(p_row,c_row):
if abs(p-c) > epsilon:
return False
return True
import unittest
class TemporalDifferenceLearningAlgoTests(unittest.TestCase):
def test_Given_lecture_grid_When_passive_applied_deterministically_Then_matches_lecture_10_10_results(self):
world = GridWorld(np.array([[0,0,0,1], [0,np.nan,0,-1], [0,0,0,0]]))
world.start_cell(row=2,col=0)
world.policy([
['e', 'e', 'e', '' ],
['n', '', 'n', '' ],
['n', 'e', 'n', 'w']
])
discount_factor = 1
learning_factor = lambda Ns: 1.0 / (Ns + 1) # alpha function
target = TemporalDifferenceLearningAlgo(world, discount_factor=discount_factor, learning_factor=learning_factor);
expected_utils = [a[:] for a in target.utility_map]
expected_utils[0][2] = 1.0/2
target.update()
self.assertEqual(len(expected_utils), len(target.utility_map))
for e,a in zip(expected_utils, target.utility_map):
self.assertAlmostEqual(e,a)
expected_utils[0][2] = 2.0/3
expected_utils[0][1] = 1.0/6
target.update()
for e,a in zip(expected_utils, target.utility_map):
self.assertAlmostEqual(e,a)
print 'Passive temporal diff learning:'
temporal_difference_learning()
print
print 'Active greedy:'
active_greedy()
##unittest.main()
@pckujawa
Copy link
Author

Passive temporal diff learning:
starting with policy:
e e e
s n
e e n s
------------------------------iteration 0-------------------------------
utilities:
0.000 0.000 0.000 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 1-------------------------------
utilities:
0.000 0.000 0.450 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 2-------------------------------
utilities:
0.000 0.000 0.450 1.000
0.000 nan -0.300 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 3-------------------------------
utilities:
0.000 0.000 0.600 1.000
0.000 nan -0.124 -1.000
0.000 0.000 -0.054 0.000
----------------------------iteration 100000----------------------------
utilities:
0.000 0.000 0.848 1.000
0.257 nan 0.572 -1.000
0.324 0.397 0.460 0.158
----------------------------iteration 100001----------------------------
utilities:
0.000 0.000 0.848 1.000
0.257 nan 0.572 -1.000
0.324 0.397 0.460 0.158
Ns[cells]:
0 0 109364 0
15797 0 123096 0
124979 140660 125044 124240
Took 19.8316876911s to run.

Active greedy:
starting with policy:
e e e
s n
e e n s
------------------------------iteration 0-------------------------------
utilities:
0.000 0.000 0.000 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 1-------------------------------
utilities:
0.000 0.000 0.450 1.000
0.000 nan 0.000 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 2-------------------------------
utilities:
0.000 0.000 0.600 1.000
0.000 nan 0.135 -1.000
0.000 0.000 0.000 0.000
------------------------------iteration 3-------------------------------
utilities:
0.000 0.000 0.675 1.000
0.000 nan 0.236 -1.000
0.000 0.000 0.030 0.000
-----------------------------iteration 100------------------------------
utilities:
0.000 0.000 0.840 1.000
0.018 nan 0.484 -1.000
0.073 0.152 0.286 0.044
previous policy:
e e e
s n
e e n s
updated policy:
e e e
s n
e e n w
----------------------------iteration 100000----------------------------
utilities:
0.000 0.000 0.847 1.000
0.269 nan 0.569 -1.000
0.331 0.402 0.467 0.268
----------------------------iteration 100001----------------------------
utilities:
0.000 0.000 0.847 1.000
0.269 nan 0.569 -1.000
0.331 0.402 0.467 0.268
Ns[cells]:
0 0 108109 0
15567 0 121752 0
125065 140608 123332 13747
Took 18.1787866773s to run.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment