Skip to content

Instantly share code, notes, and snippets.

@vzhong
Created October 17, 2016 02:15
Show Gist options
  • Save vzhong/f006894abc0d626c21394dfa943a4c42 to your computer and use it in GitHub Desktop.
Save vzhong/f006894abc0d626c21394dfa943a4c42 to your computer and use it in GitHub Desktop.
Basic search algorithms
"""
backtracking extended list informed
Depth First Search: y y n
Breadth First Search: n y n
Hill Climbing y y y
Beam Search: y y y
"""
import heapq
class SearchAlgorithm(object):
def __init__(self, get_actions, take_action, prune_seen_states=True):
super().__init__()
self.get_actions = get_actions
self.take_action = take_action
self.prune_seen_states = prune_seen_states
def initialize_queue(self):
return []
def extend_state(self, state, seen):
new_states = []
for a in self.get_actions(state):
new_state = self.take_action(state, a)
if self.prune_seen_states and new_state.current in seen:
continue
seen.add(new_state.current)
new_states.append(new_state)
return new_states
def enqueue(self, queue, new_states):
raise NotImplementedError()
def dequeue(self, queue):
return queue.pop(0)
def reorder_new_states(self, new_states):
return new_states
def prune_queue(self, queue):
pass
def __call__(self, state, terminate, callback=None):
queue = self.initialize_queue()
self.enqueue(queue, [state])
seen = set()
while len(queue):
s = self.dequeue(queue)
if callback is not None:
callback(s)
if terminate(s):
return s
new_states = self.extend_state(s, seen)
new_states = self.reorder_new_states(new_states)
self.enqueue(queue, new_states)
self.prune_queue(queue)
return None
class DepthFirstSearch(SearchAlgorithm):
def enqueue(self, queue, new_states):
for s in new_states:
queue.insert(0, s)
class BreadthFirstSearch(SearchAlgorithm):
def enqueue(self, queue, new_states):
for s in new_states:
queue.append(s)
class HillClimbingSearch(DepthFirstSearch):
# DFS but break ties by considering which one is closer to the goal
def __init__(self, get_actions, take_action, heuristic, prune_seen_states=True):
super().__init__(get_actions, take_action, prune_seen_states=prune_seen_states)
self.heuristic = heuristic
def reorder_new_states(self, new_states):
return sorted(new_states, key=self.heuristic)
class BeamSearch(BreadthFirstSearch):
# DFS but break ties by considering which one is closer to the goal
def __init__(self, get_actions, take_action, heuristic, beam_size, prune_seen_states=True):
super().__init__(get_actions, take_action, prune_seen_states=prune_seen_states)
self.heuristic = heuristic
self.beam_size = beam_size
def prune_queue(self, queue):
states_and_scores = [(s, self.heuristic(s)) for s in queue]
top_states_and_scores = heapq.nlargest(self.beam_size, states_and_scores, key=lambda tup: tup[1])
queue.clear()
for state, score in top_states_and_scores:
queue.append(state)
if __name__ == '__main__':
import random
import networkx as nx
from collections import namedtuple, defaultdict
State = namedtuple('State', ['current', 'history'])
def get_dists_to_target(g, target, dists=None):
if dists is None:
dists = defaultdict(lambda: float('inf'))
dists[target] = 0
for n in g.neighbors(target):
weight = g[target][n]['weight']
new_weight = dists[target] + weight
if new_weight < dists[n]:
dists[n] = new_weight
get_dists_to_target(g, n, dists)
return dists
def toy_graph():
"""
C---E
|
4
|
B G
/| /
5 4 5
/ | /
S-3-A-3-D
"""
g = nx.Graph()
g.add_edge('S', 'A', weight=3)
g.add_edge('S', 'B', weight=5)
g.add_edge('A', 'B', weight=4)
g.add_edge('B', 'C', weight=4)
g.add_edge('C', 'E', weight=0)
g.add_edge('A', 'D', weight=3)
g.add_edge('D', 'G', weight=5)
def dist_to_target(state):
dists = {'G': 0}
dists['D'] = dists['G'] + 5
dists['A'] = dists['D'] + 3
dists['S'] = dists['A'] + 3
dists['B'] = dists['A'] + 4
dists['C'] = dists['B'] + 4
dists['E'] = dists['C'] + 0
other = get_dists_to_target(g, 'G')
for k, v in dists.items():
assert other[k] == v, 'differ for k: {}, expect {}, got {}'.format(k, v, other[k])
return -dists[state.current]
return g, 'S', 'G', dist_to_target
def large_graph(num_nodes=50, num_edges=10):
g = nx.Graph()
dists = {}
target = 0
# add random edges
for i in range(num_edges):
start = end = random.randint(0, num_nodes-1)
while end == start:
end = random.randint(0, num_nodes-1)
weight = random.randint(1, 5)
g.add_edge(start, target, weight=weight)
# ensure that there is at least 1 path
for i in range(1, num_nodes):
g.add_edge(i-1, i, weight=5)
dists = get_dists_to_target(g, 0)
return g, 0, num_nodes-1, lambda state: -dists[state.current]
def get_controllers(g, end):
def get_actions(state):
actions = [n for n in g.neighbors(state.current) if n not in state.history]
return actions
def take_action(state, action):
return State(action, state.history + [state.current])
def terminate(state):
return state.current == end
return get_actions, take_action, terminate
def run_algorithms(get_graph):
g, start, end, heuristic = get_graph()
get_actions, take_action, terminate = get_controllers(g, end)
start_state = State(start, [])
algs = [
DepthFirstSearch(get_actions, take_action),
BreadthFirstSearch(get_actions, take_action),
HillClimbingSearch(get_actions, take_action, heuristic=heuristic),
BeamSearch(get_actions, take_action, heuristic=heuristic, beam_size=3),
]
print('heuristic ground truth: {}\n'.format(heuristic(State(end, []))))
for alg in algs:
total = [0]
def callback(state):
total[0] += 1
print('Using algorithm: {}'.format(alg.__class__.__name__))
path = alg(start_state, terminate, callback=callback)
print('done in {} steps'.format(total[0]))
print(path)
print()
random.seed(0)
print('toy problem')
run_algorithms(toy_graph)
print('large graph problem')
run_algorithms(large_graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment