Skip to content

Instantly share code, notes, and snippets.

@denny0323
Last active January 25, 2018 16:47
Show Gist options
  • Save denny0323/cbacc8a5417ad122b62f50e1a834274a to your computer and use it in GitHub Desktop.
Save denny0323/cbacc8a5417ad122b62f50e1a834274a to your computer and use it in GitHub Desktop.
Chapter 4. Dynamic Programming_Value Iteration
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_state(state, action):\n",
" \n",
" action_grid = [(-1, 0), (1, 0), (0, -1), (0, 1)]\n",
" \n",
" state[0]+=action_grid[action][0]\n",
" state[1]+=action_grid[action][1]\n",
" \n",
" if state[0] < 0 :\n",
" state[0] = 0\n",
" elif state[0] > 3 :\n",
" state[0] = 3\n",
" \n",
" if state[1] < 0 :\n",
" state[1] = 0\n",
" elif state[1] > 3 :\n",
" state[1] = 3\n",
" \n",
" return state[0], state[1]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def policy_evaluation(grid_width, grid_height, action, policy, iter_num, reward=-1, dis=1):\n",
" \n",
" # table initialize\n",
" post_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
" \n",
" # iteration\n",
" if iter_num == 0:\n",
" print('Iteration: {} \\n{}\\n'.format(iter_num, post_value_table))\n",
" return post_value_table\n",
" \n",
" for iteration in range(iter_num):\n",
" next_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
" for i in range(grid_height):\n",
" for j in range(grid_width):\n",
" if i == j and ((i == 0) or (i == 3)):\n",
" value_t = 0\n",
" else :\n",
" value_t_list= []\n",
" for act in action:\n",
" i_, j_ = get_state([i,j], act)\n",
" value = (reward + dis*post_value_table[i_][j_])\n",
" value_t_list.append(value)\n",
" next_value_table[i][j] = max(value_t_list)\n",
" iteration += 1\n",
" \n",
" # print result\n",
" if (iteration % 10) != iter_num: \n",
" # print result \n",
" if iteration > 100 :\n",
" if (iteration % 20) == 0: \n",
" print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
" else :\n",
" if (iteration % 10) == 0:\n",
" print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
" else :\n",
" print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table ))\n",
" \n",
" \n",
" post_value_table = next_value_table\n",
" \n",
" \n",
" return next_value_table"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"grid_width = 4\n",
"grid_height = grid_width\n",
"action = [0, 1, 2, 3] # up, down, left, right\n",
"policy = np.empty([grid_height, grid_width, len(action)], dtype=float)\n",
"for i in range(grid_height):\n",
" for j in range(grid_width):\n",
" for k in range(len(action)):\n",
" if i==j and ((i==0) or (i==3)):\n",
" policy[i][j]=0.00\n",
" else :\n",
" policy[i][j]=0.25\n",
"policy[0][0] = [0] * grid_width\n",
"policy[3][3] = [0] * grid_width"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 1 \n",
"[[ 0. -1. -1. -1.]\n",
" [-1. -1. -1. -1.]\n",
" [-1. -1. -1. -1.]\n",
" [-1. -1. -1. 0.]]\n",
"\n",
"Iteration: 2 \n",
"[[ 0. -1. -2. -2.]\n",
" [-1. -2. -2. -2.]\n",
" [-2. -2. -2. -1.]\n",
" [-2. -2. -1. 0.]]\n",
"\n",
"Iteration: 3 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 10 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n"
]
}
],
"source": [
"value = policy_evaluation(grid_width, grid_height, action, policy, 1)\n",
"value = policy_evaluation(grid_width, grid_height, action, policy, 2)\n",
"value = policy_evaluation(grid_width, grid_height, action, policy, 3)\n",
"value = policy_evaluation(grid_width, grid_height, action, policy, 10)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 10 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 20 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 30 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 40 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 50 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 60 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 70 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 80 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 90 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n",
"Iteration: 100 \n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n"
]
}
],
"source": [
"value = policy_evaluation(grid_width, grid_height, action, policy, 100)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment