Skip to content

Instantly share code, notes, and snippets.

@brando90
Last active May 9, 2024 05:11
Show Gist options
  • Save brando90/769e71f376b4a9a78ad63c6989be6a52 to your computer and use it in GitHub Desktop.
Save brando90/769e71f376b4a9a78ad63c6989be6a52 to your computer and use it in GitHub Desktop.
initial_reprover_in_pypentrograph.py
import math
from pantograph.server import Server, ServerError
from pantograph.expr import GoalState, TacticHave, TacticCalc, Tactic
from loguru import logger
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
@dataclass(frozen=True)
class SearchResult:
"""The result of attempting to prove a theorem."""
goal: str
status: Status
proof: Optional[List[Tactic]]
actor_time: float
environment_time: float
total_time: float
num_total_nodes: int
num_searched_nodes: int
class BestFirstSearchProver:
"""A prover that uses best-first search to find proofs using a tactic generator."""
def __init__(
self,
tac_gen, # A given tactic generator.
timeout: int,
num_sampled_tactics: int,
debug: bool,
server: Server
) -> None:
self.tac_gen = tac_gen
self.timeout = timeout
self.num_sampled_tactics = num_sampled_tactics
self.debug = debug
self.server = server
self.num_expansions = 0
self.actor_time = 0.0
self.environment_time = 0.0
self.total_time = None
def search(self, goal: str) -> Optional[SearchResult]:
logger.info(f"Proving {goal}")
self.goal = goal
self.actor_time = 0.0
self.environment_time = 0.0
self.num_expansions = 0
init_state = self.server.goal_start(goal)
self.root = InternalNode(state=init_state, cumulative_logprob=0.0)
self.nodes = {init_state: self.root}
self.priority_queue = [self.root]
with torch.no_grad():
try:
self._best_first_search()
except ServerError as ex:
logger.warning(f"Server crashed with {ex} when proving {self.goal}")
if self.root.status == Status.PROVED:
proof = [e.tactic for e in self.root.extract_proof()]
else:
proof = None
result = SearchResult(
goal=goal,
status=self.root.status,
proof=proof,
actor_time=self.actor_time,
environment_time=self.environment_time,
total_time=self.total_time,
num_total_nodes=len(self.nodes),
num_searched_nodes=self.num_expansions,
)
logger.info(result)
return result
def _best_first_search(self) -> None:
time_start = time.monotonic()
while True:
if len(self.priority_queue) == 0:
logger.info("Ran out of nodes to search.")
break
try:
self._step()
except pexpect.exceptions.TIMEOUT:
assert time.monotonic() - time_start >= self.timeout
self.total_time = time.monotonic() - time_start
if self.total_time > self.timeout:
if self.root.status == Status.PROVED:
logger.info("Found a proof but timed out.")
self.root.status = Status.OPEN
logger.info("Search timed out.")
break
if self.root.status == Status.FAILED:
logger.info("Failed early!")
break
if self.root.status == Status.PROVED:
logger.info("Found a proof!")
break
def _step(self):
"""
Perform a single step of search.
Selects the node with the highest priority, queries the model for suggested
tactics, and tries each tactic in the environment, creating and enqueuing
a new node for each valid result.
"""
# Search the node with highest priority.
search_node = heapq.heappop(self.priority_queue)
logger.debug(f"Expanding node: {search_node}")
if self.debug:
assert all(
search_node.priority >= node.priority for node in self.priority_queue
)
ts = str(search_node.state.goals[-1])
suggestions = self._generate_tactics(ts)
# Try all tactics in order of descending logprob, and collect the results. Any
# new nodes are added to `self.nodes`, and edges are added to the result node.
results = []
for tactic, logprob in suggestions:
edge, finished = self._run_tactic(search_node, tactic, logprob)
results.append(edge)
if finished:
break
# Store the fixed out edges of this node, marking it as explored.
# This will trigger recursively recomputing tree statistics.
search_node.out_edges = results
self.num_expansions += 1
# If we're running in debug mode, run a full test suite each step
if self.debug:
assert self.num_expansions == sum(
node.is_explored
for node in self.nodes.values()
if isinstance(node, InternalNode)
)
self.check_invariants()
def _generate_tactics(self, ts: str) -> List[Tuple[Tactic, float]]:
t0 = time.monotonic()
suggestions = self.tac_gen.generate(
state=ts,
theorem=self.goal,
num_samples=self.num_sampled_tactics,
)
self.actor_time += time.monotonic() - t0
logger.debug(f"Tactic suggestions: {suggestions}")
return suggestions
def _run_tactic(self, node: InternalNode, tactic: Tactic, logprob: float) -> Tuple[Edge, bool]:
t0 = time.monotonic()
try:
response = self.server.goal_tactic(node.state, node.state.goals[-1], tactic)
except ServerError as ex:
response = ex
elapsed = time.monotonic() - t0
self.environment_time += elapsed
try:
# If we've seen this response before, use the existing node
result_node = self.nodes[response]
except KeyError:
# Build a new node
if not response.goals:
result_node = ProofFinishedNode(response)
elif isinstance(response, ServerError):
result_node = ErrorNode(response)
else:
assert isinstance(response, GoalState)
result_node = InternalNode(
state=response,
cumulative_logprob=logprob + node.cumulative_logprob,
)
if result_node.status == Status.OPEN: # Don't search proved/failed nodes
heapq.heappush(self.priority_queue, result_node)
# Record the new node and add it to the search queue.
self.nodes[response] = result_node
# Build an edge connecting these nodes.
# Will be added to the source node externally.
edge = Edge(tactic=tactic, src=node, dst=result_node)
if isinstance(result_node, InternalNode):
result_node.in_edges.append(edge)
return edge, not response.goals
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment