Skip to content

Instantly share code, notes, and snippets.

@zephirefaith
Created October 4, 2022 20:11
Show Gist options
  • Save zephirefaith/bdb7ea164c23e6494582ea9a9d235f42 to your computer and use it in GitHub Desktop.
Save zephirefaith/bdb7ea164c23e6494582ea9a9d235f42 to your computer and use it in GitHub Desktop.
Python code to instantiate a simple 4 node HTN for an agent in 2D grid. The agent is spawned randomly somewhere in the world, the HTN encodes decisions needed to get the agent to top-right corner of the grid.
#!/bin/env python
import networkx as nx
import numpy as np
from enum import Enum
class Actions(Enum):
MOVE_UP = 0
MOVE_RIGHT = 1
class States(Enum):
LEFT_BOTTOM = 0
LEFT_TOP = 1
RIGHT_BOTTOM = 2
RIGHT_TOP = 3
class MRPEnv(object):
def __init__(self):
self.x = None # agent-state
self.init_agent()
def init_agent(self):
self.x = np.array(
[
10.00 * np.random.random_sample(),
10.00 * np.random.random_sample(),
]
)
def step(self, action: str):
if action == Actions.MOVE_UP:
self.x[1] = np.clip(
self.x[1] + 2 * np.random.random_sample(), 0, 10
)
elif action == Actions.MOVE_RIGHT:
self.x[0] = np.clip(
self.x[0] + 1.5 * np.random.random_sample(), 0, 10
)
def is_done(self):
return (10.00 - self.x[0]) < 0.25 and (10.00 - self.x[1]) < 0.25
class TaskGraph(object):
def __init__(self) -> None:
self._V = []
for state in States:
self._V.append((state.value, {"state": state.name}))
self._E = [
(0, 1, {"action": Actions.MOVE_UP}),
(0, 2, {"action": Actions.MOVE_RIGHT}),
(1, 3, {"action": Actions.MOVE_RIGHT}),
(2, 3, {"action": Actions.MOVE_UP}),
(3, 3, {"action": Actions.MOVE_RIGHT}),
(3, 3, {"action": Actions.MOVE_UP}),
]
self.GT = None
self.init_graph()
def init_graph(self):
self.GT = nx.MultiDiGraph()
self.GT.add_nodes_from(self._V)
self.GT.add_edges_from(self._E)
def get_current_node(self, x):
if x[0] <= 5.00:
if x[1] <= 5.00:
return States.LEFT_BOTTOM
else:
return States.LEFT_TOP
else:
if x[1] <= 5.00:
return States.RIGHT_BOTTOM
else:
return States.RIGHT_TOP
def get_next_action(self, x):
rng = np.random.default_rng()
current_v = self.get_current_node(x)
action_list = [
e for e in self.GT.out_edges(current_v.value, data=True)
]
action_num = self.GT.out_degree(current_v.value)
if action_num == 1:
return action_list[0][2]["action"]
else:
return rng.choice(action_list)[2]["action"]
if __name__ == "__main__":
# create env
env = MRPEnv()
task_graph = TaskGraph()
while not env.is_done():
curr_state = task_graph.get_current_node(env.x)
print(f"Current state: {env.x}, {curr_state}")
next_action = task_graph.get_next_action(env.x)
env.step(next_action)
print(f"Next action: {next_action}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment