Skip to content

Instantly share code, notes, and snippets.

@tmoux
Created February 24, 2024 06:20
Show Gist options
  • Save tmoux/06ed21574f402f56cad5fe0b9bcdfb81 to your computer and use it in GitHub Desktop.
Save tmoux/06ed21574f402f56cad5fe0b9bcdfb81 to your computer and use it in GitHub Desktop.
import Mathlib.Util.AtomM
import Mathlib.Data.Nat.Basic
open Lean Elab Meta Tactic Mathlib.Tactic
open Lean.Expr
-- Monoid simplifier tactic
-- Normalize expressions by flattening to a list of atoms and eliminating identity elements.
namespace Monoid
structure Context where
α : Expr -- The type of the expression
univ : Level -- The universe level of α
inst : Expr -- The Monoid instance
α1 : Expr -- the identity of the monoid
def mkContext (e : Expr) : MetaM Context := do
let α ← inferType e
let c ← synthInstance (← mkAppM ``Monoid #[α])
let u ← mkFreshLevelMVar
let α1 ← Expr.ofNat α 1
_ ← isDefEq (.sort (.succ u)) (← inferType α)
return ⟨ α, u, c, α1 ⟩
-- applies functions of form ∀ {u} {a : Type u}, _
def Context.app (c : Context) (n : Name) : Array Expr → Expr :=
mkAppN ((@Expr.const n [c.univ]).app c.α)
-- applies functions of form ∀ {u} {a : Type u} [Monoid α], _
def Context.appInst (c : Context) (n : Name) : Array Expr → Expr :=
mkAppN (((@Expr.const n [c.univ]).app c.α).app c.inst)
abbrev M := ReaderT Context AtomM
inductive exp (α : Type u) : Type u where
| atom : α → exp α
| mult : exp α → exp α → exp α
| mempty : exp α
@[simp] def exp.denote [Monoid α] (e : exp α) : α :=
match e with
| atom x => x
| mult x y => exp.denote x * exp.denote y
| mempty => 1
@[simp] def exp.normalize [Monoid α] : exp α → List α
| atom x => [x]
| mult x y => exp.normalize x ++ exp.normalize y
| mempty => []
@[simp] def nf.denote [Monoid α] : List α → α
| [] => 1
| x :: xs => x * nf.denote xs
lemma denote_mult_commute [Monoid α] (xs ys : List α) :
nf.denote xs * nf.denote ys = nf.denote (xs ++ ys) := by
induction xs
case nil => simp
case cons IH => simp [mul_assoc, IH]
theorem normalize_correct [Monoid α] (e : exp α) : nf.denote (exp.normalize e) = exp.denote e := by
induction e
case atom => simp
case mult xIH yIH => simp [xIH, yIH, ←denote_mult_commute]
case mempty => simp
theorem monoid_reflect [Monoid α] (a b : exp α) :
nf.denote (exp.normalize a) = nf.denote (exp.normalize b) →
exp.denote a = exp.denote b := by
repeat rw [normalize_correct]
simp
partial def exp.reify (e : Expr) : M Expr := do
match getAppFnArgs e with
| (``HMul.hMul, #[_, _, _, _, e₁, e₂]) => do
let e₁' ← exp.reify e₁
let e₂' ← exp.reify e₂
return (← read).app ``exp.mult #[e₁', e₂']
| _ =>
if ← isDefEq e (← read).α1 then
return (← read).app ``exp.mempty #[]
else
return (← read).app ``exp.atom #[e]
syntax (name := monoid) "monoid" : tactic
elab_rules : tactic | `(tactic| monoid) => withMainContext do
let some (_, e₁, e₂) := (← whnfR <| ← getMainTarget).eq?
| throwError "monoid: requires an equality goal"
let c ← mkContext e₁
closeMainGoal <| ← AtomM.run .default <| ReaderT.run (r := c) do
let t₁ ← exp.reify e₁ -- Reify the expressions into exp
let t₂ ← exp.reify e₂
let n₁ := (← read).appInst ``exp.normalize #[t₁] -- Normalize the expressions
let n₂ := (← read).appInst ``exp.normalize #[t₂]
unless ← isDefEq n₁ n₂ do
throwError "monoid: normalized forms not equal"
let m₁ := (← read).appInst ``nf.denote #[n₁]
let eq ← mkAppM ``Eq.refl #[m₁]
mkAppM ``monoid_reflect #[t₁, t₂, eq] -- apply the reflect theorem
-- Examples
instance : Monoid (List α) where
mul := List.append
one := []
mul_assoc := List.append_assoc
one_mul := List.nil_append
mul_one := List.append_nil
def f := 3
lemma ex1 : 1 * 2 * 3 = 1 * (2 * f) * 1 * 1 := by
monoid
lemma ex2 : ∀ (a b c : ℕ), a * 1 * b * (a + c) = a * (b * (a + c)):= by
intros
monoid
lemma ex3 : ∀ (a b c d : List Bool), [] * a * (b * c) * d = a * (b * c * d) := by
intros
monoid
lemma ex4 : ∀ (a b c d : ℕ), a * b * (b * d * c * a * 1 * d) = 1 * 1 * a * (1 * b) * (b * d) * c * (a * d) * 1 := by
intros
monoid
lemma ex5 : ∀ (a b c d : List α), a * b * c * d * [] = a * (b * c) * (1 * 1 * d) := by
intros
monoid
-- This fails:
-- lemma ex6 : ∀ (a b c : List α), a ++ b ++ c = a ++ (b ++ c) := by
-- intros
-- monoid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment