public
anonymous / gist:1194308
Last active

Smarter do notation desugaring

  • Download Gist
gistfile1.hs
Haskell
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
module Main where
 
import qualified Data.Set as S
 
type Var = String
 
-- For the purposes of the desugaring the only important thing
-- about the nested terms are the free variables
data HsTerm = HsTerm { dsHsTerm' :: String, hsTermFVs' :: [Var] }
 
dsHsTerm = Var . dsHsTerm' -- A hack :-)
hsTermFVs = S.fromList . hsTermFVs'
 
data HsDo = HsDo [HsQual] HsTerm
data HsQual = Bind Var HsTerm
| Seq HsTerm
 
data Term = Unit
| Pair Term Term | Fst Term | Snd Term
| MonadBind Term Term -- e1 >>= e2
| MonadSeq Term Term -- e1 >> e2
| Pure Term
| LiftA2Pair Term Term -- liftA2 (,) e1 e2
| ApplicativeSeqL Term Term -- e1 <* e2
| ApplicativeSeqR Term Term -- e1 *> e2
| Let Var Term Term | Var Var
| Lam Var Term
 
instance Show Term where
show Unit = "()"
show (Pair e1 e2) = "(" ++ show e1 ++ ", " ++ show e2 ++ ")"
show (Fst e) = "fst " ++ showP e
show (Snd e) = "snd " ++ showP e
show (MonadBind e1 e2) = showP e1 ++ " >>= " ++ showP e2
show (MonadSeq e1 e2) = showP e1 ++ " >> " ++ showP e2
show (Pure e) = "pure " ++ showP e
show (LiftA2Pair e1 e2) = "liftA2 (,) " ++ showP e1 ++ " " ++ showP e2
show (ApplicativeSeqL e1 e2) = showP e1 ++ " <* " ++ showP e2
show (ApplicativeSeqR e1 e2) = showP e1 ++ " <* " ++ showP e2
show (Let x e1 e2) = "let " ++ x ++ " = " ++ show e1 ++ "\nin " ++ show e2
show (Var x) = x
show (Lam x e) = "\\" ++ x ++ ". " ++ show e
 
showP e = "(" ++ show e ++ ")"
 
 
dsHsDo :: HsDo -> Term
dsHsDo = start
where
start = go [] (Pure Unit)
 
go :: [Var] -> Term -> HsDo -> Term
go vs e_vs (HsDo [] e)
| all (`S.notMember` hsTermFVs e) vs
= e_vs `ApplicativeSeqR` dsHsTerm e
| otherwise
= bind vs e_vs (dsHsTerm e)
go vs e_vs (HsDo (q:qs) e)
= case q of
Bind x eq | all (`S.notMember` hsTermFVs eq) vs
-> go (x:vs) (e_vs `LiftA2Pair` dsHsTerm eq) (HsDo qs e)
Seq eq | all (`S.notMember` hsTermFVs eq) vs
-> go vs (e_vs `ApplicativeSeqL` dsHsTerm eq) (HsDo qs e)
_ -> bind vs e_vs (start (HsDo (q:qs) e))
 
-- The first bound variable in the list is the one bound *last* and is
-- stored in the rightmost (shallowest) component of the tuple; hence foldl
bind vs e_vs e_body
= e_vs `MonadBind`
Lam "tup" (snd $ foldl (\(bound, e_body) v -> (Fst bound, Let v (Snd bound) e_body))
(Var "tup", e_body) vs)
 
 
main = do
let e1 = HsTerm "computation1" []
e2 = HsTerm "computation2" []
e3 = HsTerm "return x y" ["x", "y"]
print $ dsHsDo (HsDo [Bind "x" e1, Bind "y" e2] e3)
 
let e1 = HsTerm "computation1" []
e2 = HsTerm "computation2" []
e3 = HsTerm "computation3 y" ["y"]
e4 = HsTerm "computation4 x" ["x"]
print $ dsHsDo (HsDo [Bind "x" e1, Bind "y" e2, Bind "z" e3] e4)

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.