Skip to content

Instantly share code, notes, and snippets.

@oisdk
Last active April 29, 2026 07:56
Show Gist options
  • Select an option

  • Save oisdk/a10b27a54a3fef6b489eb9202406387c to your computer and use it in GitHub Desktop.

Select an option

Save oisdk/a10b27a54a3fef6b489eb9202406387c to your computer and use it in GitHub Desktop.
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DeriveGeneric #-}
import qualified Data.OrdPSQ as Map
import Data.OrdPSQ (OrdPSQ)
import Data.Ord (Down(..))
import Data.List (foldl',intersperse,unfoldr)
import Data.Functor ((<&>))
import Control.Lens
import GHC.Generics (Generic)
import Test.QuickCheck
--------------------------------------------------------------------------------
-- Basic Data Type, basic instances
--------------------------------------------------------------------------------
data Poly v c = !c :<+ SubTerms v c
type SubTerms v c = OrdPSQ (Down v) (Down Word) (Poly v c)
deriving instance Functor (Poly v)
deriving instance (Eq c, Ord v) => Eq (Poly v c)
isZero :: (Num c, Eq c) => Poly v c -> Bool
isZero (n :<+ ns) = (n == 0) && Map.null ns
depth :: Poly v c -> Word
depth (_ :<+ ns) = maybe 0 (\(_,Down p,_) -> succ p) (Map.findMin ns)
instance (Eq c, Num c, Show c, Show v) => Show (Poly v c) where
showsPrec p (n :<+ as) = case terms of
[] -> showChar '0'
[t] -> t p
_ -> showParen (p > 6) (sepShowS (showString " + ") (map ($ 6) terms))
where
terms = [ flip showsPrec n | n /= 0 ]
++ map (\(Down v,_,p) -> showMult v p) (Map.toAscList as)
showMult k t q
| isOne t = showsPrec q k
| isNegOne t = showParen (q > 9) (showChar '-' . showsPrec 9 k)
| otherwise =
showParen (q > 8) (showsPrec 8 k . showString "*" . showsPrec 8 t)
isOne (m :<+ ms) = m == 1 && Map.null ms
isNegOne (m :<+ ms) = m == (-1) && Map.null ms
sepShowS :: ShowS -> [ShowS] -> ShowS
sepShowS s = foldr (.) id . intersperse s
--------------------------------------------------------------------------------
-- Numeric
--------------------------------------------------------------------------------
instance (Ord v, Num c, Eq c) => Num (Poly v c) where
fromInteger n = fromInteger n :<+ Map.empty
(n :<+ ns) + (m :<+ ms) = (n + m) :<+ mns
where
mns | Map.size ms < Map.size ns = foldl' ins ns (Map.toList ms)
| otherwise = foldl' ins ms (Map.toList ns)
ins ms (v,p,ns) =
snd (Map.alter ((,) () . maybe (Just (p,ns)) (add ns . snd)) v ms)
add ns ms = view (from entry) (ns + ms)
_ * ms | isZero ms = 0
(0 :<+ ns) * ms =
0 :<+ Map.unsafeMapMonotonic
(\_ (Down d) p -> (Down (d + depth ms), p * ms)) ns
(n :<+ ns) * ms =
fmap (n*) ms +
(0 :<+ Map.unsafeMapMonotonic
(\_ (Down d) p -> (Down (d + depth ms), p * ms)) ns)
negate = fmap negate
abs = fmap abs
signum (n :<+ _) = signum n :<+ Map.empty
var :: Num c => v -> Poly v c
var v = 0 :<+ Map.singleton (Down v) (Down 0) (1 :<+ Map.empty)
eval :: Num c => (v -> c) -> Poly v c -> c
eval f (n :<+ ns) = n + sum [ f v * eval f t | (Down v,_,t) <- Map.toList ns ]
prop_selfEval :: Poly Var Integer -> Property
prop_selfEval p = p === eval var (fmap fromInteger p)
--------------------------------------------------------------------------------
-- Lenses
--------------------------------------------------------------------------------
coeff :: Lens' (Poly v c) c
coeff f (n :<+ ns) = fmap (:<+ ns) (f n)
vars :: Lens (Poly v c) (Poly v' c) (SubTerms v c) (SubTerms v' c)
vars f (n :<+ ns) = fmap (n :<+) (f ns)
type instance Index (OrdPSQ k p v) = k
type instance IxValue (OrdPSQ k p v) = (p, v)
instance (Ord k, Ord p) => At (OrdPSQ k p v) where
at k f mp = f v <&> \case
Nothing -> maybe mp (\_ -> Map.delete k mp) v
Just (p',v') -> Map.insert k p' v' mp
where v = Map.lookup k mp
instance (Ord k, Ord p) => Ixed (OrdPSQ k p v) where
entry :: (Num c, Eq c) => Iso' (Maybe (Down Word, Poly v c)) (Poly v c)
entry =
iso
(maybe (0 :<+ Map.empty) snd)
(\p -> if isZero p then Nothing else Just (Down (depth p), p))
factored :: (Ord v, Num c, Eq c) => [v] -> Lens' (Poly v c) (Poly v c)
factored vs = foldr (\v ls -> vars . at (Down v) . entry . ls) id vs
coeffOf :: (Ord v, Num c, Eq c) => [v] -> Lens' (Poly v c) c
coeffOf vs = factored vs . coeff
--------------------------------------------------------------------------------
-- Enumeration
--------------------------------------------------------------------------------
data Knots a
= Knot
{ tied :: !Bool
, yank :: [a]
, ends :: Knots a }
tighten :: Knots a -> Knots a
tighten ~(Knot t y e) = Knot False (if t then y else []) (tighten e)
monos :: (Eq c, Num c) => Poly v c -> [([v],c)]
monos p = y
where
Knot _ y e = tie [] p (tighten e)
cons sv 0 ms = ms
cons sv c ms = (reverse sv, c) : ms
tie sv (n :<+ m) (Knot _ ms ps) =
Knot True (cons sv n ms)
(foldr (\(Down v,_,t) -> tie (v:sv) t) ps (Map.toAscList m))
pull :: Knots a -> [a]
pull (Knot True _ e) = pull e
pull (Knot False y _) = y
monosDesc :: (Eq c, Num c) => Poly v c -> [([v],c)]
monosDesc p = pull r
where
r = tie [] p (Knot False [] (tighten r))
cons sv 0 ms = ms
cons sv c ms = (reverse sv,c) : ms
tie vs (n :<+ m) (Knot _ ms ps) =
Knot True (cons vs n ms)
(foldr (\(Down v,_,p) -> tie (v:vs) p) ps (Map.toAscList m))
fromMonos :: (Ord v, Num c, Eq c) => [([v],c)] -> Poly v c
fromMonos = foldr (\(vs,c) -> coeffOf vs .~ c) 0
prop_monosRoundTrip :: Poly Var Word -> Property
prop_monosRoundTrip p = p === fromMonos (monosDesc p)
--------------------------------------------------------------------------------
-- Leading
--------------------------------------------------------------------------------
leading :: (Num c, Eq c, Ord v) => Poly v c -> Maybe (([v],c),Poly v c)
leading p | isZero p = Nothing
leading (n :<+ ns) = Just (retrie (Map.alterMin step ns))
where
retrie ((r,n'),ns') = (r, n' :<+ ns')
step Nothing = ((([],n),0),Nothing)
step (Just (Down v, _, p)) = (((v:vs,c),n), subTrie)
where
Just ((vs,c),p') = leading p
subTrie | isZero p' = Nothing
| otherwise = Just (Down v, Down (depth p'), p')
prop_leadingMonos :: Poly Var Word -> Property
prop_leadingMonos p = monosDesc p === unfoldr leading p
--------------------------------------------------------------------------------
-- Division
--------------------------------------------------------------------------------
divModPrefM :: (Fractional c, Eq c, Ord v)
=> Poly v c -> ([v],c) -> (Poly v c, Poly v c)
divModPrefM p (vs, i) = factored vs ((, 0) . fmap (/i)) p
divModPref :: (Fractional c, Eq c, Ord v)
=> Poly v c -> Poly v c -> (Poly v c, Poly v c)
divModPref num divisor = case leading divisor of
Nothing -> error "Divide by zero"
Just (lt, rest) -> go 0 num
where
go !quot !rem = case divModPrefM rem lt of
(0, _) -> (quot, rem)
(q, rem') -> go (quot + q) (rem' - rest * q)
prop_divReconstruct :: Poly Var Rational -> Poly Var Rational -> Property
prop_divReconstruct x y =
not (isZero y) ==> let (d,m) = divModPref x y in x === y * d + m
--------------------------------------------------------------------------------
-- Testing
--------------------------------------------------------------------------------
isNorm :: (Num c, Eq c, Show v) => Poly v c -> Property
isNorm (_ :<+ ns) =
conjoin [ counterexample (show v ++ "*0") (not (isZero p)) .&&. isNorm p
| (Down v,_,p) <- Map.toList ns]
validDepth :: (Show c, Show v, Num c, Eq c) => Poly v c -> Property
validDepth (_ :<+ ns) =
conjoin [ counterexample ("depth " ++ showsPrec 11 p "")
(depth p === d) .&&. validDepth p
| (_,Down d,p) <- Map.toList ns]
valid :: (Show c, Show v, Num c, Eq c) => Poly v c -> Property
valid p = isNorm p .&&. validDepth p
instance (Ord v, Num c, Eq c, Arbitrary v, Arbitrary c)
=> Arbitrary (Poly v c) where
arbitrary = fmap fromMonos arbitrary
data Var = A | B | C | D deriving (Eq, Ord, Show, Enum, Bounded, Generic)
instance Arbitrary Var where
arbitrary = arbitraryBoundedEnum
shrink x = init [minBound..x]
instance CoArbitrary Var
instance Function Var
data Expr a
= Var' a
| Lit Integer
| Expr a :+: Expr a
| Expr a :*: Expr a
| Negate (Expr a)
deriving (Eq, Ord, Show)
instance Arbitrary a => Arbitrary (Expr a) where
arbitrary = sized go
where
go n | n <= 1 = oneof [fmap Lit arbitrary, fmap Var' arbitrary]
go n = frequency [ (1, fmap Negate (go (n-1)))
, (n, go' (:+:) (n-1))
, (n, go' (:*:) (n-1))
]
go' f n = do
m <- choose (1, n-1)
f <$> go m <*> go (n-m)
shrink (xs :+: ys) = [xs,ys]
shrink (xs :*: ys) = [xs,ys]
shrink (Negate xs) = [xs]
shrink _ = []
eval' :: Num e => (a -> e) -> Expr a -> e
eval' f (Var' x) = f x
eval' f (Lit i) = fromInteger i
eval' f (x :+: y) = eval' f x + eval' f y
eval' f (x :*: y) = eval' f x * eval' f y
eval' f (Negate x) = negate (eval' f x)
prop_evalExpr :: Fun Var Int -> Expr Var -> Property
prop_evalExpr (Fn f) e = eval' f e === eval f (eval' var e)
prop_validExpr :: Expr Var -> Property
prop_validExpr e = valid (eval' var e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment