Created
November 13, 2011 01:31
-
-
Save jotbe/1361445 to your computer and use it in GitHub Desktop.
AI-Class: 9.15 Value Iteration 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/env python | |
""" | |
9.15 Value Iteration 3 | |
============================== | |
We got the following world:: | |
+--+-----+-----+-----+-----+ | |
| | 1 | 2 | 3 | 4 | | |
+--+-----+-----+-----+-----+ | |
|A | | | | +100| | |
+--+-----+-----+-----+-----+ | |
|B | |#####| | -100| | |
+--+-----+-----+-----+-----+ | |
|C | | | | | | |
+--+-----+-----+-----+-----+ | |
The goal is in A4 with a reward of +100. B4 is a trap. B2 is blocked. | |
When heading into the target direction the target field is successfully | |
reached with a probability of 0.8. Adjacent fields (90 deg angle) with 0.1. | |
We should calculate the optimal policy. | |
Result:: | |
+--+---------------+----------------+----------------+---------------+ | |
|0 | 85.0 | 89.0 | 93.0 | 100| | |
+--+---------------+----------------+----------------+---------------+ | |
|1 | 81.0 | None | 68.0 | -100| | |
+--+---------------+----------------+----------------+---------------+ | |
|2 | 77.0 | 73.0 | 70.0 | 47.0| | |
+--+---------------+----------------+----------------+---------------+ | |
""" | |
__author__ = 'Jan Beilicke <dev +at+ jotbe-fx +dot+ de>' | |
__date__ = '2011-11-11' | |
import sys | |
class ValueIterator: | |
def __init__(self, world, read_only, prob_target=1.0, prob_left=0.0, prob_right=0.0, \ | |
prob_back=0.0, state_reward=0.0, discount_factor=1.0): | |
self.world = world | |
self.prob_target = prob_target | |
self.prob_left = prob_left | |
self.prob_right = prob_right | |
self.prob_back = prob_back | |
self.state_reward = state_reward | |
self.read_only_states = read_only | |
self.discount_factor = discount_factor | |
def iterate(self, iterations): | |
for i in range(iterations): | |
print('Iteration %d' % int(i + 1)) | |
for row in range(len(self.world)): | |
for col in range(len(self.world[row])): | |
if (row, col) in self.read_only_states: | |
continue | |
print('Estimating S(%d,%d)' % (row, col)) | |
self.world[row][col] = self.bellman_update((row, col)) | |
self.print_world(decimal_places=4) | |
def bellman_update(self, state_coords): | |
r, c = state_coords | |
e_coords = (r, c + 1) | |
s_coords = (r + 1, c) | |
w_coords = (r, c - 1) | |
n_coords = (r - 1, c) | |
e_util = self.get_utility_of_state(self.world[r][c], e_coords) | |
s_util = self.get_utility_of_state(self.world[r][c], s_coords) | |
w_util = self.get_utility_of_state(self.world[r][c], w_coords) | |
n_util = self.get_utility_of_state(self.world[r][c], n_coords) | |
print('own_value%s = %d' % ((r, c), self.world[r][c])) | |
print('e_util%s = %d' % (e_coords, e_util)) | |
print('s_util%s = %d' % (s_coords, s_util)) | |
print('w_util%s = %d' % (w_coords, w_util)) | |
print('n_util%s = %d' % (n_coords, n_util)) | |
e_value = self.value_function(e_util, n_util, s_util, w_util) | |
s_value = self.value_function(s_util, e_util, w_util, n_util) | |
w_value = self.value_function(w_util, s_util, n_util, e_util) | |
n_value = self.value_function(n_util, w_util, e_util, s_util) | |
return (self.state_reward) + self.discount_factor * max(e_value, s_value, w_value, n_value) | |
def get_utility_of_state(self, own_value, target_coords): | |
row, col = target_coords | |
if row < 0 or col < 0: | |
print('Bumping against a wall in S(%d,%d)' % (row, col)) | |
return own_value | |
try: | |
value = self.world[row][col] or own_value | |
except IndexError: | |
print('Bumping against a wall in S(%d,%d)' % (row, col)) | |
value = own_value | |
return float(value) | |
def value_function(self, target, left=0, right=0, back=0): | |
return self.prob_target * target + \ | |
self.prob_left * left + \ | |
self.prob_right * right +\ | |
self.prob_back * back | |
def print_world(self, decimal_places=0): | |
for row in range(len(self.world)): | |
print('-' * 70) | |
sys.stdout.write(str(row)) | |
for col in range(len(self.world[row])): | |
val = self.world[row][col] | |
if type(val) == float: | |
val = round(val, decimal_places) | |
sys.stdout.write(' | %14s' % val) | |
print('|') | |
print('-' * 70) | |
if __name__ == '__main__': | |
world = [ | |
[0, 0, 0, 100], | |
[0, None, 0, -100], | |
[0, 0, 0, 0] | |
] | |
read_only_states = [(0, 3), (1, 1), (1, 3)] | |
# For stochastic actions | |
#Vi_stochastic = ValueIterator(world, read_only_states, prob_target=0.8, prob_left=0.1, prob_right=0.1, state_reward=-3) | |
#Vi_stochastic.iterate(50) | |
# For stochastic actions with no costs | |
#Vi_stochastic = ValueIterator(world, read_only_states, prob_target=0.8, prob_left=0.1, prob_right=0.1, state_reward=0, discount_factor=0.1) | |
#Vi_stochastic.iterate(100) | |
# For deterministic actions | |
Vi_deterministic = ValueIterator(world, read_only_states, prob_target=1, state_reward=-3) | |
Vi_deterministic.iterate(50) | |
# For stochastic actions with high costs (even higher then the ditch in (0,4)) | |
# "This is an extreme case. I don't know why it would make sense to set a penalty for life that is | |
# so negative that even negative death is worse than living." | |
#Vi_stochastic = ValueIterator(world, read_only_states, prob_target=0.8, prob_left=0.1, prob_right=0.1, state_reward=-200) | |
#Vi_stochastic.iterate(50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment