Skip to content

Instantly share code, notes, and snippets.

@tuxdna
Last active April 13, 2024 04:02
Show Gist options
  • Save tuxdna/7e29dd37300e308a80fc1559c343c545 to your computer and use it in GitHub Desktop.
Save tuxdna/7e29dd37300e308a80fc1559c343c545 to your computer and use it in GitHub Desktop.
Policy Iteration in Python
#!/usr/bin/env python
# coding=utf-8
import numpy as np
"""
1: Procedure Policy_Iteration(S,A,P,R)
2: Inputs
3: S is the set of all states
4: A is the set of all actions
5: P is state transition function specifying P(s'|s,a)
6: R is a reward function R(s,a,s')
7: Output
8: optimal policy π
9: Local
10: action array π[S]
11: Boolean variable noChange
12: real array V[S]
13: set π arbitrarily
14: repeat
15: noChange ←true
16: Solve V[s] = ∑s'∈S P(s'|s,π[s])(R(s,a,s')+γV[s'])
17: for each s∈S do
18: Let QBest=V[s]
19: for each a ∈A do
20: Let Qsa=∑s'∈S P(s'|s,a)(R(s,a,s')+γV[s'])
21: if (Qsa > QBest) then
22: π[s]←a
23: QBest ←Qsa
24: noChange ←false
25: until noChange
26: return π
"""
states = [0,1,2,3,4]
actions = [0,1]
N_STATES = len(states)
N_ACTIONS = len(actions)
P = np.zeros((N_STATES, N_ACTIONS, N_STATES)) # transition probability
R = np.zeros((N_STATES, N_ACTIONS, N_STATES)) # rewards
P[0,0,1] = 1.0
P[1,1,2] = 1.0
P[2,0,3] = 1.0
P[3,1,4] = 1.0
P[4,0,4] = 1.0
R[0,0,1] = 1
R[1,1,2] = 10
R[2,0,3] = 100
R[3,1,4] = 1000
R[4,0,4] = 1.0
gamma = 0.75
# initialize policy and value arbitrarily
policy = [0 for s in range(N_STATES)]
V = np.zeros(N_STATES)
print "Initial policy", policy
# print V
# print P
# print R
is_value_changed = True
iterations = 0
while is_value_changed:
is_value_changed = False
iterations += 1
# run value iteration for each state
for s in range(N_STATES):
V[s] = sum([P[s,policy[s],s1] * (R[s,policy[s],s1] + gamma*V[s1]) for s1 in range(N_STATES)])
# print "Run for state", s
for s in range(N_STATES):
q_best = V[s]
# print "State", s, "q_best", q_best
for a in range(N_ACTIONS):
q_sa = sum([P[s, a, s1] * (R[s, a, s1] + gamma * V[s1]) for s1 in range(N_STATES)])
if q_sa > q_best:
print "State", s, ": q_sa", q_sa, "q_best", q_best
policy[s] = a
q_best = q_sa
is_value_changed = True
print "Iterations:", iterations
# print "Policy now", policy
print "Final policy"
print policy
print V
@trentwoodbury
Copy link

I think you may want to change your definition of q_sa to be:

q_sa = sum([R[s, a, s1] + ( P[s, a, s1] * gamma * V[s1]) for s1 in range(N_STATES)])

This is because Q = R(s,a) + gamma * sum(P(s,a,s')*V(s')).

@YingxiaoKong
Copy link

for policy evaluation and policy iteration, do you need to do policy iteration after policy evaluation reach convergence?

@sasmitBiceps
Copy link

love you my guy, helped me to solve my assignment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment