Created
May 4, 2018 01:12
-
-
Save nekketsuuu/6733a0f3f9074161480cf6a8e8485624 to your computer and use it in GitHub Desktop.
『Recursion Scheme テクニック』 (@eldesh) の内容の一部を OCaml で書きました。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* general definition of signature using recursion scheme *) | |
module type TYP = sig | |
type t | |
type 'a f | |
val fmap : ('a -> 'b) -> 'a f -> 'b f | |
val inj : t f -> t | |
val prj : t -> t f | |
end | |
module Rec (T : TYP) : sig | |
val fold : ('a T.f -> 'a) -> T.t -> 'a | |
val unfold : ('a -> 'a T.f) -> 'a -> T.t | |
end = struct | |
let wrap f g h x = f (T.fmap h (g x)) | |
let rec fold f x = wrap f T.prj (fold f) x | |
let rec unfold g x = wrap T.inj g (unfold g) x | |
end | |
module Injs (T : TYP) = struct | |
open T | |
let inj_succ inj_pred x = inj (fmap inj_pred x) | |
let inj1 = inj | |
let inj2 = inj_succ inj1 | |
let inj3 = inj_succ inj2 | |
let inj4 = inj_succ inj3 | |
let inj5 = inj_succ inj4 | |
end | |
module Prjs (T : TYP) = struct | |
open T | |
let prj_succ prj_pred x = fmap prj_pred (prj x) | |
let prj1 = prj | |
let prj2 = prj_succ prj1 | |
let prj3 = prj_succ prj2 | |
let prj4 = prj_succ prj3 | |
let prj5 = prj_succ prj4 | |
end | |
module AuxDefs (T : TYP) : sig | |
val inj1 : T.t T.f -> T.t | |
val inj2 : T.t T.f T.f -> T.t | |
val inj3 : T.t T.f T.f T.f -> T.t | |
val inj4 : T.t T.f T.f T.f T.f -> T.t | |
val inj5 : T.t T.f T.f T.f T.f T.f -> T.t | |
val inj_succ : ('a -> T.t) -> 'a T.f -> T.t | |
val prj1 : T.t -> T.t T.f | |
val prj2 : T.t -> T.t T.f T.f | |
val prj3 : T.t -> T.t T.f T.f T.f | |
val prj4 : T.t -> T.t T.f T.f T.f T.f | |
val prj5 : T.t -> T.t T.f T.f T.f T.f T.f | |
val prj_succ : (T.t -> 'a) -> T.t -> 'a T.f | |
val fold : ('a T.f -> 'a) -> T.t -> 'a | |
val unfold : ('a -> 'a T.f) -> 'a -> T.t | |
end = struct | |
include Injs(T) | |
include Prjs(T) | |
include Rec(T) | |
end | |
(* definition of Nat module *) | |
module type NAT_TYP = sig | |
type t | |
type 'a f = ZERO | SUCC of 'a | |
val fmap : ('a -> 'b) -> 'a f -> 'b f | |
val inj : t f -> t | |
val prj : t -> t f | |
end | |
module NatOp (NatTyp : NAT_TYP) : sig | |
val add : NatTyp.t -> NatTyp.t -> NatTyp.t | |
val fib : NatTyp.t -> NatTyp.t | |
val toInt : NatTyp.t -> int | |
val fromInt : int -> NatTyp.t | |
end = struct | |
include AuxDefs(NatTyp) | |
open NatTyp | |
let one = inj2 @@ SUCC(ZERO) | |
let rec add n m = | |
match prj n, prj m with | |
| ZERO, _ -> m | |
| _, ZERO -> n | |
| SUCC(nn), SUCC(mm) -> | |
inj2 @@ SUCC(SUCC(add nn mm)) | |
let rec fib n = | |
match prj2 n with | |
| ZERO -> one | |
| SUCC(ZERO) -> one | |
| SUCC(SUCC(n)) -> | |
add (fib @@ inj @@ SUCC(n)) (fib n) | |
let toInt n = | |
let f = function | |
ZERO -> 0 | |
| SUCC(n) -> n + 1 | |
in fold f n | |
let fromInt nn = | |
let g nn = | |
if nn < 0 then failwith "Negative number" | |
else if nn = 0 then ZERO | |
else SUCC(nn - 1) | |
in unfold g nn | |
end | |
module Nat1Typ : NAT_TYP = struct | |
type t = Z | S of t | |
type 'a f = ZERO | SUCC of 'a | |
let fmap f = function | |
ZERO -> ZERO | |
| SUCC(n) -> SUCC(f n) | |
let inj = function | |
ZERO -> Z | |
| SUCC(n) -> S(n) | |
let prj = function | |
Z -> ZERO | |
| S(n) -> SUCC(n) | |
end | |
module Nat1 = NatOp(Nat1Typ) | |
module Nat2Typ : NAT_TYP = struct | |
type t = int | |
type 'a f = ZERO | SUCC of 'a | |
let fmap f = function | |
ZERO -> ZERO | |
| SUCC(n) -> SUCC(f n) | |
let inj = function | |
ZERO -> 0 | |
| SUCC(n) -> n + 1 | |
let prj n = | |
if n < 0 then failwith "Negative value" | |
else if n = 0 then ZERO | |
else SUCC(n - 1) | |
end | |
module Nat2 = NatOp(Nat2Typ) | |
let timeit f count = | |
let rec run count = | |
if count <= 0 then () | |
else (f (); run (count - 1)) in | |
let start = Unix.gettimeofday () in | |
run count; | |
let finish = Unix.gettimeofday () in | |
(finish -. start) /. (float_of_int count) | |
let nat_time () = | |
let n = 25 in | |
let count = 100 in | |
(* Nat1 *) | |
let f1 () = | |
let open Nat1 in | |
ignore @@ fib @@ fromInt n | |
in | |
print_string "Nat1: "; | |
print_float @@ timeit f1 count; | |
print_newline (); | |
(* Nat2 *) | |
let f2 () = | |
let open Nat2 in | |
ignore @@ fib @@ fromInt n | |
in | |
print_string "Nat2: "; | |
print_float @@ timeit f2 count; | |
print_newline () | |
(* definition of boolean operations *) | |
module type BOOL_TYP = sig | |
type t | |
type 'a f = TRUE | |
| FALSE | |
| VAR of string | |
| AND of 'a * 'a | |
| OR of 'a * 'a | |
| NOT of 'a | |
val fmap : ('a -> 'b) -> 'a f -> 'b f | |
val inj : t f -> t | |
val prj : t -> t f | |
end | |
module BoolTyp : BOOL_TYP = struct | |
type t = True | |
| False | |
| Var of string | |
| And of t * t | |
| Or of t * t | |
| Not of t | |
type 'a f = TRUE | |
| FALSE | |
| VAR of string | |
| AND of 'a * 'a | |
| OR of 'a * 'a | |
| NOT of 'a | |
let fmap f = function | |
TRUE -> TRUE | |
| FALSE -> FALSE | |
| VAR(x) -> VAR(x) | |
| AND(p, q) -> AND(f p, f q) | |
| OR(p, q) -> OR(f p, f q) | |
| NOT(p) -> NOT(f p) | |
let inj = function | |
TRUE -> True | |
| FALSE -> False | |
| VAR(x) -> Var(x) | |
| AND(p, q) -> And(p, q) | |
| OR(p, q) -> Or(p, q) | |
| NOT(p) -> Not(p) | |
let prj = function | |
True -> TRUE | |
| False -> FALSE | |
| Var(x) -> VAR(x) | |
| And(p, q) -> AND(p, q) | |
| Or(p, q) -> OR(p, q) | |
| Not(p) -> NOT(p) | |
end | |
module Simplify (BoolTyp : BOOL_TYP) = struct | |
module BoolDefs = AuxDefs(BoolTyp) | |
include BoolTyp | |
include BoolDefs | |
let simplify p = | |
let rewrite = function | |
NOT(TRUE) -> FALSE | |
| NOT(FALSE) -> TRUE | |
| NOT(NOT(p)) -> prj1 p | |
| AND(FALSE, _) -> FALSE | |
| AND(_, FALSE) -> FALSE | |
| OR(TRUE, _) -> TRUE | |
| OR(_, TRUE) -> TRUE | |
(* ... more rules here ... *) | |
| p -> fmap inj1 p | |
in unfold (fun p -> rewrite @@ prj2 p) p | |
end | |
module Bool = Simplify(BoolTyp) | |
let print p = | |
let rec print_it p = | |
(match Bool.prj1 p with | |
| TRUE -> print_string "True" | |
| FALSE -> print_string "False" | |
| VAR(x) -> print_string x | |
| AND(p, q) -> print_bi "and" p q | |
| OR(p, q) -> print_bi "or" p q | |
| NOT(p) -> begin | |
print_string "(not "; | |
print_it p; | |
print_string ")" | |
end) | |
and print_bi op p q = begin | |
print_string "("; | |
print_it p; | |
print_string @@ " " ^ op ^ " "; | |
print_it q; | |
print_string ")" | |
end | |
in (print_it p; print_newline ()) | |
let () = | |
let open Bool in | |
let p = inj3 @@ AND(TRUE, NOT(FALSE)) in | |
let p = simplify p in | |
print p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment