Skip to content

Instantly share code, notes, and snippets.

@gelisam
Created December 23, 2023 06:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gelisam/e1e7e312d3b5e5d16496c4135b0c885f to your computer and use it in GitHub Desktop.
Save gelisam/e1e7e312d3b5e5d16496c4135b0c885f to your computer and use it in GitHub Desktop.
Tracking whether the combination of two functions is still strictly-monotonic
-- In response to https://twitter.com/kmett/status/1738168271357026634
--
-- The challenge is to implement a version of
--
-- > mapKeys :: Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
--
-- which costs O(1) if the (k1 -> k2) function is coerce and the coercion
-- preserves the ordering, O(n) if the function is injective, and O(n log n)
-- otherwise. Obviously, the implementation can't inspect a pure function in
-- order to determine which case it is, so our version will need to use a
-- different type.
--
-- The challenge asks to do this "without making your user cry or accidentally
-- violate invariants".
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
module Main where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
import Data.Coerce (Coercible, coerce)
import Data.Functor.Identity (Identity(Identity))
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Ord (Down(Down))
import Unsafe.Coerce (unsafeCoerce)
-- Haskell's type system is great, but you'll get into trouble if you try to
-- track too much at the type level. Suppose we tracked whether a function is
-- strictly-monotone at the type level, and then used a typeclass to pick which
-- 'mapKeys' implement to use. Simple examples would probably work fine, but
-- higher-order functions would require more complex types might fail the "don't
-- make the user cry" criteria.
--
-- The secret to this challenge is thus: let's track this at the value level!
--
-- The simplest solution would be to define a sum type for the three cases and
-- to let the user specify at the value level which case they want:
--
-- > data Fun a b where
-- > Coerce
-- > :: Coercible a b
-- > => Fun a b
-- > StrictlyMonotone
-- > :: (a -> b)
-- > -> Fun a b
-- > NotMonotone
-- > :: (a -> b)
-- > -> Fun a b
-- >
-- > fancyMapKeys (Coerce :: String -> Identity String) exampleMap1
-- > fancyMapKeys (StrictlyMonotone (\x -> [x])) exampleMap1
-- > fancyMapKeys (NotMonotone Down) exampleMap1
--
-- But this might fail the "don't accidentally violate invariants" criteria, as
-- the user has to manually label their function at every single call site and
-- is bound to make a mistake sooner or later.
--
-- The second secret to this challenge is thus: define a combinator library!
-- The library defines a bunch of primitives which are already
-- correctly-labeled, plus a bunch of combinators which correctly update the
-- label. When prototyping, it is easy for the user to use 'arr' to get a
-- correct program which might not be as efficient as possible. When the user
-- needs the extra performance, then they can use the primitives to define a
-- more efficient version. And if the primitives are insufficient, they can use
-- 'unsafeStrictlyMonotone', and it it only then that the user needs to think
-- about whether that label is correct.
--
-- In the implementation below, I split 'Fun' into 'CFun' and 'MFun', because
-- the decisions regarding whether coerce can be used are orthogonal to the
-- decisions regarding whether a function is strictly-monotone. This is turn
-- leads to some typeclasses, just to reuse the same combinator name for (->),
-- 'CFun', and 'MFun'. Those typeclasses are entirely unnecessary, as my
-- solution is at the value level, not the type level.
data CFun a b where
Coerce :: Coercible a b => CFun a b
NotCoerce :: (a -> b) -> CFun a b
runCFun :: CFun a b -> a -> b
runCFun Coerce = coerce
runCFun (NotCoerce f) = f
-- | In a real library, these data constructors would be exported from a
-- ".Internal" module, so that dedicated users can define their own combinators.
data MFun k a b
= -- | if x < y then f x < f y
StrictlyMonotone (k a b)
| NotMonotone (k a b)
runMFun :: MFun k a b -> k a b
runMFun (StrictlyMonotone f) = f
runMFun (NotMonotone f) = f
-- | The caller promises that the function really is strictly-monotone.
unsafeStrictlyMonotone
:: k a b
-> MFun k a b
unsafeStrictlyMonotone = StrictlyMonotone
type MCFun = MFun CFun
runMCFun :: MCFun a b -> a -> b
runMCFun = runCFun . runMFun
arr :: (a -> b) -> MCFun a b
arr = NotMonotone . NotCoerce
fancyMapKeys
:: Ord k2
=> MCFun k1 k2
-> Map k1 a
-> Map k2 a
fancyMapKeys (StrictlyMonotone Coerce)
= -- O(1) case.
-- Coerce brings a Coercible instance in scope, so we could call coerce,
-- but that intentionally doesn't type-check thanks to k1's nominal role,
-- which prevents e.g.
-- > coerce :: Map k1 a -> Map (Down k1) a
-- from breaking the invariant. In this case, we know that the invariant
-- will not be broken because coerce is strictly-monotone, but the type
-- system doesn't know that, so we have to use unsafeCoerce.
unsafeCoerce
fancyMapKeys (StrictlyMonotone (NotCoerce f))
= -- O(n) case.
Map.mapKeysMonotonic f
fancyMapKeys (NotMonotone f)
= -- O(n log n) case.
Map.mapKeys (runCFun f)
-- That's it, that's the core of the library, and it was trivial! The rest
-- defines a bunch of combinators which carefully update the labels.
-- A lot more could be defined, I am sure.
instance Category CFun where
id = Coerce
Coerce . Coerce
= Coerce
f . g
= NotCoerce (runCFun f . runCFun g)
instance Category k => Category (MFun k) where
id = StrictlyMonotone id
StrictlyMonotone f . StrictlyMonotone g
= -- if x < y
-- then g x < g y
-- then f (g x) < f (g y)
StrictlyMonotone (f . g)
f . g
= NotMonotone (runMFun f . runMFun g)
class Product k where
(***) :: k a1 b1 -> k a2 b2 -> k (a1,a2) (b1,b2)
instance Product (->) where
(f1 *** f2) (x1,x2) = (f1 x1, f2 x2)
instance Product CFun where
Coerce *** Coerce
= Coerce
f1 *** f2
= NotCoerce $ \(x1,x2)
-> (runCFun f1 x1, runCFun f2 x2)
instance Product k => Product (MFun k) where
StrictlyMonotone f1 *** StrictlyMonotone f2
= -- if (x1,x2) < (y1,y2)
-- case 1: x1 < y1
-- then f1 x1 < f1 y1
-- then (f1 x1, _) < (f1 y1, _)
-- then (f1 x1, f2 x2) < (f1 y1, f2 y2)
-- case 2: x1 == y1 and x2 < y2
-- then f1 x1 == f1 y1 && f2 x2 < f2 y2
-- then (f1 x1, f2 x2) < (f1 y1, f2 y2)
StrictlyMonotone (f1 *** f2)
f1 *** f2
= NotMonotone (runMFun f1 *** runMFun f2)
dup :: MCFun a (a,a)
dup
= -- if x < y
-- then (x,_) < (y,_)
-- then (x,x) < (y,y)
StrictlyMonotone
$ NotCoerce $ \x
-> (x,x)
(&&&) :: MCFun a b1 -> MCFun a b2 -> MCFun a (b1,b2)
f1 &&& f2 = (f1 *** f2) . dup
class Sum k where
(+++) :: k a1 b1 -> k a2 b2 -> k (Either a1 a2) (Either b1 b2)
instance Sum (->) where
f1 +++ f2 = \case
Left x1
-> Left (f1 x1)
Right x2
-> Right (f2 x2)
instance Sum CFun where
Coerce +++ Coerce
= Coerce
f1 +++ f2
= NotCoerce $ \case
Left x1 -> Left (runCFun f1 x1)
Right x2 -> Right (runCFun f2 x2)
instance Sum k => Sum (MFun k) where
StrictlyMonotone f1 +++ StrictlyMonotone f2
= -- if x < y
-- case 1: x = Left x1 && y = Left y1 && x1 < y1
-- then f1 x1 < f1 y1
-- then Left (f1 x1) < Left (f1 y1)
-- then (f1 +++ f2) x < (f1 +++ f2) y
-- case 2: x = Left x1 && y = Right y2
-- then Left _ < Right _
-- then Left (f1 x1) < Right (f2 y2)
-- then (f1 +++ f2) x < (f1 +++ f2) y
-- case 3: x = Right x2 && y = Right y1 && x2 < y2
-- then f2 x2 < f2 y2
-- then Right (f2 x2) < Right (f2 y2)
-- then (f1 +++ f2) x < (f1 +++ f2) y
StrictlyMonotone (f1 +++ f2)
f1 +++ f2
= NotMonotone (runMFun f1 +++ runMFun f2)
class MapList k where
mapList :: k a b -> k [a] [b]
instance MapList (->) where
mapList = map
instance MapList CFun where
mapList Coerce
= Coerce
mapList (NotCoerce f)
= NotCoerce (map f)
instance MapList k => MapList (MFun k) where
mapList (StrictlyMonotone f)
= -- if xs < ys
-- case 1: xs = [] && ys = y:ys'
-- then [] < _:_
-- then map f [] < _:_
-- then map f xs < f y : map f ys'
-- then map f xs < map f ys
-- case 2: xs = x:xs' && ys = y:ys' && x < y
-- then f x < f y
-- then f x : _ < f y : _
-- then f x : map f xs' < f y : map f ys'
-- then map f xs < map f ys
-- case 3: xs = x:xs' && ys = y:ys' && x == y && xs' < ys'
-- then f x == f y && by induction, map f xs' < map f ys'
-- then f x : map f xs' < f y : map f ys'
-- then map f xs < map f ys
StrictlyMonotone (mapList f)
mapList (NotMonotone f)
= NotMonotone (mapList f)
-- Finally, let's write some tests. From now on we refrain from using the data
-- constructors from the ".Internal" module, and we imagine that it is the user
-- who is writing the code below using the library above. Thus, the user spends
-- their cognitive budget on making sure that 'wrapIdentity', 'singleton', and
-- 'addPrefix' really are strictly-monotone, and then they reap the benefits
-- below, in 'main', where they can compose those functions in a bunch of
-- different ways without having to think about monotonicity anymore.
exampleMap1 :: Map String Int
exampleMap1 = Map.fromList [("a",1),("b",2)]
wrapIdentity :: MCFun a (Identity a)
wrapIdentity
= unsafeStrictlyMonotone
$ Coerce
singleton :: MCFun a [a]
singleton
= unsafeStrictlyMonotone
$ NotCoerce (:[])
addPrefix :: String -> MCFun String String
addPrefix prefix
= unsafeStrictlyMonotone
$ NotCoerce (prefix ++)
down :: MCFun a (Down a)
down = arr Down
printComplexity
:: MCFun a b
-> IO ()
printComplexity (StrictlyMonotone Coerce) = do
putStrLn "O(1)"
printComplexity (StrictlyMonotone (NotCoerce _)) = do
putStrLn "O(n)"
printComplexity (NotMonotone _) = do
putStrLn "O(n log n)"
test
:: (Eq a, Ord b, Show b)
=> MCFun a b
-> Map a Int
-> IO ()
test f input = do
let expected = Map.mapKeys (runMCFun f) input
let actual = fancyMapKeys f input
if expected == actual
then do
printComplexity f
else do
putStrLn "expected:"
print expected
putStrLn "actual:"
print actual
main :: IO ()
main = do
test id exampleMap1 -- O(1)
test wrapIdentity exampleMap1 -- O(1)
test singleton exampleMap1 -- O(n)
test (addPrefix "./") exampleMap1 -- O(n)
test down exampleMap1 -- O(n log n)
test (wrapIdentity &&& wrapIdentity) exampleMap1 -- O(n)
test (id . wrapIdentity . id) exampleMap1 -- O(1)
test (addPrefix "../" . addPrefix "../") exampleMap1 -- O(n)
test (mapList id) exampleMap1 -- O(1)
test (mapList wrapIdentity) exampleMap1 -- O(1)
test (arr (map (\c -> "./" ++ [c]))) exampleMap1 -- when prototyping: (n log n)
test (mapList (addPrefix "./" . singleton)) exampleMap1 -- O(n)
test (mapList down) exampleMap1 -- O(n log n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment