Skip to content

Instantly share code, notes, and snippets.

@geffy
Created November 21, 2016 10:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geffy/dbf48e3f1c174f6bbd8061b2b20f5498 to your computer and use it in GitHub Desktop.
Save geffy/dbf48e3f1c174f6bbd8061b2b20f5498 to your computer and use it in GitHub Desktop.
import gym
import numpy as np
env = gym.make('FrozenLake8x8-v0')
env.reset()
# find terminals
def find_terminals(mdp_raw):
terminals = set()
for src_state, node in mdp_raw.items():
for action, action_tuple in node.items():
for (prob, dst_state, reward, is_final) in action_tuple:
if is_final:
terminals.add(dst_state)
return terminals
def evaluate_policy(v, policy, terminals, mdp_raw):
ret = np.zeros(64)
for sid in range(64):
nuller = (sid not in terminals)*1
act = policy[sid]
est_reward = np.sum([pr*rew*nuller for (pr, dst, rew, _) in mdp_raw[sid][act]])
est_v = np.sum([pr*v[dst]*nuller for (pr, dst, rew, _) in mdp_raw[sid][act]])
ret[sid] = est_reward + gamma*est_v
return ret
def build_greedy_policy(v, mdp_raw):
new_policy = np.zeros(64)
for state_id in range(64):
profits = np.zeros(4)
for action in range(4):
for (prob, dst_state, reward, is_final) in mdp_raw[state_id][action]:
profits[action] += prob*(reward + gamma*v[dst_state])
new_policy[state_id] = np.argmax(profits)
return new_policy
gamma = 1.0
# init random policy
policy = np.array([1]*64).astype(int)
# init v
v = np.zeros(64)
# copy info about env
mdp_raw = env.P.copy()
terminals = find_terminals(mdp_raw)
# solve MDP
for n_iter in range(100):
for eval_iter in range(1000):
v = evaluate_policy(v, policy, terminals, mdp_raw)
policy = build_greedy_policy(v, mdp_raw).astype(int)
# run enviroment
env.monitor.start('/tmp/frozenlake-exp', force=True)
cum_reward = 0
for t_rounds in range(200000):
env.reset()
observation = 0
for t in range(10000):
action = policy[observation]
observation, reward, done, info = env.step(action)
if done:
cum_reward += reward
break
env.monitor.close()
print(cum_reward)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment