Last active
May 9, 2024 05:11
-
-
Save brando90/769e71f376b4a9a78ad63c6989be6a52 to your computer and use it in GitHub Desktop.
initial_reprover_in_pypentrograph.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from pantograph.server import Server, ServerError | |
from pantograph.expr import GoalState, TacticHave, TacticCalc, Tactic | |
from loguru import logger | |
from dataclasses import dataclass, field | |
from typing import Optional, List, Tuple | |
@dataclass(frozen=True) | |
class SearchResult: | |
"""The result of attempting to prove a theorem.""" | |
goal: str | |
status: Status | |
proof: Optional[List[Tactic]] | |
actor_time: float | |
environment_time: float | |
total_time: float | |
num_total_nodes: int | |
num_searched_nodes: int | |
class BestFirstSearchProver: | |
"""A prover that uses best-first search to find proofs using a tactic generator.""" | |
def __init__( | |
self, | |
tac_gen, # A given tactic generator. | |
timeout: int, | |
num_sampled_tactics: int, | |
debug: bool, | |
server: Server | |
) -> None: | |
self.tac_gen = tac_gen | |
self.timeout = timeout | |
self.num_sampled_tactics = num_sampled_tactics | |
self.debug = debug | |
self.server = server | |
self.num_expansions = 0 | |
self.actor_time = 0.0 | |
self.environment_time = 0.0 | |
self.total_time = None | |
def search(self, goal: str) -> Optional[SearchResult]: | |
logger.info(f"Proving {goal}") | |
self.goal = goal | |
self.actor_time = 0.0 | |
self.environment_time = 0.0 | |
self.num_expansions = 0 | |
init_state = self.server.goal_start(goal) | |
self.root = InternalNode(state=init_state, cumulative_logprob=0.0) | |
self.nodes = {init_state: self.root} | |
self.priority_queue = [self.root] | |
with torch.no_grad(): | |
try: | |
self._best_first_search() | |
except ServerError as ex: | |
logger.warning(f"Server crashed with {ex} when proving {self.goal}") | |
if self.root.status == Status.PROVED: | |
proof = [e.tactic for e in self.root.extract_proof()] | |
else: | |
proof = None | |
result = SearchResult( | |
goal=goal, | |
status=self.root.status, | |
proof=proof, | |
actor_time=self.actor_time, | |
environment_time=self.environment_time, | |
total_time=self.total_time, | |
num_total_nodes=len(self.nodes), | |
num_searched_nodes=self.num_expansions, | |
) | |
logger.info(result) | |
return result | |
def _best_first_search(self) -> None: | |
time_start = time.monotonic() | |
while True: | |
if len(self.priority_queue) == 0: | |
logger.info("Ran out of nodes to search.") | |
break | |
try: | |
self._step() | |
except pexpect.exceptions.TIMEOUT: | |
assert time.monotonic() - time_start >= self.timeout | |
self.total_time = time.monotonic() - time_start | |
if self.total_time > self.timeout: | |
if self.root.status == Status.PROVED: | |
logger.info("Found a proof but timed out.") | |
self.root.status = Status.OPEN | |
logger.info("Search timed out.") | |
break | |
if self.root.status == Status.FAILED: | |
logger.info("Failed early!") | |
break | |
if self.root.status == Status.PROVED: | |
logger.info("Found a proof!") | |
break | |
def _step(self): | |
""" | |
Perform a single step of search. | |
Selects the node with the highest priority, queries the model for suggested | |
tactics, and tries each tactic in the environment, creating and enqueuing | |
a new node for each valid result. | |
""" | |
# Search the node with highest priority. | |
search_node = heapq.heappop(self.priority_queue) | |
logger.debug(f"Expanding node: {search_node}") | |
if self.debug: | |
assert all( | |
search_node.priority >= node.priority for node in self.priority_queue | |
) | |
ts = str(search_node.state.goals[-1]) | |
suggestions = self._generate_tactics(ts) | |
# Try all tactics in order of descending logprob, and collect the results. Any | |
# new nodes are added to `self.nodes`, and edges are added to the result node. | |
results = [] | |
for tactic, logprob in suggestions: | |
edge, finished = self._run_tactic(search_node, tactic, logprob) | |
results.append(edge) | |
if finished: | |
break | |
# Store the fixed out edges of this node, marking it as explored. | |
# This will trigger recursively recomputing tree statistics. | |
search_node.out_edges = results | |
self.num_expansions += 1 | |
# If we're running in debug mode, run a full test suite each step | |
if self.debug: | |
assert self.num_expansions == sum( | |
node.is_explored | |
for node in self.nodes.values() | |
if isinstance(node, InternalNode) | |
) | |
self.check_invariants() | |
def _generate_tactics(self, ts: str) -> List[Tuple[Tactic, float]]: | |
t0 = time.monotonic() | |
suggestions = self.tac_gen.generate( | |
state=ts, | |
theorem=self.goal, | |
num_samples=self.num_sampled_tactics, | |
) | |
self.actor_time += time.monotonic() - t0 | |
logger.debug(f"Tactic suggestions: {suggestions}") | |
return suggestions | |
def _run_tactic(self, node: InternalNode, tactic: Tactic, logprob: float) -> Tuple[Edge, bool]: | |
t0 = time.monotonic() | |
try: | |
response = self.server.goal_tactic(node.state, node.state.goals[-1], tactic) | |
except ServerError as ex: | |
response = ex | |
elapsed = time.monotonic() - t0 | |
self.environment_time += elapsed | |
try: | |
# If we've seen this response before, use the existing node | |
result_node = self.nodes[response] | |
except KeyError: | |
# Build a new node | |
if not response.goals: | |
result_node = ProofFinishedNode(response) | |
elif isinstance(response, ServerError): | |
result_node = ErrorNode(response) | |
else: | |
assert isinstance(response, GoalState) | |
result_node = InternalNode( | |
state=response, | |
cumulative_logprob=logprob + node.cumulative_logprob, | |
) | |
if result_node.status == Status.OPEN: # Don't search proved/failed nodes | |
heapq.heappush(self.priority_queue, result_node) | |
# Record the new node and add it to the search queue. | |
self.nodes[response] = result_node | |
# Build an edge connecting these nodes. | |
# Will be added to the source node externally. | |
edge = Edge(tactic=tactic, src=node, dst=result_node) | |
if isinstance(result_node, InternalNode): | |
result_node.in_edges.append(edge) | |
return edge, not response.goals |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment