Skip to content

Instantly share code, notes, and snippets.

@geffy
Created November 24, 2016 16:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save geffy/b2d16d01cbca1ae9e13f11f678fa96fd to your computer and use it in GitHub Desktop.
Save geffy/b2d16d01cbca1ae9e13f11f678fa96fd to your computer and use it in GitHub Desktop.
# Solving as MDP using Value Iteration Algorithm
import gym
import numpy as np
def iterate_value_function(v_inp, gamma, env):
ret = np.zeros(env.nS)
for sid in range(env.nS):
temp_v = np.zeros(env.nA)
for action in range(env.nA):
for (prob, dst_state, reward, is_final) in env.P[sid][action]:
temp_v[action] += prob*(reward + gamma*v_inp[dst_state]*(not is_final))
ret[sid] = max(temp_v)
return ret
def build_greedy_policy(v_inp, gamma, env):
new_policy = np.zeros(env.nS)
for state_id in range(env.nS):
profits = np.zeros(env.nA)
for action in range(env.nA):
for (prob, dst_state, reward, is_final) in env.P[state_id][action]:
profits[action] += prob*(reward + gamma*v[dst_state])
new_policy[state_id] = np.argmax(profits)
return new_policy
env = gym.make('Taxi-v1')
gamma = 0.999999
cum_reward = 0
n_rounds = 500
env.monitor.start('/tmp/taxi-vi', force=True)
for t_rounds in range(n_rounds):
# init env and value function
observation = env.reset()
v = np.zeros(env.nS)
# solve MDP
for _ in range(100):
v_old = v.copy()
v = iterate_value_function(v, gamma, env)
if np.all(v == v_old):
break
policy = build_greedy_policy(v, gamma, env).astype(np.int)
# apply policy
for t in range(1000):
action = policy[observation]
observation, reward, done, info = env.step(action)
cum_reward += reward
if done:
break
if t_rounds % 50 == 0 and t_rounds > 0:
print(cum_reward * 1.0 / (t_rounds + 1))
env.monitor.close()
@dc2032
Copy link

dc2032 commented May 9, 2023

Recently I purchased the book "Deep Reinforcement Learning" by Aske Plaat. He uses your taxi-vi.py code to illustrate how value iteration works. However, when I try to run it, I get the error messages 'TaxiEnv' object has no attribute nS and nA. I've been through the gym documentation, and various forums and can find no reference to these values. Can you tell me what they represent so I can put them into the code to get it working? Thanks.

@Bruzie987
Copy link

Bruzie987 commented May 23, 2023

The code provided above is outdated. Replace env.nS with env.observation_space.n and env.nA with env.action_space.n. Use the code from the book.

@dc2032
Copy link

dc2032 commented May 23, 2023

Thanks, I'll give that a try.

@Taresin
Copy link

Taresin commented Feb 26, 2024

I also found this through the textbook. Sorry I'm a year late. Just started reading the book.

Here's my updated code to work with the Gymnasium library:
https://gist.github.com/Taresin/a090274fbaf092ad649e4e32e22ecaf4

Hoping that this might help others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment