Skip to content

Instantly share code, notes, and snippets.

@zkytony
Last active May 28, 2024 00:05
Show Gist options
  • Save zkytony/73dce5f0832c6ded3197bd68f27f99eb to your computer and use it in GitHub Desktop.
Save zkytony/73dce5f0832c6ded3197bd68f27f99eb to your computer and use it in GitHub Desktop.
quick example for reachable belief and value iteration with infinite horizon using pomdp-py
import random
import pprint
import pomdp_py
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from pomdp_py.algorithms.value_function import expected_reward, belief_observation_model
from pomdp_py.problems.tiger.tiger_problem import TigerProblem, TigerState
PRECISION = 4 # precision of belief probabilities
def _to_tuple(b: dict) -> tuple:
"""given a belief dictionary return a tuple representation"""
return tuple(
sorted(((s, round(b[s], PRECISION)) for s in b), key=lambda elm: str(elm[0]))
)
def _to_dict(b_tuple: tuple) -> dict:
"""given a tuple belief, return a dictionary from state to prob"""
return {elm[0]: elm[1] for elm in b_tuple}
def reachable_belief(b, A, Z, T, O) -> set:
"""given an initial belief b, computes the set of beliefs reachable from b
under the transition defined by A, Z, T, O"""
reachable_set = set({_to_tuple(b)})
transitions = {} # maps from b,a,z to b'
_reachable_belief(b, A, Z, T, O, reachable_set, transitions)
return reachable_set, transitions
def _reachable_belief(b, A, Z, T, O, reachable_set, transitions) -> None:
"""Given initial belief b of a POMDP
<S,A,Z,T,O,R>, return a set of reachable
belief states. Only A, Z, T, O are necessary
to be passed in"""
for a in A:
for z in Z:
b_next = pomdp_py.belief_update(b, a, z, T, O)
b_next_tuple = _to_tuple(b_next)
transitions[(_to_tuple(b), a, z)] = b_next_tuple
if b_next_tuple not in reachable_set:
reachable_set.add(b_next_tuple)
_reachable_belief(b_next, A, Z, T, O, reachable_set, transitions)
def value_iteration_infinite_horizon(
Rb0, A, Z, T, O, R, gamma, belief_transitions, max_iter=1000, epsilon=1e-4
):
"""Perform value iteration with infinite horizon over reachable belief states;
Also returns the optimal policy"""
V = {b: random.uniform(-5, 5) for b in Rb0}
pi = {}
for step in range(max_iter):
Vp = {}
for b in Rb0:
Vp[b], pi[b] = _value(
_to_dict(b), A, Z, T, O, R, gamma, V, belief_transitions
)
diff = _value_difference(V, Vp)
if diff < epsilon:
print(f"Value Iteration converged after {step+1} iterations.")
return V, pi
V = Vp
return V, pi
def _value(b, A, Z, T, O, R, gamma, V, belief_transitions):
"""Compute value at belief b making use of future values from V."""
max_qval = float("-inf")
best_action = None
for a in A:
qval = _qvalue(b, a, Z, T, O, R, gamma, V, belief_transitions)
if qval > max_qval:
max_qval = qval
best_action = a
return max_qval, best_action
def _qvalue(b, a, Z, T, O, R, gamma, V, belief_transitions):
"""Compute qvalue at b, a making use of future values from V"""
r = expected_reward(b, R, a, T)
expected_future_value = 0.0
for z in Z:
# compute Pr(o|b,a)*V(b')
prob_z = belief_observation_model(z, b, a, T, O)
# If o has non-zero probability
if prob_z > 0:
next_b = belief_transitions[(_to_tuple(b), a, z)]
next_value = V[next_b]
expected_future_value += prob_z * next_value
return r + gamma * expected_future_value
def _value_difference(V1, V2):
diffs = [abs(V1[b] - V2[b]) for b in V1]
return max(diffs)
def _create_pomdp(noise=0.15, init_state="tiger-left"):
tiger = TigerProblem(
noise,
TigerState(init_state),
pomdp_py.Histogram(
{TigerState("tiger-left"): 0.5, TigerState("tiger-right"): 0.5}
),
)
T = tiger.agent.transition_model
O = tiger.agent.observation_model
S = list(T.get_all_states())
Z = list(O.get_all_observations())
A = list(tiger.agent.policy_model.get_all_actions())
R = tiger.agent.reward_model
gamma = 0.95
b0 = tiger.agent.belief
s0 = tiger.env.state
return b0, s0, S, A, Z, T, O, R, gamma
def test():
b0, s0, S, A, Z, T, O, R, gamma = _create_pomdp()
Rb0, btrans = reachable_belief(b0, A, Z, T, O)
print(f"Reachable belief from {b0} has {len(Rb0)} belief states:")
pprint.pp(Rb0)
value, policy = value_iteration_infinite_horizon(Rb0, A, Z, T, O, R, gamma, btrans)
print("Value function:")
pprint.pp(value)
print("Policy:")
pprint.pp(policy)
# Let's make a plot
data = {"b_tiger_left": [], "value": [], "action": []}
for b in value:
data["b_tiger_left"].append(_to_dict(b)[TigerState("tiger-left")])
data["value"].append(value[b])
data["action"].append(policy[b])
sns.scatterplot(pd.DataFrame(data), x="b_tiger_left", y="value", hue="action")
plt.show()
if __name__ == "__main__":
test()
@zkytony
Copy link
Author

zkytony commented May 28, 2024

Output:

Reachable belief from {TigerState(tiger-left): 0.5, TigerState(tiger-right): 0.5} has 13 belief states:
{((TigerState(tiger-left), 0.0), (TigerState(tiger-right), 1.0)),
 ((TigerState(tiger-left), 0.0002), (TigerState(tiger-right), 0.9998)),
 ((TigerState(tiger-left), 0.001), (TigerState(tiger-right), 0.999)),
 ((TigerState(tiger-left), 0.0055), (TigerState(tiger-right), 0.9945)),
 ((TigerState(tiger-left), 0.0302), (TigerState(tiger-right), 0.9698)),
 ((TigerState(tiger-left), 0.15), (TigerState(tiger-right), 0.85)),
 ((TigerState(tiger-left), 0.5), (TigerState(tiger-right), 0.5)),
 ((TigerState(tiger-left), 0.85), (TigerState(tiger-right), 0.15)),
 ((TigerState(tiger-left), 0.9698), (TigerState(tiger-right), 0.0302)),
 ((TigerState(tiger-left), 0.9945), (TigerState(tiger-right), 0.0055)),
 ((TigerState(tiger-left), 0.999), (TigerState(tiger-right), 0.001)),
 ((TigerState(tiger-left), 0.9998), (TigerState(tiger-right), 0.0002)),
 ((TigerState(tiger-left), 1.0), (TigerState(tiger-right), 0.0))}
Value Iteration converged after 183 iterations.
Value function:
{((TigerState(tiger-left), 0.85), (TigerState(tiger-right), 0.15)): 21.442398887166696,
 ((TigerState(tiger-left), 0.0302), (TigerState(tiger-right), 0.9698)): 25.079575480861333,
 ((TigerState(tiger-left), 0.5), (TigerState(tiger-right), 0.5)): 19.370181755728048,
 ((TigerState(tiger-left), 1.0), (TigerState(tiger-right), 0.0)): 28.401575480861332,
 ((TigerState(tiger-left), 0.001), (TigerState(tiger-right), 0.999)): 28.291575480861333,
 ((TigerState(tiger-left), 0.15), (TigerState(tiger-right), 0.85)): 21.442398887166696,
 ((TigerState(tiger-left), 0.0), (TigerState(tiger-right), 1.0)): 28.401575480861332,
 ((TigerState(tiger-left), 0.999), (TigerState(tiger-right), 0.001)): 28.291575480861326,
 ((TigerState(tiger-left), 0.0055), (TigerState(tiger-right), 0.9945)): 27.796575480861332,
 ((TigerState(tiger-left), 0.9945), (TigerState(tiger-right), 0.0055)): 27.796575480861332,
 ((TigerState(tiger-left), 0.9998), (TigerState(tiger-right), 0.0002)): 28.379575480861334,
 ((TigerState(tiger-left), 0.9698), (TigerState(tiger-right), 0.0302)): 25.079575480861333,
 ((TigerState(tiger-left), 0.0002), (TigerState(tiger-right), 0.9998)): 28.379575480861334}
Policy:
{((TigerState(tiger-left), 0.85), (TigerState(tiger-right), 0.15)): TigerAction(listen),
 ((TigerState(tiger-left), 0.0302), (TigerState(tiger-right), 0.9698)): TigerAction(open-left),
 ((TigerState(tiger-left), 0.5), (TigerState(tiger-right), 0.5)): TigerAction(listen),
 ((TigerState(tiger-left), 1.0), (TigerState(tiger-right), 0.0)): TigerAction(open-right),
 ((TigerState(tiger-left), 0.001), (TigerState(tiger-right), 0.999)): TigerAction(open-left),
 ((TigerState(tiger-left), 0.15), (TigerState(tiger-right), 0.85)): TigerAction(listen),
 ((TigerState(tiger-left), 0.0), (TigerState(tiger-right), 1.0)): TigerAction(open-left),
 ((TigerState(tiger-left), 0.999), (TigerState(tiger-right), 0.001)): TigerAction(open-right),
 ((TigerState(tiger-left), 0.0055), (TigerState(tiger-right), 0.9945)): TigerAction(open-left),
 ((TigerState(tiger-left), 0.9945), (TigerState(tiger-right), 0.0055)): TigerAction(open-right),
 ((TigerState(tiger-left), 0.9998), (TigerState(tiger-right), 0.0002)): TigerAction(open-right),
 ((TigerState(tiger-left), 0.9698), (TigerState(tiger-right), 0.0302)): TigerAction(open-right),
 ((TigerState(tiger-left), 0.0002), (TigerState(tiger-right), 0.9998)): TigerAction(open-left)}

Plot:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment