Skip to content

Instantly share code, notes, and snippets.

Created September 5, 2011 07:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/1194308 to your computer and use it in GitHub Desktop.
Save anonymous/1194308 to your computer and use it in GitHub Desktop.
Smarter do notation desugaring
module Main where
import qualified Data.Set as S
type Var = String
-- For the purposes of the desugaring the only important thing
-- about the nested terms are the free variables
data HsTerm = HsTerm { dsHsTerm' :: String, hsTermFVs' :: [Var] }
dsHsTerm = Var . dsHsTerm' -- A hack :-)
hsTermFVs = S.fromList . hsTermFVs'
data HsDo = HsDo [HsQual] HsTerm
data HsQual = Bind Var HsTerm
| Seq HsTerm
data Term = Unit
| Pair Term Term | Fst Term | Snd Term
| MonadBind Term Term -- e1 >>= e2
| MonadSeq Term Term -- e1 >> e2
| Pure Term
| LiftA2Pair Term Term -- liftA2 (,) e1 e2
| ApplicativeSeqL Term Term -- e1 <* e2
| ApplicativeSeqR Term Term -- e1 *> e2
| Let Var Term Term | Var Var
| Lam Var Term
instance Show Term where
show Unit = "()"
show (Pair e1 e2) = "(" ++ show e1 ++ ", " ++ show e2 ++ ")"
show (Fst e) = "fst " ++ showP e
show (Snd e) = "snd " ++ showP e
show (MonadBind e1 e2) = showP e1 ++ " >>= " ++ showP e2
show (MonadSeq e1 e2) = showP e1 ++ " >> " ++ showP e2
show (Pure e) = "pure " ++ showP e
show (LiftA2Pair e1 e2) = "liftA2 (,) " ++ showP e1 ++ " " ++ showP e2
show (ApplicativeSeqL e1 e2) = showP e1 ++ " <* " ++ showP e2
show (ApplicativeSeqR e1 e2) = showP e1 ++ " <* " ++ showP e2
show (Let x e1 e2) = "let " ++ x ++ " = " ++ show e1 ++ "\nin " ++ show e2
show (Var x) = x
show (Lam x e) = "\\" ++ x ++ ". " ++ show e
showP e = "(" ++ show e ++ ")"
dsHsDo :: HsDo -> Term
dsHsDo = start
where
start = go [] (Pure Unit)
go :: [Var] -> Term -> HsDo -> Term
go vs e_vs (HsDo [] e)
| all (`S.notMember` hsTermFVs e) vs
= e_vs `ApplicativeSeqR` dsHsTerm e
| otherwise
= bind vs e_vs (dsHsTerm e)
go vs e_vs (HsDo (q:qs) e)
= case q of
Bind x eq | all (`S.notMember` hsTermFVs eq) vs
-> go (x:vs) (e_vs `LiftA2Pair` dsHsTerm eq) (HsDo qs e)
Seq eq | all (`S.notMember` hsTermFVs eq) vs
-> go vs (e_vs `ApplicativeSeqL` dsHsTerm eq) (HsDo qs e)
_ -> bind vs e_vs (start (HsDo (q:qs) e))
-- The first bound variable in the list is the one bound *last* and is
-- stored in the rightmost (shallowest) component of the tuple; hence foldl
bind vs e_vs e_body
= e_vs `MonadBind`
Lam "tup" (snd $ foldl (\(bound, e_body) v -> (Fst bound, Let v (Snd bound) e_body))
(Var "tup", e_body) vs)
main = do
let e1 = HsTerm "computation1" []
e2 = HsTerm "computation2" []
e3 = HsTerm "return x y" ["x", "y"]
print $ dsHsDo (HsDo [Bind "x" e1, Bind "y" e2] e3)
let e1 = HsTerm "computation1" []
e2 = HsTerm "computation2" []
e3 = HsTerm "computation3 y" ["y"]
e4 = HsTerm "computation4 x" ["x"]
print $ dsHsDo (HsDo [Bind "x" e1, Bind "y" e2, Bind "z" e3] e4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment