Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@snowleopard
Last active August 28, 2018 10:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snowleopard/2dd93951cfd42e03aa04a4aa696ca029 to your computer and use it in GitHub Desktop.
Save snowleopard/2dd93951cfd42e03aa04a4aa696ca029 to your computer and use it in GitHub Desktop.
Typed constant folding
{-# LANGUAGE GADTs, DataKinds, TypeOperators #-}
{-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-}
-- This is an attempt to find a safer implementation for GHC constant folding algorithm
-- See https://ghc.haskell.org/trac/ghc/ticket/15569
-- Shapes of expression trees: L stands for a literal, V for a variable
data Shape = L | V | Shape :+: Shape | Shape :*: Shape
-- Arithmetic expressions with shape annotations
data Expr s a b where
Lit :: Polarity -> a -> Expr L a b
Var :: Polarity -> b -> Expr V a b
Add :: Expr x a b -> Expr y a b -> Expr (x :+: y) a b
Mul :: Expr x a b -> Expr y a b -> Expr (x :*: y) a b
mapLeft :: (Expr x a b -> Expr y a b) -> Expr (op x (z :: Shape)) a b -> Expr (op y z) a b
mapLeft f (Add x z) = Add (f x) z
mapLeft f (Mul x z) = Mul (f x) z
mapRight :: (Expr x a b -> Expr y a b) -> Expr (op (z :: Shape) x) a b -> Expr (op z y) a b
mapRight f (Add z x) = Add z (f x)
mapRight f (Mul z x) = Mul z (f x)
-- We use polarity to encode subtraction
data Polarity = Positive | Negative
neg :: Expr s a b -> Expr s a b
neg expr = case expr of
Lit p a -> Lit (mirror p) a
Var p b -> Var (mirror p) b
Add x y -> Add (neg x) (neg y)
Mul x y -> Mul (neg x) y
where
mirror Positive = Negative
mirror Negative = Positive
getLit :: Num a => Expr L a b -> a
getLit (Lit Positive a) = a
getLit (Lit Negative a) = negate a
-- A few smart constructors
lit :: a -> Expr L a b
lit = Lit Positive
var :: b -> Expr V a b
var = Var Positive
add :: Expr x a b -> Expr y a b -> Expr (x :+: y) a b
add = Add
sub :: Expr x a b -> Expr y a b -> Expr (x :+: y) a b
sub x y = Add x (neg y)
mul :: Expr x a b -> Expr y a b -> Expr (x :*: y) a b
mul = Mul
-- Axioms of addition and multiplication
comm :: Expr (op (x :: Shape) y) a b -> Expr (op y x) a b
comm (Add x y) = Add y x
comm (Mul x y) = Mul y x
assoc1 :: Expr (op (x :: Shape) (op y z)) a b -> Expr (op (op x y) z) a b
assoc1 (Add x (Add y z)) = Add (Add x y) z
assoc1 (Mul x (Mul y z)) = Mul (Mul x y) z
assoc2 :: Expr (op (op (x :: Shape) y) z) a b -> Expr (op x (op y z)) a b
assoc2 (Add (Add x y) z) = Add x (Add y z)
assoc2 (Mul (Mul x y) z) = Mul x (Mul y z)
distr :: Expr (x :*: (y :+: z)) a b -> Expr ((x :*: y) :+: (x :*: z)) a b
distr (Mul x (Add y z)) = Add (Mul x y) (Mul x z)
-- The main constant folding step
eval :: Num a => Expr (op L L) a b -> Expr L a b
eval (Add x y) = Lit Positive (getLit x + getLit y)
eval (Mul x y) = Lit Positive (getLit x * getLit y)
-- Constant folding rules that are checked by the compiler.
-- Ideally these rules would live in a separate module seeing
-- `Expr` as an abstract data type with transformations `eval`,
-- `mapLeft`, `comm` etc. whose correctness is verified manually.
-- In this way, I think, it should be impossible to write an
-- incorrect constant folding rule.
r1 :: Num a => Expr (op L (op L x)) a b -> Expr (op L x) a b
r1 = mapLeft eval . assoc1
r2 :: Num a => Expr (op L (op x L)) a b -> Expr (op L x) a b
r2 = r1 . mapRight comm
r3 :: Num a => Expr (op (op L x) L) a b -> Expr (op L x) a b
r3 = r1 . comm
r4 :: Num a => Expr (op (op x L) L) a b -> Expr (op L x) a b
r4 = r3 . mapLeft comm
r5 :: Num a => Expr (op (op L x) (op L y)) a b -> Expr (op L (op x y)) a b
r5 = assoc2 . mapLeft r3 . assoc1
r6 :: Num a => Expr (op (op L x) (op y L)) a b -> Expr (op L (op x y)) a b
r6 = r5 . mapRight comm
r7 :: Num a => Expr (op (op x L) (op L y)) a b -> Expr (op L (op x y)) a b
r7 = r5 . mapLeft comm
r8 :: Num a => Expr (op (op x L) (op y L)) a b -> Expr (op L (op x y)) a b
r8 = r6 . mapLeft comm
r9 :: Num a => Expr (L :*: (L :+: x)) a b -> Expr (L :+: (L :*: x)) a b
r9 = mapLeft eval . distr
r10 :: Num a => Expr (L :*: (x :+: L)) a b -> Expr (L :+: (L :*: x)) a b
r10 = r9 . mapRight comm
r11 :: Expr (op x (op L y)) a b -> Expr (op L (op x y)) a b
r11 = assoc2 . mapLeft comm . assoc1
r12 :: Expr (op x (op y L)) a b -> Expr (op L (op x y)) a b
r12 = r11 . mapRight comm
r13 :: Expr (op (op L x) y) a b -> Expr (op L (op x y)) a b
r13 = mapRight comm . r11 . comm
r14 :: Expr (op (op x L) y) a b -> Expr (op L (op x y)) a b
r14 = mapRight comm . r12 . comm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment