Skip to content

Instantly share code, notes, and snippets.

@zsunberg
Created August 3, 2023 04:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zsunberg/c298d1f79099a624193c3dfe16b93762 to your computer and use it in GitHub Desktop.
Save zsunberg/c298d1f79099a624193c3dfe16b93762 to your computer and use it in GitHub Desktop.
from julia.CommonRLSpaces import Box
from julia.Main import Float64
from julia.POMDPs import solve, pdf,action
from julia.QMDP import QMDPSolver
from julia.POMCPOW import POMCPOWSolver
from julia.POMDPTools import stepthrough, alphavectors, Uniform, Deterministic
from julia.Distributions import Normal,AbstractMvNormal,MvNormal
from quickpomdps import QuickPOMDP
import numpy as np
goal = [3,3,0]
def reward(s, a, sp):
if s[0] == goal[0] and s[1] == goal[1]:
return 5.0
else:
return -1.0
def transition(s, a, dt=0.1):
s0 = s[0] + a[0] * np.sin(s[2]) * dt
s1 = s[1] + a[0] * np.sin(s[2]) * dt
return MvNormal([s0, s1, s[2]], [0.1,0.1,0.1])
def observation(s, a, sp):
# sp is next state
return MvNormal(sp, [0.001,0.001,0.001])
def reward(s, a, sp):
if s[0] == goal[0] and s[1] == goal[1]:
return 5.0
else:
return -1.0
m = QuickPOMDP(
states = Box([-5,-5,-3], [5,5,3]),
actions = Box([-5,-5], [5,5]),
observations=Box([-5,-5,-3], [5,5,3]),
discount = 0.9,
isterminal = lambda s: (s[0] == goal[0] and s[1] == goal[1]),
transition = transition,
observation = observation,
reward = reward,
initialstate = Deterministic([0.0,0.0,0.0])
)
solver = POMCPOWSolver()
#solver = QMDPSolver()
policy = solve(solver, m)
a = action(policy,Deterministic([0.0,0.0,0.0]))
print(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment