Skip to content

Instantly share code, notes, and snippets.

@fullstackwebdev
Created April 10, 2024 01:47
Show Gist options
  • Save fullstackwebdev/1c41e65a65af1adf0c6d6466f0369770 to your computer and use it in GitHub Desktop.
Save fullstackwebdev/1c41e65a65af1adf0c6d6466f0369770 to your computer and use it in GitHub Desktop.
import json
from os import link
from pprint import pprint
import queue
import random
import threading
import dspy
from more_itertools import only
from networkx import nodes
from pydantic import BaseModel
from typing import List
from tqdm import tqdm
from utils.load_documents import load_data
from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch
from dspy.teleprompt.bootstrap import BootstrapFewShot
# Theorem Addition_Commutativity, Theorem Subtraction_Nonnegativity, Theorem , Multiplication_Associativity,
class CoqGenSignature(dspy.Signature):
"""The task is to generate valid coq code for a given theorem or proof statement. The generated code should be syntactically correct and logically valid. Think out your proof step by step, then implement a psuedocode version like coq, then using the pseudocode as a reference, generate a 100% syntax correct coq code including any neccessary imports."""
theorem = dspy.InputField(desc="theorem to be proved")
previous_error = dspy.InputField(desc="previous error message or N/A")
# proof = dspy.OutputField(desc="proof in natural language")
psuedocode = dspy.OutputField()
coq_code = dspy.OutputField(desc="machine runnable code no comments", prefix="```coq\n")
# final_comments = dspy.OutputField()
class CoqRewrite(dspy.Signature):
"""You are given coq code and an error message. Your goal is to understand the error message and rewrite the coq code to fix the error. The generated code should be syntactically correct and logically valid. Think out your proof step by step, then implement a psuedocode version like coq, then using the pseudocode as a reference, generate a 100% syntax correct coq code including any neccessary imports."""
bad_coq_code = dspy.InputField()
previous_error = dspy.InputField(desc="previous error message or N/A")
coq_code = dspy.OutputField(desc="machine runnable code no comments", prefix="```coq\n")
from dspy import Example
# Define the Coq examples
examples = [
Example(
theorem="Proof of Associativity of List Append",
coqtop="""
Require Import List.
Theorem app_assoc : forall (A B C : Type) (l : list A) (f : A -> B) (g : B -> C),
map (fun x => g (f x)) l = map g (map f l).
Proof.
intros A B C l f g.
induction l as [| a l' IHl'].
- simpl. reflexivity.
- simpl. rewrite IHl'. reflexivity.
Qed."""),
Example(
theorem="Proof of Multiplication by Zero",
coqtop="""Theorem mult_0_r : forall n : nat, n * 0 = 0.
Proof.
intros n.
induction n as [| n' IHn'].
- reflexivity.
- simpl. rewrite IHn'. reflexivity.
Qed."""),
Example(
theorem="Proof of Commutativity of Addition",
coqtop="""Theorem plus_comm : forall x y : nat, x + y = y + x.
Proof.
intros x y.
induction x as [| x' IHx'].
- simpl. rewrite <- plus_n_O. reflexivity.
- simpl. rewrite IHx'. rewrite <- plus_n_Sm. reflexivity.
Qed."""),
Example(
theorem="Proof of Lemma: 1 + 1 = 2",
coqtop="""Lemma sample_lemma: 1 + 1 = 2.
Proof.
simpl.
reflexivity.
Qed."""),
Example(
theorem="Proof of Transitivity of Equality",
coqtop="""Lemma eq_trans : forall (a b c : nat), a = b -> b = c -> a = c.
Proof.
intros a b c H1 H2.
rewrite H1. rewrite H2. reflexivity.
Qed."""),
Example(
theorem="Proof of Addition by Zero",
coqtop="""Lemma add_0_r : forall n : nat, n + 0 = n.
Proof.
intros n.
induction n as [| n' IHn'].
- reflexivity.
- simpl. rewrite IHn'. reflexivity.
Qed."""),
Example(
theorem="Proof of Commutativity of Addition",
coqtop="""Lemma add_comm : forall a b : nat, a + b = b + a.
Proof.
intros a b.
induction a as [| a' IHa'].
- simpl. rewrite <- plus_n_O. reflexivity.
- simpl. rewrite IHa'. rewrite <- plus_n_Sm. reflexivity.
Qed."""
),
Example(
theorem="Proof of Symmetry of Equality",
coqtop="""
Lemma eq_sym : forall (a b : nat), a = b -> b = a.
Proof.
intros a b H.
rewrite H. reflexivity.
Qed.
"""),
Example(
theorem="Proof that 0 is the Identity Element for Addition",
coqtop="""
Lemma zero_add : forall n : nat, 0 + n = n.
Proof.
intros n. simpl. reflexivity.
Qed.
"""),
Example(
theorem="Proof of Multiplication by One",
coqtop="""
Lemma mult_1_r : forall n : nat, n * 1 = n.
Proof.
intros n.
induction n as [| n' IHn'].
- reflexivity.
- simpl. rewrite IHn'. reflexivity.
Qed.
"""),
Example(
theorem="Proof of Reflexivity of Equality",
coqtop="""
Lemma eq_refl : forall (a : nat), a = a.
Proof.
intros a. reflexivity.
Qed.
"""),
Example(
theorem="Proof of Non-Zero Divisor",
coqtop="""
Lemma nonzero_div : forall a b : nat, a * b = 0 -> a = 0 \/ b = 0.
Proof.
intros a b H. destruct a.
- left. reflexivity.
- right. destruct b.
+ reflexivity.
+ simpl in H. inversion H.
Qed.
"""),
Example(
theorem="Proof of Simplification of Multiplication",
coqtop="""
Lemma mult_simplify : forall a b : nat, a * b = 0 -> a = 0 \/ b = 0.
Proof.
intros a b H.
destruct a.
- left. reflexivity.
- destruct b.
+ right. reflexivity.
+ inversion H.
Qed.
"""),
Example(
theorem="Proof of Neutral Element for Multiplication",
coqtop="""
Lemma mult_1_l : forall n : nat, 1 * n = n.
Proof.
intros n.
simpl. destruct n.
- reflexivity.
- simpl. rewrite <- plus_n_O. reflexivity.
Qed.
"""),
Example(
theorem="Proof that Subtraction by Itself Equals Zero",
coqtop="""
Lemma sub_self : forall n : nat, n - n = 0.
Proof.
intros n.
induction n as [| n' IHn'].
- simpl. reflexivity.
- simpl. apply IHn'.
Qed.
"""),
Example(
theorem="Proof of List Concatenation is Associative",
coqtop="""
Require Import Coq.Lists.List.
Import ListNotations.
Lemma concat_assoc : forall A (l m n: list A), (l ++ m) ++ n = l ++ (m ++ n).
Proof.
intros A l m n. induction l as [| a l' IHl'].
- reflexivity.
- simpl. rewrite IHl'. reflexivity.
Qed.
"""),
Example(
theorem="Proof that Zero is the Right Identity for Addition",
coqtop="""
Lemma add_0_r_corrected : forall n : nat, n + 0 = n.
Proof.
intros n. induction n as [| n' IHn'].
- reflexivity.
- simpl. rewrite IHn'. reflexivity.
Qed.
"""),
Example(
theorem="Proof of the Reflexivity of Equality (Simplified)",
coqtop="""
Lemma eq_refl_simplified : forall (a : nat), a = a.
Proof.
reflexivity.
Qed.
"""),
Example(
theorem="Proof of the Symmetry of Equality for Natural Numbers",
coqtop="""
Lemma eq_symm : forall (a b : nat), a = b -> b = a.
Proof.
intros a b H. rewrite H. reflexivity.
Qed.
""")
# our negative prompt
# Example(
# proof="Failing Theory",
# coqtop="""
# Lemma example_lemma : forall n : nat, n * 0 = n.
# Proof.
# intros n.
# induction n as [| n' IHn'].
# - reflexivity.
# - simpl. (* Mistake: Should be IHn', not IHn *)
# rewrite IHn. (* This line contains the deliberate mistake *)
# reflexivity.
# Qed.
# """
# ),
]
class CoqModule(dspy.Module):
def __init__(self):
super().__init__()
# self.coq_generate = dspy.ChainOfThought("theorem -> coq_code")
self.coq_generate = dspy.ChainOfThought(CoqGenSignature)
# self.coq_cleanup = dspy.ChainOfThought("coq_psuedocode -> coq_code")
self.coq_fix = dspy.ChainOfThought(CoqRewrite)
def forward(self, theorem: str) -> dspy.Prediction:
# coq_psuedocode = self.coq_generate(theorem=theorem).coq_psuedocode
# coq_code =
# return self.coq_cleanup(coq_psuedocode=coq_psuedocode)
num_hops = 2
coq_code = self.coq_generate(theorem=theorem).coq_code
# remove trailing ``` if it exists and everything after it
# coq_code = coq_code.split('```')[0]
example = dspy.Example(theorem=theorem, coq_code=coq_code)
# run_example_and_check(example)
for _ in range(num_hops):
result, output_str, error_str = run_example_and_check(example)
if not result:
_new_coq = dspy.Example(bad_coq_code=coq_code, previous_error=error_str)
_new_coq_code = self.coq_fix(_new_coq).coq_code
_new_example = dspy.Example(theorem=theorem, coq_code=_new_coq_code)
if run_example_and_check(_new_example):
return dspy.Prediction(coq_code=_new_example.coq_code)
return dspy.Prediction(coq_code=coq_code)
import subprocess
def run_example_and_check(example):
output = subprocess.run(['coqtop', '-quiet'], input=example.coq_code, text=True, capture_output=True)
if "No more goals." in output.stdout:
result = True
output_str = "Success"
error_str = None # Set error to None for success case
elif "Goals remaining or error occurred" in output.stdout:
result = False
output_str = "Fail"
error_str = output.stderr
else:
result = False
output_str = "Unknown"
error_str = output.stderr
# print(output_str, error_str)
# turbo.inspect_history(n=1)
# print(f"Coq inputs:\n----\n{example.coq_code}\n----")
# print(f"Coq output.stderr:\n----\n{output.stderr}\n----")
# input('enter')
return result, output_str, error_str
def factuality_metric(gold, pred, trace=None) -> bool:
result, output_str, error_str = run_example_and_check(pred)
return result
# predict = dspy.ChainOfThought(CoqGenSignature)
# examples = # for each example run .with_inputs('proof')
examples = [example.with_inputs('theorem') for example in examples]
training = examples[:10]
validation = examples[10:15]
import requests
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
API_BASE = 'http://localhost:6000/v1/'
# MODEL_NAME = requests.get(API_BASE+'models').json()['data'][0]['id']
MODEL_NAME='mixtral-turbo'
turbo = dspy.OpenAI(model=MODEL_NAME, api_base=API_BASE, api_key='asdf', timeout=200, temperature=0.2, presence_penalty=0.1, max_tokens=1500, top_p=0.89) # temperature=0.2, presence_penalty=0.1, max_tokens=1500, top_p=0.89,
dspy.settings.configure(lm=turbo)
# teleprompter = BootstrapFewShotWithRandomSearch(max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, num_candidate_programs=16, num_threads=20, metric=factuality_metric, )
teleprompter = BootstrapFewShot(max_bootstrapped_demos=4, max_labeled_demos=16, metric=factuality_metric, max_rounds=1)
coqgen = teleprompter.compile(CoqModule(), trainset=training, valset=validation)
# print(coqgen(theorem="Proof of Commutativity of Addition"))
turbo.inspect_history(n=99)
print(run_example_and_check(coqgen(theorem="Proof of Commutativity of Addition")))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment