Last active
August 29, 2015 14:21
-
-
Save pacak/65a48eba7c38dc6b614c to your computer and use it in GitHub Desktop.
Template haskell Foldable/Unfoldable instances and datatypes generation for recursion-schemes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor, StandaloneDeriving, FlexibleInstances #-} | |
{-# LANGUAGE TypeFamilies, TemplateHaskell, DataKinds #-} | |
{-# LANGUAGE BangPatterns #-} | |
import TH | |
import Data.Functor.Foldable | |
import Criterion.Main | |
import Prelude hiding (Foldable) | |
data Tree a = Leaf | T (Tree a) a (Tree a) deriving (Show) | |
makePrim ''Tree | |
alg :: Prim (Tree Int) Int -> Int | |
alg LeafF = 0 | |
alg (TF !tl !a !tr) = tl + a + tr | |
alg2 0 = LeafF | |
alg2 n = TF (n-1) n (n-1) | |
main :: IO () | |
main = defaultMain [ | |
bgroup "cata" | |
[ bench "rec" $ whnf recSum sample | |
, bench "cata" $ whnf cataSum sample | |
] | |
] | |
recSum :: Tree Int -> Int | |
recSum Leaf = 0 | |
recSum (T l a r) = recSum l + a + recSum r | |
cataSum :: Tree Int -> Int | |
cataSum = cata alg | |
sample :: Tree Int | |
sample = ana alg2 20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE LambdaCase, OverloadedStrings, FlexibleInstances #-} | |
{-# LANGUAGE ViewPatterns, TemplateHaskell #-} | |
module TH where | |
import Control.Applicative | |
import Control.Lens | |
import Control.Monad | |
import Data.Functor.Foldable | |
import Language.Haskell.TH | |
import Language.Haskell.TH.Lens | |
import qualified Data.Set as Set | |
import Prelude hiding (Foldable) | |
class HasType t where | |
typeVal :: Traversal' t Type | |
instance HasType Con where | |
typeVal f (NormalC name typ) = NormalC name <$> typeVal f typ | |
typeVal f (RecC name typ) = RecC name <$> typeVal f typ | |
typeVal f (InfixC typ1 name typ2) = InfixC <$> typeVal f typ1 | |
<*> pure name <*> typeVal f typ2 | |
typeVal f (ForallC vb ctx con) = ForallC vb ctx <$> typeVal f con | |
instance HasType t => HasType [t] where | |
typeVal = traverse . typeVal | |
instance HasType (a, b, Type) where | |
typeVal f (a, b, typ) = (,,) a b <$> f typ | |
instance HasType (a, Type) where | |
typeVal f (a, typ) = (,) a <$> f typ | |
makePrim :: Name -> DecsQ | |
makePrim name = do | |
reify name >>= \case | |
TyConI dec -> makePrimForDec dec | |
_ -> fail "makePrim: Expected type constructor name" | |
toFName :: Name -> Name | |
toFName name = mkName $ nameBase name ++ "F" | |
varBindName :: Getter TyVarBndr Type | |
varBindName = to (VarT . \case PlainTV n -> n ; KindedTV n _ -> n) | |
makePrimForDec :: Dec -> DecsQ | |
makePrimForDec = \case | |
dd@(DataD _{-context-} tyName vars cons _{-derive-}) -> do | |
r <- VarT <$> newName "r" | |
-- exctract useful stuff | |
let fullType = foldlOf (traverse . varBindName) AppT (ConT tyName) vars :: Type | |
primType = ConT ''Prim `AppT` fullType | |
-- data instance declaration | |
let toFunctor = set (traverse . typeVal . filtered (== fullType)) r | |
renameRecs = over (traverse . _RecC . _2 . traverse . _1) toFName | |
renameCons = over (traverse . name) toFName | |
cons' = renameRecs . renameCons . toFunctor $ cons | |
dataInstance = DataInstD [] ''Prim [fullType, r] cons' [''Functor, ''Show] | |
-- type synonym instance declaration | |
let typeInstance = TySynInstD ''Base (TySynEqn [fullType] primType) | |
-- Foldable project | |
let (nNames, map length -> nAttrs) = unzip $ map normalizeConstructor cons | |
args <- mapM (flip replicateM (newName "a")) nAttrs | |
let projD = FunD 'project (mkMorphism nNames (map toFName nNames) args) | |
foldInstance = InstanceD [] (ConT ''Foldable `AppT` fullType) [projD] | |
-- Unfoldable embed | |
let embD = FunD 'embed (mkMorphism (map toFName nNames) nNames args) | |
unfInstance = InstanceD [] (ConT ''Unfoldable `AppT` fullType) [embD] | |
return [dataInstance, typeInstance, foldInstance, unfInstance] | |
-- | makes clauses to rename constructors | |
mkMorphism :: [Name] -> [Name] -> [[Name]] -> [Clause] | |
mkMorphism nFrom nTo args = | |
let pats = zipWith ConP nFrom (map (map VarP) args) | |
res = zipWith (foldl AppE) (map ConE nTo) (map (map VarE) args) | |
in zipWith3 Clause (map (:[]) pats) (map NormalB res) (repeat []) | |
-- | Normalized the Con type into a uniform positional representation, | |
-- eliminating the variance between records, infix constructors, and normal | |
-- constructors. | |
normalizeConstructor :: | |
Con -> (Name, [(Maybe Name, Type)]) -- ^ constructor name, field name, field type | |
normalizeConstructor (RecC n xs) = | |
(n, [ (Just fieldName, ty) | (fieldName,_,ty) <- xs]) | |
normalizeConstructor (NormalC n xs) = | |
(n, [ (Nothing, ty) | (_,ty) <- xs]) | |
normalizeConstructor (InfixC (_,ty1) n (_,ty2)) = | |
(n, [ (Nothing, ty1), (Nothing, ty2) ]) | |
normalizeConstructor (ForallC _ _ con) = | |
let con' = normalizeConstructor con | |
in (set (_2 . mapped . _1) Nothing con') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment