Skip to content

Instantly share code, notes, and snippets.

@darichey
Last active April 17, 2023 05:26
Show Gist options
  • Save darichey/cb18ee71460255401988bd9a9b58bb8f to your computer and use it in GitHub Desktop.
Save darichey/cb18ee71460255401988bd9a9b58bb8f to your computer and use it in GitHub Desktop.
Code re-use with recursion schemes
{-# LANGUAGE DeriveTraversable, TemplateHaskell, TypeFamilies, LambdaCase, FlexibleInstances #-}
module Main where
import Data.Functor.Foldable.TH (makeBaseFunctor)
import Control.Lens
import Control.Monad.Reader
import Data.Functor.Foldable
import Data.List
import Control.Monad.Writer (Writer)
data Expr = Num Int | Add Expr Expr | Mul Expr Expr
makeBaseFunctor ''Expr
type Depth = Int
class HasDepth r where
depthL :: Lens' r Depth
instance HasDepth (Depth, Prec) where
depthL = lens fst (\(_, prec) depth -> (depth, prec))
type Prec = Int
class HasPrec r where
precL :: Lens' r Prec
instance HasPrec (Depth, Prec) where
precL = lens snd (\(depth, _) prec -> (depth, prec))
instance HasPrec Prec where
precL = lens id (\_ x -> x)
increaseDepth ::
HasDepth r =>
ExprF (Reader r a) ->
ExprF (Reader r a)
increaseDepth = fmap (local (depthL +~ 1))
setPrecedence ::
HasPrec r =>
ExprF (Reader r a) ->
ExprF (Reader r a)
setPrecedence = \case
NumF n -> NumF n
AddF e1 e2 -> AddF (localPrec 0 e1) (localPrec 0 e2)
MulF e1 e2 -> MulF (localPrec 1 e1) (localPrec 1 e2)
where
localPrec n = local (precL .~ n)
seqAp :: (Traversable f, Monad m) => (f a -> m b) -> f (m a) -> m b
seqAp fa2mb fma = do
fa <- sequence fma
fa2mb fa
debugPrint :: Expr -> String
debugPrint e = runReader (cata (seqAp debugAlg . setPrecedence . increaseDepth) e) (0, 0)
where
debugAlg :: ExprF String -> Reader (Depth, Prec) String
debugAlg e = do
(depth, prec) <- ask
let (name, body) = case e of
NumF n -> ("Num " ++ show n, [])
AddF e1 e2 -> ("Add", [e1, e2])
MulF e1 e2 -> ("Mul", [e1, e2])
let s = replicate (depth * 2) ' ' ++ "* " ++ name ++ " -- prec: " ++ show prec
pure $ intercalate "\n" (s : body)
prettyPrint :: Expr -> String
prettyPrint e = runReader (cata (seqAp prettyAlg . setPrecedence) e) 0
where
prettyAlg :: ExprF String -> Reader Prec String
prettyAlg e = do
prec <- ask
pure $ case e of
NumF n -> show n
AddF e1 e2 -> showParen (prec > 0) $ e1 ++ " + " ++ e2
MulF e1 e2 -> showParen (prec > 1) $ e1 ++ " * " ++ e2
showParen :: Bool -> String -> String
showParen b s = if b then "(" <> s <> ")" else s
main = do
let expr = Mul (Num 1) (Add (Num 2) (Num 3))
putStrLn $ debugPrint expr
putStrLn $ prettyPrint expr
* Mul -- prec: 0
* Num 1 -- prec: 1
* Add -- prec: 1
* Num 2 -- prec: 0
* Num 3 -- prec: 0
1 * (2 + 3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment