Skip to content

Instantly share code, notes, and snippets.

@jorendorff
Created June 4, 2020 06:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jorendorff/0b1c492b81687e43ebc749ea1b353080 to your computer and use it in GitHub Desktop.
Save jorendorff/0b1c492b81687e43ebc749ea1b353080 to your computer and use it in GitHub Desktop.
-- https://leanprover.github.io/theorem_proving_in_lean/induction_and_recursion.html#exercises
-- 6. Consider the following type of arithmetic expressions. The idea is that
-- `var n` is a variable, v_n, and `const n` is the constant whose value
-- is n.
inductive aexpr : Type
| const : ℕ → aexpr
| var : ℕ → aexpr
| plus : aexpr → aexpr → aexpr
| times : aexpr → aexpr → aexpr
open aexpr
-- Write a function that evaluates such an expression, evaluating each `var
-- n` to `v n`.
def aeval (env : ℕ → ℕ) : aexpr → ℕ
| (const n) := n
| (var k) := env k
| (plus a b) := aeval a + aeval b
| (times a b) := aeval a * aeval b
-- Implement “constant fusion,” a procedure that simplifies subterms like
-- 5 + 7 to 12. Using the auxiliary function simp_const, define a function
-- “fuse”: to simplify a plus or a times, first simplify the arguments
-- recursively, and then apply simp_const to try to simplify the result.
def simp_const : aexpr → aexpr
| (plus (const n₁) (const n₂)) := const (n₁ + n₂)
| (times (const n₁) (const n₂)) := const (n₁ * n₂)
| e := e
def fuse : aexpr → aexpr
| (plus a b) := simp_const (plus (simp_const a) (simp_const b))
| (times a b) := simp_const (times (simp_const a) (simp_const b))
| e := e
theorem simp_const_eq (v : ℕ → ℕ) :
∀ e : aexpr, aeval v (simp_const e) = aeval v e
:= λ e, match e with
| (plus (const a) (const b)) := by refl
| (times (const a) (const b)) := by refl
| e := have h : simp_const e = e, by refl, -- failed to unify
by rwa h
end
theorem fuse_eq (v : ℕ → ℕ) :
∀ e : aexpr, aeval v (fuse e) = aeval v e
:= begin
intro e, cases e,
case const : v { refl },
case var : v { refl },
case plus : a b {
exact calc
aeval v (fuse (plus a b)) = aeval v (simp_const (plus (simp_const a) (simp_const b))) : by refl
... = aeval v (plus (simp_const a) (simp_const b)) : by rw simp_const_eq
... = aeval v (simp_const a) + aeval v (simp_const b) : by refl
... = aeval v a + aeval v b : by rw [simp_const_eq, simp_const_eq]
... = aeval v (plus a b) : by refl
},
case times : a b {
exact calc
aeval v (fuse (times a b)) = aeval v (simp_const (times (simp_const a) (simp_const b))) : by refl
... = aeval v (times (simp_const a) (simp_const b)) : by rw simp_const_eq
... = aeval v (simp_const a) * aeval v (simp_const b) : by refl
... = aeval v a * aeval v b : by rw [simp_const_eq, simp_const_eq]
... = aeval v (times a b) : by refl
}
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment