Last active
February 15, 2023 17:17
-
-
Save gerner/08d9509569979fadb9262d2e30af0da7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import logging | |
import enum | |
import collections | |
import math | |
import time | |
import pdb | |
from dataclasses import dataclass | |
from typing import Sequence, Tuple, List, Mapping, MutableMapping, Any, Set, Collection, Iterator, Optional | |
import sortedcontainers # type: ignore | |
ZERO = 0 | |
POS_INF = (1<<64)-1 | |
class Location(enum.IntEnum): | |
invalid = -1 | |
none = 0 | |
woods = enum.auto() | |
store = enum.auto() | |
class ContextKeys(enum.IntEnum): | |
invalid = -1 | |
money = enum.auto() | |
have_axe = enum.auto() | |
forest = enum.auto() | |
wood = enum.auto() | |
location = enum.auto() | |
class State: | |
def __init__(self, values:Mapping[ContextKeys, int]) -> None: | |
self.values = values | |
def __str__(self) -> str: | |
return str([f'{k.name} = {v}' for k,v in self.values.items()]) | |
def __repr(self) -> str: | |
return repr(self.values) | |
@dataclass(order=True, eq=True, unsafe_hash=True) | |
class Bound: | |
key: ContextKeys | |
low: int | |
high: int | |
def inc(self, amount: int) -> None: | |
assert amount > 0 | |
self.low = self.low + amount | |
if self.high < POS_INF: | |
self.high += amount | |
def atleast(self, amount:int) -> None: | |
assert amount > 0 | |
if self.low < amount: | |
self.low = amount | |
if self.high < amount: | |
self.high = amount | |
def dec(self, amount: int) -> None: | |
assert amount > 0 | |
self.low = max(self.low - amount, ZERO) | |
def set(self, amount:int) -> None: | |
self.low = amount | |
self.high = amount | |
def nontrivial(self) -> bool: | |
return self.low > ZERO or self.high < POS_INF | |
def satisfies(self, other: "Bound") -> bool: | |
""" True iff any state that satisfies us satisfies other """ | |
if self.key != other.key: | |
return False | |
# we need to be at least as narrow as other | |
return self.low >= other.low and self.high <= other.high | |
def __repr__(self) -> str: | |
return str(self) | |
def __str__(self) -> str: | |
if self.low > 0: | |
l = f'{self.low} <= ' | |
else: | |
l = "" | |
if self.high < POS_INF: | |
h = f' <= {self.high}' | |
else: | |
h = "" | |
return f'{l}{self.key.name}{h}' | |
EMPTY_BOUND = Bound(ContextKeys.invalid, ZERO, POS_INF) | |
def distance(k: ContextKeys, delta: int) -> float: | |
if k == ContextKeys.have_axe: | |
return 20*delta | |
elif k == ContextKeys.location: | |
return 1. if delta > 0 else 0. | |
else: | |
return delta | |
class Goal: | |
def __init__(self, bounds:Mapping[ContextKeys, Bound]) -> None: | |
#assert all(x.low <= x.high for x in bounds.values()) | |
self.bounds = {k:v for k,v in bounds.items() if v.nontrivial()} | |
self.goal_value = math.inf | |
self.fitness_value = math.inf | |
def delta(self, state:State) -> float: | |
d = 0. | |
for key, bound in self.bounds.items(): | |
if state.values[key] < bound.low: | |
d += distance(key, bound.low - state.values[key]) | |
elif state.values[key] > bound.high: | |
d += distance(key, state.values[key] - bound.high) | |
# else we're already in bound | |
return d | |
def satisfies(self, other: "Goal") -> bool: | |
for k in other.bounds: | |
if k not in self.bounds: | |
return False | |
if not self.bounds[k].satisfies(other.bounds[k]): | |
return False | |
# it's ok for us to have extra conditions | |
return True | |
def __repr__(self) -> str: | |
return str(self) | |
def __str__(self) -> str: | |
return str(list(self.bounds.values())) | |
def __eq__(self, other:Any) -> bool: | |
if not isinstance(other, Goal): | |
return False | |
return self.bounds == other.bounds | |
def __lt__(self, other:Any) -> bool: | |
if not isinstance(other, Goal): | |
raise ValueError(f'both items must be goals') | |
for k in ContextKeys: | |
if k in self.bounds: | |
if k in other.bounds: | |
if self.bounds[k].low < other.bounds[k].low: | |
return True | |
elif self.bounds[k].high < other.bounds[k].high: | |
return True | |
# else bounds for k are equal, consider next | |
else: | |
return False | |
elif k in other.bounds: | |
return True | |
# else neigher has k, consider next | |
# they must be equal | |
return False | |
def __hash__(self) -> int: | |
h = 0 | |
# always iterate over these (and compute hash) in same order | |
for k in ContextKeys: | |
if k in self.bounds: | |
h = 31 * h + hash(self.bounds[k]) | |
return h | |
@dataclass | |
class Action: | |
name: str | |
cost: float | |
class ActionFactory: | |
def _copy_bounds(self, goal: Goal) -> MutableMapping[ContextKeys, Bound]: | |
return {k: Bound(k, b.low, b.high) for k, b in goal.bounds.items()} | |
def compatible(self, goal:Goal) -> bool: | |
return False | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
# add precondition to goal | |
# backout post condition | |
return (Action("hi", 5.), Goal({})) | |
class BuyAxe(ActionFactory): | |
def compatible(self, goal:Goal) -> bool: | |
return ( | |
goal.bounds.get(ContextKeys.have_axe, EMPTY_BOUND).low > 0 and | |
goal.bounds.get(ContextKeys.location, EMPTY_BOUND).low in [0, Location.store] | |
) | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
new_bounds = self._copy_bounds(goal) | |
#prec: location = store | |
if ContextKeys.location in new_bounds: | |
new_bounds[ContextKeys.location].set(Location.store) | |
else: | |
new_bounds[ContextKeys.location] = Bound(ContextKeys.location, Location.store, Location.store) | |
# prec: 20 <= money <= POS_INF | |
# postc: money -= 20 | |
if ContextKeys.money not in goal.bounds: | |
new_bounds[ContextKeys.money] = Bound(ContextKeys.money, 20, POS_INF) | |
else: | |
new_bounds[ContextKeys.money].inc(20) | |
# postc: has_axe += 1 | |
if ContextKeys.have_axe in goal.bounds: | |
new_bounds[ContextKeys.have_axe].dec(1) | |
else: | |
# goal didn't care about an axe, doesn't change the goal | |
pass | |
return Action("buy axe", 20.), Goal(new_bounds) | |
class SellWood(ActionFactory): | |
def compatible(self, goal:Goal) -> bool: | |
return ( | |
goal.bounds.get(ContextKeys.money, EMPTY_BOUND).low > 0 and | |
goal.bounds.get(ContextKeys.location, EMPTY_BOUND).low in [0, Location.store] | |
) | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
new_bounds = self._copy_bounds(goal) | |
#prec: location = store | |
if ContextKeys.location in new_bounds: | |
new_bounds[ContextKeys.location].set(Location.store) | |
else: | |
new_bounds[ContextKeys.location] = Bound(ContextKeys.location, Location.store, Location.store) | |
# prec: 1 <= wood <= POS_INF | |
# postc: wood -= 1 | |
if ContextKeys.wood not in new_bounds: | |
new_bounds[ContextKeys.wood] = Bound(ContextKeys.wood, 1, POS_INF) | |
else: | |
new_bounds[ContextKeys.wood].inc(1) | |
# postc: money += 1 | |
if ContextKeys.money in new_bounds: | |
new_bounds[ContextKeys.money].dec(1) | |
else: | |
# goal didn't care about money, doesn't change the goal | |
pass | |
return Action("sell wood", 1.), Goal(new_bounds) | |
class ChopWood(ActionFactory): | |
def compatible(self, goal:Goal) -> bool: | |
return ( | |
goal.bounds.get(ContextKeys.wood, EMPTY_BOUND).low > 0 and | |
goal.bounds.get(ContextKeys.location, EMPTY_BOUND).low in [0, Location.woods] | |
) | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
new_bounds = self._copy_bounds(goal) | |
#prec: location = woods | |
if ContextKeys.location in new_bounds: | |
new_bounds[ContextKeys.location].set(Location.woods) | |
else: | |
new_bounds[ContextKeys.location] = Bound(ContextKeys.location, Location.woods, Location.woods) | |
# prec: 1 <= have_axe < POS_INF | |
if ContextKeys.have_axe not in new_bounds: | |
new_bounds[ContextKeys.have_axe] = Bound(ContextKeys.have_axe, 1, POS_INF) | |
else: | |
new_bounds[ContextKeys.have_axe].atleast(1) | |
# prec: 1 <= forest < POS_INF | |
# postc: forest -= 1 | |
if ContextKeys.forest not in new_bounds: | |
new_bounds[ContextKeys.forest] = Bound(ContextKeys.forest, 10, POS_INF) | |
else: | |
new_bounds[ContextKeys.forest].inc(10) | |
# postc: wood += 1 | |
if ContextKeys.wood in new_bounds: | |
new_bounds[ContextKeys.wood].dec(10) | |
else: | |
# goal didn't care about wood, no change | |
pass | |
return Action("chop wood", 10.), Goal(new_bounds) | |
class GatherWood(ActionFactory): | |
def compatible(self, goal:Goal) -> bool: | |
return ( | |
goal.bounds.get(ContextKeys.wood, EMPTY_BOUND).low > 0 and | |
goal.bounds.get(ContextKeys.location, EMPTY_BOUND).low in [0, Location.woods] | |
) | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
new_bounds = self._copy_bounds(goal) | |
#prec: location = woods | |
if ContextKeys.location in new_bounds: | |
new_bounds[ContextKeys.location].set(Location.store) | |
else: | |
new_bounds[ContextKeys.location] = Bound(ContextKeys.location, Location.woods, Location.woods) | |
# prec: 1 <= forest < POS_INF | |
# postc: forest -= 1 | |
if ContextKeys.forest not in new_bounds: | |
new_bounds[ContextKeys.forest] = Bound(ContextKeys.forest, 10, POS_INF) | |
else: | |
new_bounds[ContextKeys.forest].inc(10) | |
# postc: wood += 1 | |
if ContextKeys.wood in new_bounds: | |
new_bounds[ContextKeys.wood].dec(10) | |
else: | |
# goal didn't care about wood, no change | |
pass | |
return Action("gather wood", 100.), Goal(new_bounds) | |
class SellAxe(ActionFactory): | |
def compatible(self, goal:Goal) -> bool: | |
return ( | |
goal.bounds.get(ContextKeys.money, EMPTY_BOUND).low > 0 and | |
goal.bounds.get(ContextKeys.location, EMPTY_BOUND).low in [0, Location.store] | |
) | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
new_bounds = self._copy_bounds(goal) | |
#prec: location = store | |
if ContextKeys.location in new_bounds: | |
new_bounds[ContextKeys.location].set(Location.store) | |
else: | |
new_bounds[ContextKeys.location] = Bound(ContextKeys.location, Location.store, Location.store) | |
# prec: 1 <= have_axe < POS_INF | |
# post: have_axe -= 1 | |
if ContextKeys.have_axe not in new_bounds: | |
new_bounds[ContextKeys.have_axe] = Bound(ContextKeys.have_axe, 1, POS_INF) | |
else: | |
new_bounds[ContextKeys.have_axe].inc(1) | |
# postc: money += 10 | |
if ContextKeys.money in new_bounds: | |
new_bounds[ContextKeys.money].dec(10) | |
else: | |
# goal didn't care about money, doesn't change the goal | |
pass | |
return Action("sell axe", 10.), Goal(new_bounds) | |
class GoTo(ActionFactory): | |
def compatible(self, goal:Goal) -> bool: | |
return goal.bounds.get(ContextKeys.location, EMPTY_BOUND).low > 0 | |
def neighbor(self, goal:Goal) -> Tuple[Action, Goal]: | |
new_bounds = self._copy_bounds(goal) | |
# postc: location is whatever the goal wants | |
del new_bounds[ContextKeys.location] | |
return Action(f'goto({Location(goal.bounds[ContextKeys.location].low).name})', 5.), Goal(new_bounds) | |
action_factories:List[ActionFactory] = [ | |
BuyAxe(), | |
SellWood(), | |
ChopWood(), | |
GatherWood(), | |
SellAxe(), | |
GoTo(), | |
] | |
@dataclass(order=True, eq=True, unsafe_hash=True) | |
class FitnessAndGoal: | |
fitness: float | |
goal: Goal | |
class GoalQueue: | |
def __init__(self) -> None: | |
self.queue = sortedcontainers.SortedSet(key=lambda x: -x.fitness) | |
# mapping from a generic goal (with any fitness value) to the one | |
# that's actually in the queue | |
self.goals: MutableMapping[Goal, FitnessAndGoal] = {} | |
def __len__(self) -> int: | |
return len(self.goals) | |
def __contains__(self, g:Goal) -> bool: | |
return g in self.goals | |
def __getitem__(self, g:Goal) -> Goal: | |
return self.goals[g].goal | |
def __iter__(self) -> Iterator[Goal]: | |
for fg in self.queue: | |
yield fg.goal | |
def add(self, goal:Goal) -> None: | |
f = FitnessAndGoal(goal.fitness_value, goal) | |
self.goals[goal] = f | |
self.queue.add(f) | |
def remove(self, goal:Goal) -> None: | |
f = self.goals[goal] | |
self.queue.remove(f) | |
del self.goals[goal] | |
def pop(self) -> Goal: | |
f = self.queue.pop() | |
del self.goals[f.goal] | |
return f.goal | |
def heuristic_cost(current_state: Goal, initial_state: State) -> float: | |
# return estimate of cost between current_state and intial_state | |
return current_state.delta(initial_state) | |
def choose_cheapest(open_set: GoalQueue) -> Goal: | |
# return the state with lowest path cost + heuristic cost to initial_state | |
return open_set.pop() | |
#cheapest_cost = math.inf | |
#cheapest:Goal = None # type: ignore[assignment] | |
#for state in open_set: | |
# if state.fitness_value < cheapest_cost: | |
# cheapest_cost = state.fitness_value | |
# cheapest = state | |
#open_set.remove(cheapest) | |
#return cheapest | |
def reconstruct_path(came_from: Mapping[Goal, Tuple[Goal, Action]], current: Goal) -> Sequence[Tuple[Goal, Action]]: | |
# return the sequence of edges from goal state to this state | |
total_path = [(current, Action("initial", 0.))] | |
while current in came_from: | |
current, edge = came_from[current] | |
total_path.append((current, edge)) | |
return total_path | |
def get_neighbors(state: Goal) -> Sequence[Tuple[Action, Goal]]: | |
# return a set of edges from state | |
neighbors:List[Tuple[Action, Goal]] = [] | |
for f in action_factories: | |
if f.compatible(state): | |
neighbors.append(f.neighbor(state)) | |
return neighbors | |
def cost(edge: Action) -> float: | |
# return the cost of traversing edge | |
return edge.cost | |
class CKeys(enum.IntEnum): | |
in_open_set = 0 | |
in_closed_set = enum.auto() | |
no_improvement = enum.auto() | |
neighbor_options = enum.auto() | |
COUNTERS = [0] * len(CKeys) | |
def astar(goal_state: Goal, initial_state: State) -> Sequence[Tuple[Goal, Action]]: | |
open_set = GoalQueue() | |
goal_state.goal_value = 0 | |
goal_state.fitness_value = heuristic_cost(goal_state, initial_state) | |
open_set.add(goal_state) | |
closed_set: MutableMapping[Goal, Goal] = {} | |
came_from: MutableMapping[Goal, Tuple[Goal, Action]] = {} | |
global COUNTERS | |
best_distance = math.inf | |
best_goal:Optional[Goal] = None | |
while len(open_set) > 0: | |
current = choose_cheapest(open_set) | |
logging.debug(f'considering {current} with f_score: {current.fitness_value}') | |
closed_set[current] = current | |
current_distance = current.delta(initial_state) | |
if current_distance == 0.: | |
logging.debug(f'compatible with initial state') | |
return reconstruct_path(came_from, current) | |
if current_distance < best_distance: | |
logging.info(f'closest goal: {current} distance: {current_distance}') | |
best_distance = current_distance | |
best_goal = current | |
for edge, destination in get_neighbors(current): | |
#TODO: shouldn't we check if we have a shorter path to destination and re-open it if so? | |
COUNTERS[CKeys.neighbor_options] += 1 | |
if destination in closed_set: | |
COUNTERS[CKeys.in_closed_set] += 1 | |
destination = closed_set[destination] | |
elif destination in open_set: | |
COUNTERS[CKeys.in_open_set] += 1 | |
destination = open_set[destination] | |
tentative_g_score = current.goal_value + cost(edge) | |
if tentative_g_score < destination.goal_value: | |
came_from[destination] = (current, edge) | |
destination.goal_value = tentative_g_score | |
destination.fitness_value = tentative_g_score + heuristic_cost(destination, initial_state) | |
logging.debug(f'adding {destination} via {edge.name} to open set with g_score {destination.goal_value} f_score {destination.fitness_value}') | |
if destination in open_set: | |
# destination came from open set, but we have a cheaper way | |
# to get there. before we modify it, let's remove it | |
open_set.remove(destination) | |
open_set.add(destination) | |
else: | |
if destination in closed_set: | |
del closed_set[destination] | |
open_set.add(destination) | |
else: | |
COUNTERS[CKeys.no_improvement] += 1 | |
raise Exception("ohnoes") | |
if __name__ == "__main__": | |
try: | |
logging.basicConfig(stream=sys.stderr, level=logging.INFO) | |
goal = Goal({ContextKeys.money: Bound(ContextKeys.money, 50, POS_INF)}) | |
initial_state = State({ | |
ContextKeys.forest: 1000, | |
ContextKeys.have_axe: 0, | |
ContextKeys.wood: 0, | |
ContextKeys.money: 25, | |
ContextKeys.location: 0, | |
}) | |
logging.info(f'looking for a solution for: {goal}') | |
logging.info(f'starting conditions: {initial_state}') | |
starttime = time.perf_counter() | |
#import tqdm | |
#for i in tqdm.tqdm(range(30)): | |
# solution = astar(goal, initial_state) | |
solution = astar(goal, initial_state) | |
endtime = time.perf_counter() | |
solution_cost = sum(x[1].cost for x in solution) | |
logging.info(f'solution of length {len(solution)} of cost {solution_cost} found in {endtime-starttime:.2f}s') | |
counter_str = "\n".join(f'{k.name}:\t{COUNTERS[k]}' for k in CKeys) | |
logging.info(f'counters:\n{counter_str}') | |
for goal, action in solution: | |
print(f'{action.name} {action.cost} {goal} {goal.fitness_value} {goal.delta(initial_state)}') | |
except Exception as e: | |
logging.error(f'handling exception {e}') | |
pdb.post_mortem(sys.exc_info()[2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment