Skip to content

Instantly share code, notes, and snippets.

@erutuf
Created September 25, 2018 23:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save erutuf/5523067c5ff6792430c63c4a2cd29bd3 to your computer and use it in GitHub Desktop.
Save erutuf/5523067c5ff6792430c63c4a2cd29bd3 to your computer and use it in GitHub Desktop.
Require Import List PeanoNat Arith.
Import ListNotations.
Set Implicit Arguments.
Lemma fold_left_map : forall (A B C : Type)(f : A -> B) g l (init:C),
fold_left g (map f l) init = fold_left (fun x y => g x (f y)) l init.
Proof.
induction l; simpl; auto.
Qed.
Definition sum xs := fold_left Nat.add xs 0.
Definition sum_of_square xs := sum (map Nat.square xs).
Hint Unfold sum sum_of_square : my_db.
Hint Rewrite fold_left_map map_map : my_rewrite_db.
Ltac optimize db rdb:=
repeat autounfold with db;
eexists; intros;
match goal with
|- ?X = ?Y => set (P := fun y => y = Y); enough (P X) by auto
end;
autorewrite with rdb; reflexivity.
Definition ss_optimized_sig : { prog | forall xs, sum_of_square xs = prog xs }.
Proof.
optimize my_db my_rewrite_db.
Defined.
Definition ss_optimized := Eval simpl in proj1_sig ss_optimized_sig.
Print ss_optimized.
(* result :
[[[
ss_optimized =
fun xs : list nat => fold_left (fun x y : nat => x + y * y) xs 0
: list nat -> nat
]]]
*)
Fixpoint fold_nat (A : Type) (f : A -> nat -> A) (start len : nat) (init : A) :=
match len with
| O => init
| S len' => fold_nat f (S start) len' (f init start)
end.
Notation "[ n ,, m ]" := (seq n (S (m - n))).
Lemma fold_seq_fold_nat : forall (A : Type) st ed (f : A -> nat -> A) init,
fold_left f [st,,ed] init = fold_nat f st (S (ed - st)) init.
Proof.
intros.
remember (S (ed - st)) as m.
clear Heqm. revert st init.
induction m; intros; simpl; auto.
Qed.
Hint Rewrite fold_seq_fold_nat : my_rewrite_db.
Definition example n := sum (map Nat.square [1,,n]).
Hint Unfold example : my_db.
Definition example_optimized_sig : { prog | forall n, example n = prog n }.
Proof.
optimize my_db my_rewrite_db.
Defined.
Definition example_optimized := Eval simpl in proj1_sig example_optimized_sig.
Print example_optimized.
(* result :
[[[
example_optimized =
fun n : nat => fold_nat (fun x y : nat => x + Nat.square y) 2 (n - 1) (Nat.square 1)
: nat -> nat
]]]
*)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment