Skip to content

Instantly share code, notes, and snippets.

@damhiya
Last active January 31, 2022 06:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save damhiya/c3db201dae7bb9a70bc0072be9a5ab93 to your computer and use it in GitHub Desktop.
Save damhiya/c3db201dae7bb9a70bc0072be9a5ab93 to your computer and use it in GitHub Desktop.
defunctionalization
#[derive(PartialEq, Debug)]
pub struct Tree<A> {
value: A,
children: Vec<Tree<A>>,
}
enum KontElem<'a, A, B> {
Kont1(&'a A, &'a [Tree<A>]),
Kont2(&'a A, Vec<B>),
}
impl<A> Tree<A> {
pub fn fold<B, F>(&self, f: F) -> B
where
F: Fn(&A, Vec<B>) -> B,
{
let mut stack: Vec<KontElem<A, B>> = vec![];
let mut ts: &[Tree<A>] = &self.children;
let ys = 'main_loop: loop {
// go
while let Some((ts_head, ts_tail)) = ts.split_first() {
stack.push(KontElem::Kont1(&ts_head.value, ts_tail));
ts = &ts_head.children;
}
// apply
let mut ys_reversed: Vec<B> = vec![];
loop {
if let Some(kontelem) = stack.pop() {
match kontelem {
KontElem::Kont1(x, ts_) => {
ts = ts_;
let ys = {
ys_reversed.reverse();
ys_reversed
};
stack.push(KontElem::Kont2(x, ys));
break;
}
KontElem::Kont2(x, ys) => ys_reversed.push(f(x, ys)),
}
} else {
ys_reversed.reverse();
break 'main_loop ys_reversed;
};
}
};
f(&self.value, ys)
}
pub fn map<B, F>(&self, f: F) -> Tree<B>
where
F: Fn(&A) -> B,
{
self.fold(|x, ys| Tree {
value: f(x),
children: ys,
})
}
}
fn main() {
let t1 = Tree {
value: 1,
children: vec![
Tree {
value: 2,
children: vec![
Tree {
value: 3,
children: vec![
Tree {
value: 4,
children: vec![],
},
Tree {
value: 5,
children: vec![],
},
],
},
Tree {
value: 6,
children: vec![],
},
],
},
Tree {
value: 7,
children: vec![
Tree {
value: 8,
children: vec![],
},
Tree {
value: 9,
children: vec![],
},
],
},
Tree {
value: 10,
children: vec![
Tree {
value: 11,
children: vec![Tree {
value: 12,
children: vec![],
}],
},
Tree {
value: 13,
children: vec![Tree {
value: 14,
children: vec![],
}],
},
Tree {
value: 15,
children: vec![],
},
],
},
],
};
let t2 = t1.fold::<Tree<i32>, _>(|x: &i32, ys| {
Tree {
value: *x,
children: ys,
}
});
assert!(t1 == t2);
println!("correct!");
}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Rose where
data Tree a = Node a [Tree a] deriving (Eq, Show)
t1 :: Tree Int
t1 = Node 1
[ Node 2
[ Node 3
[ Node 4 []
, Node 5 []
]
, Node 6 []
]
, Node 7
[ Node 8 []
, Node 9 []
]
, Node 10
[ Node 11
[ Node 12 [] ]
, Node 13
[ Node 14 [] ]
, Node 15 []
]
]
-- naive implementation using recursion
fold :: (a -> [b] -> b) -> Tree a -> b
fold f (Node x ts) = f x (map (fold f) ts)
-- clarify mutual recursion
foldMut :: (a -> [b] -> b) -> Tree a -> b
foldMut f = go1
where
go1 (Node x ts) = f x (go2 ts)
go2 [] = []
go2 (t:ts) = go1 t : go2 ts
-- inlining
foldInline :: (a -> [b] -> b) -> Tree a -> b
foldInline f = \(Node x ts) -> f x (go ts)
where
go [] = []
go (Node x ts' : ts) = f x (go ts') : go ts
-- apply CPS conversion
foldCPS :: (a -> [b] -> b) -> Tree a -> b
foldCPS f = \(Node x ts) -> f x (go ts id)
where
go [] k = k []
go (Node x ts' : ts) k =
go ts' $ \ys' ->
go ts $ \ys ->
k (f x ys' : ys)
-- factor out continuations
foldCPS' :: forall a b. (a -> [b] -> b) -> Tree a -> b
foldCPS' f = \(Node x ts) -> f x (go ts (id :: [b] -> [b]))
where
go :: forall r. [Tree a] -> ([b] -> r) -> r
go [] k = k []
go (Node x ts' : ts) k = go ts' (kont1 x ts k)
kont1 :: forall r. a -> [Tree a] -> ([b] -> r) -> [b] -> r
kont1 x ts k = \ys' -> go ts (kont2 x ys' k)
kont2 :: forall r. a -> [b] -> ([b] -> r) -> [b] -> r
kont2 x ys' k = \ys -> k (f x ys' : ys)
-- defunctionalization
data Kont a b = Id | Kont1 a [Tree a] (Kont a b) | Kont2 a [b] (Kont a b)
foldDefunc :: forall a b. (a -> [b] -> b) -> Tree a -> b
foldDefunc f = \(Node x ts) -> f x (go ts Id)
where
go :: [Tree a] -> Kont a b -> [b]
go [] k = apply k []
go (Node x ts' : ts) k = go ts' (Kont1 x ts k)
apply :: Kont a b -> [b] -> [b]
apply Id ys = ys
apply (Kont1 x ts k) ys' = go ts (Kont2 x ys' k)
apply (Kont2 x ys' k) ys = apply k (f x ys' : ys)
-- Kont using list
type Kont' a b = [Either (a, [Tree a]) (a, [b])]
foldDefunc' :: forall a b. (a -> [b] -> b) -> Tree a -> b
foldDefunc' f = \(Node x ts) -> f x (go ts [])
where
go :: [Tree a] -> Kont' a b -> [b]
go [] k = apply k []
go (Node x ts' : ts) k = go ts' (Left (x, ts) : k)
apply :: Kont' a b -> [b] -> [b]
apply [] ys = ys
apply (Left (x, ts) : k) ys' = go ts (Right (x, ys') : k)
apply (Right (x, ys') : k) ys = apply k (f x ys' : ys)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment