Created
November 29, 2019 05:02
-
-
Save mir-ikbch/6df11b5ba869c1d17fed5681104f5eb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Require Import List PeanoNat Arith Morphisms Setoid. | |
Import ListNotations. | |
Set Implicit Arguments. | |
Ltac optimize db := | |
eexists; intros; | |
match goal with | |
|- ?X = ?Y => | |
let P := fresh in | |
set (P := fun y => y = Y); | |
enough (P X) by auto | |
end; | |
repeat autounfold with db; | |
autorewrite with db; | |
reflexivity. | |
Definition sqr_list xs := map (fun x => x * x) xs. | |
Definition sqr_add2_list xs := map (fun x => x + 2) (sqr_list xs). | |
Hint Unfold sqr_list sqr_add2_list : my_db. | |
Hint Rewrite map_map : my_db. | |
Definition sa_optimized_sig : { f | forall xs, sqr_add2_list xs = f xs } := | |
ltac:(optimize my_db). | |
Eval simpl in proj1_sig sa_optimized_sig. | |
(* result : | |
[[[ | |
= map (fun x : nat => x * x + 2) | |
: list nat -> list nat | |
]]] | |
*) | |
Lemma fold_left_map : forall (A B C : Type)(f : A -> B) g l (init:C), | |
fold_left g (map f l) init = fold_left (fun x y => g x (f y)) l init. | |
Proof. | |
induction l; simpl; auto. | |
Qed. | |
Definition sum xs := fold_left Nat.add xs 0. | |
Definition sum_of_square xs := sum (map (fun x => x * x) xs). | |
Hint Unfold sum sum_of_square : my_db. | |
Hint Rewrite fold_left_map map_map : my_db. | |
Definition ss_optimized_sig : { prog | forall xs, sum_of_square xs = prog xs } | |
:= ltac:(optimize my_db). | |
Eval simpl in proj1_sig ss_optimized_sig. | |
(* result : | |
[[[ | |
= fun xs : list nat => fold_left (fun x y : nat => x + y * y) xs 0 | |
: list nat -> nat | |
]]] | |
*) | |
Require Import ZArith. | |
Open Scope Z_scope. | |
Fixpoint divide_list (A : Type)(xs : list A) := | |
match xs with | |
| [] => ([],[]) | |
| x::xs' => | |
let (xs0,xs1) := divide_list xs' in | |
(x::xs1,xs0) | |
end. | |
Fixpoint map2 (A B C : Type)(f : A -> B -> C)(xs : list A)(ys : list B) := | |
match xs,ys with | |
| [],_ => [] | |
| _,[] => [] | |
| x::xs',y::ys' => f x y :: map2 f xs' ys' | |
end. | |
Section Fft. | |
Variable N : Z. | |
Definition modulus := 2 ^ (2 ^ N) + 1. | |
Definition add x y := | |
(x + y) mod modulus. | |
Definition sub x y := | |
(x - y) mod modulus. | |
Definition mul x y := | |
(x * y) mod modulus. | |
Definition pow_pos x := | |
Pos.iter (mul x) 1. | |
Definition pow x y := | |
match y with | |
| 0 => 1 | |
| Z.pos p => pow_pos x p | |
| Z.neg _ => 0 | |
end. | |
Definition prt n : Z := | |
pow 2 (pow 2 (N + 1 - n)). | |
Definition pow_neg1 (n : Z) := | |
if Z.even n then 1 else -1. | |
Definition shiftl x y := | |
Z.of_nat (Nat.shiftl (Z.to_nat (x mod modulus)) (Z.to_nat y)) mod modulus. | |
Axiom shiftl_mul_pow2 : forall x y, | |
mul (pow 2 x) y = shiftl y x. | |
Axiom pow_pow : forall x y z, | |
pow (pow x y) z = pow x (mul y z). | |
Fixpoint Zseq x n := | |
match n with | |
| O => [] | |
| S n' => x :: Zseq (1 + x) n' | |
end. | |
Definition Zrange x y := | |
Zseq x (Z.to_nat (1 + y - x)). | |
Notation "[ x ;..; y ]" := (Zrange x y) (at level 0). | |
Lemma map2_ext : forall A B C (f g : A -> B -> C) xs ys, | |
(forall a b, f a b = g a b) -> | |
map2 f xs ys = map2 g xs ys. | |
Proof. | |
induction xs; simpl; intros. | |
auto. | |
destruct ys. | |
auto. | |
rewrite H. | |
f_equal. | |
auto. | |
Qed. | |
Instance map2_morphism A B C: | |
Proper (pointwise_relation _ (pointwise_relation _ eq) ==> @eq (list A) ==> @eq (list B) ==> @eq (list C)) (@map2 A B C). | |
Proof. | |
unfold Proper, respectful, pointwise_relation. | |
intros. | |
subst. | |
apply map2_ext; auto. | |
Qed. | |
Lemma map2_map : forall A A' B C (f : A -> B -> C) (g : A' -> A) xs ys, | |
map2 f (map g xs) ys = map2 (fun x y => f (g x) y) xs ys. | |
Proof. | |
induction xs; simpl; intros; | |
destruct ys; congruence. | |
Qed. | |
Definition fft_aux xs n := | |
let omegas := | |
map (fun k => pow (prt (n + 1)) k) [0 ;..; 2 ^ n] | |
in | |
map2 mul omegas xs. | |
Ltac setoid_autorewrite' rew := | |
lazymatch rew with | |
| (?rew',?lem',?lem) => setoid_rewrite lem + setoid_autorewrite' (rew',lem') | |
| (?lem',?lem) => setoid_rewrite lem || setoid_rewrite lem' | |
end. | |
Ltac setoid_autorewrite rew := repeat setoid_autorewrite' rew. | |
Ltac optimize' db rew := | |
eexists; intros; | |
match goal with | |
|- ?X = ?Y => | |
let P := fresh in | |
set (P := fun y => y = Y); | |
enough (P X) by auto | |
end; | |
repeat autounfold with db; | |
setoid_autorewrite rew; | |
reflexivity. | |
Hint Unfold fft_aux prt : my_db. | |
Definition fft_aux_sig : | |
{ f | forall xs n, fft_aux xs n = f xs n }. | |
Proof. | |
optimize' my_db (shiftl_mul_pow2,pow_pow,map2_map). | |
Defined. | |
Eval simpl in proj1_sig fft_aux_sig. | |
(* result : | |
[[[ | |
= fun (xs : list Z) (n : Z) => | |
map2 (fun x y : Z => shiftl y (shiftl x (N + 1 - (n + 1)))) [0;..; 2 ^ n] xs | |
: list Z -> Z -> list Z | |
]]] | |
*) | |
Function fft (xs : list Z) (n:nat) := | |
match n with | |
| O => xs | |
| S n' => | |
let (xs0,xs1) := divide_list xs in | |
let xs0' := fft xs0 n' in | |
let xs1' := proj1_sig fft_aux_sig (fft xs1 n') (Z.of_nat n') in | |
let xs' := map2 add xs0' xs1' in | |
let xs'' := map2 sub xs0' xs1' in | |
xs' ++ xs'' | |
end. | |
End Fft. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment