Skip to content

Instantly share code, notes, and snippets.

@denny0323
Last active November 7, 2022 15:01
Show Gist options
  • Save denny0323/3c67330020ee7da0597edc197adfe8f0 to your computer and use it in GitHub Desktop.
Save denny0323/3c67330020ee7da0597edc197adfe8f0 to your computer and use it in GitHub Desktop.
Chapter 4. Dynamic Programming_Policy Iteration
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_state(state, action):\n",
" \n",
" action_grid = [(-1, 0), (1, 0), (0, -1), (0, 1)]\n",
" \n",
" state[0]+=action_grid[action][0]\n",
" state[1]+=action_grid[action][1]\n",
" \n",
" if state[0] < 0 :\n",
" state[0] = 0\n",
" elif state[0] > 3 :\n",
" state[0] = 3\n",
" \n",
" if state[1] < 0 :\n",
" state[1] = 0\n",
" elif state[1] > 3 :\n",
" state[1] = 3\n",
" \n",
" return state[0], state[1]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def policy_evaluation(grid_width, grid_height, action, policy, iter_num, reward=-1, dis=1):\n",
" \n",
" # table initialize\n",
" post_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
" \n",
" # iteration\n",
" if iter_num == 0:\n",
" print('Iteration: {} \\n{}\\n'.format(iter_num, post_value_table))\n",
" return post_value_table\n",
" \n",
" for iteration in range(iter_num):\n",
" next_value_table = np.zeros([grid_height, grid_width], dtype=float)\n",
" for i in range(grid_height):\n",
" for j in range(grid_width):\n",
" if i == j and ((i == 0) or (i == 3)):\n",
" value_t = 0\n",
" else :\n",
" value_t = 0\n",
" for act in action:\n",
" i_, j_ = get_state([i,j], act)\n",
" value = policy[i][j][act] * (reward + dis*post_value_table[i_][j_])\n",
" value_t += value\n",
" next_value_table[i][j] = round(value_t, 3)\n",
" iteration += 1\n",
" \n",
" # print result\n",
" if (iteration % 10) != iter_num: \n",
" # print result \n",
" if iteration > 100 :\n",
" if (iteration % 20) == 0: \n",
" print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
" else :\n",
" if (iteration % 10) == 0:\n",
" print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table))\n",
" else :\n",
" print('Iteration: {} \\n{}\\n'.format(iteration, next_value_table ))\n",
" \n",
" \n",
" post_value_table = next_value_table\n",
" \n",
" \n",
" return next_value_table"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"grid_width = 4\n",
"grid_height = grid_width\n",
"action = [0, 1, 2, 3] # up, down, left, right\n",
"policy = np.empty([grid_height, grid_width, len(action)], dtype=float)\n",
"for i in range(grid_height):\n",
" for j in range(grid_width):\n",
" for k in range(len(action)):\n",
" if i==j and ((i==0) or (i==3)):\n",
" policy[i][j]=0.00\n",
" else :\n",
" policy[i][j]=0.25\n",
"policy[0][0] = [0] * grid_width\n",
"policy[3][3] = [0] * grid_width"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 10 \n",
"[[ 0. -6.138 -8.352 -8.968]\n",
" [-6.138 -7.737 -8.428 -8.352]\n",
" [-8.352 -8.428 -7.737 -6.138]\n",
" [-8.968 -8.352 -6.138 0. ]]\n",
"\n",
"Iteration: 20 \n",
"[[ 0. -9.45 -13.257 -14.454]\n",
" [ -9.45 -12.06 -13.302 -13.257]\n",
" [-13.257 -13.302 -12.06 -9.45 ]\n",
" [-14.454 -13.257 -9.45 0. ]]\n",
"\n",
"Iteration: 30 \n",
"[[ 0. -11.366 -16.096 -17.632]\n",
" [-11.366 -14.562 -16.123 -16.097]\n",
" [-16.096 -16.123 -14.562 -11.366]\n",
" [-17.632 -16.097 -11.366 0. ]]\n",
"\n",
"Iteration: 40 \n",
"[[ 0. -12.475 -17.74 -19.471]\n",
" [-12.475 -16.01 -17.755 -17.74 ]\n",
" [-17.74 -17.755 -16.01 -12.475]\n",
" [-19.471 -17.74 -12.475 0. ]]\n",
"\n",
"Iteration: 50 \n",
"[[ 0. -13.117 -18.691 -20.536]\n",
" [-13.117 -16.847 -18.7 -18.691]\n",
" [-18.691 -18.7 -16.847 -13.117]\n",
" [-20.536 -18.691 -13.117 0. ]]\n",
"\n",
"Iteration: 60 \n",
"[[ 0. -13.489 -19.242 -21.152]\n",
" [-13.489 -17.333 -19.248 -19.242]\n",
" [-19.242 -19.248 -17.333 -13.489]\n",
" [-21.152 -19.242 -13.489 0. ]]\n",
"\n",
"Iteration: 70 \n",
"[[ 0. -13.704 -19.562 -21.51 ]\n",
" [-13.704 -17.614 -19.564 -19.562]\n",
" [-19.562 -19.564 -17.614 -13.704]\n",
" [-21.51 -19.562 -13.704 0. ]]\n",
"\n",
"Iteration: 80 \n",
"[[ 0. -13.829 -19.746 -21.716]\n",
" [-13.829 -17.777 -19.748 -19.746]\n",
" [-19.746 -19.748 -17.777 -13.829]\n",
" [-21.716 -19.746 -13.829 0. ]]\n",
"\n",
"Iteration: 90 \n",
"[[ 0. -13.9 -19.853 -21.835]\n",
" [-13.901 -17.87 -19.854 -19.853]\n",
" [-19.853 -19.854 -17.87 -13.901]\n",
" [-21.835 -19.853 -13.9 0. ]]\n",
"\n",
"Iteration: 100 \n",
"[[ 0. -13.942 -19.915 -21.905]\n",
" [-13.942 -17.925 -19.916 -19.915]\n",
" [-19.915 -19.916 -17.925 -13.942]\n",
" [-21.905 -19.915 -13.942 0. ]]\n",
"\n"
]
}
],
"source": [
"value = policy_evaluation(grid_width, grid_height, action, policy, 100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy Improvement"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def policy_improvement(value, action, policy, reward = -1, grid_width = 4):\n",
" \n",
" grid_height = grid_width\n",
" \n",
" action_match = ['Up', 'Down', 'Left', 'Right']\n",
" action_table = []\n",
" \n",
" # get Q-func.\n",
" for i in range(grid_height):\n",
" for j in range(grid_width):\n",
" q_func_list=[]\n",
" if i==j and ((i==0)or (i==3)):\n",
" action_table.append('T')\n",
" else:\n",
" for k in range(len(action)):\n",
" i_, j_ = get_state([i, j], k)\n",
" q_func_list.append(value[i_][j_])\n",
" max_actions = [action_v for action_v, x in enumerate(q_func_list) if x == max(q_func_list)] \n",
"\n",
" # update policy\n",
" policy[i][j]= [0]*len(action) # initialize q-func_list\n",
" for y in max_actions :\n",
" policy[i][j][y] = (1 / len(max_actions))\n",
"\n",
" # get action\n",
" idx = np.argmax(policy[i][j])\n",
" action_table.append(action_match[idx])\n",
" action_table=np.asarray(action_table).reshape((grid_height, grid_width)) \n",
" \n",
" print('Updated policy is :\\n{}\\n'.format(policy))\n",
" print('at each state, chosen action is :\\n{}'.format(action_table))\n",
" \n",
" return policy"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updated policy is :\n",
"[[[ 0. 0. 0. 0. ]\n",
" [ 0. 0. 1. 0. ]\n",
" [ 0. 0. 1. 0. ]\n",
" [ 0. 0.5 0.5 0. ]]\n",
"\n",
" [[ 1. 0. 0. 0. ]\n",
" [ 0.5 0. 0.5 0. ]\n",
" [ 0. 0.5 0.5 0. ]\n",
" [ 0. 1. 0. 0. ]]\n",
"\n",
" [[ 1. 0. 0. 0. ]\n",
" [ 0.5 0. 0. 0.5]\n",
" [ 0. 0.5 0. 0.5]\n",
" [ 0. 1. 0. 0. ]]\n",
"\n",
" [[ 0.5 0. 0. 0.5]\n",
" [ 0. 0. 0. 1. ]\n",
" [ 0. 0. 0. 1. ]\n",
" [ 0. 0. 0. 0. ]]]\n",
"\n",
"at each state, chosen action is :\n",
"[['T' 'Left' 'Left' 'Down']\n",
" ['Up' 'Up' 'Down' 'Down']\n",
" ['Up' 'Up' 'Down' 'Down']\n",
" ['Up' 'Right' 'Right' 'T']]\n"
]
}
],
"source": [
"updated_policy = policy_improvement(value, action, policy)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment