Skip to content

Instantly share code, notes, and snippets.

@treeowl
Last active December 13, 2023 15:12
Show Gist options
  • Save treeowl/9621f58d55fe0c4f9162be0e074b1b29 to your computer and use it in GitHub Desktop.
Save treeowl/9621f58d55fe0c4f9162be0e074b1b29 to your computer and use it in GitHub Desktop.
Fast total sorting of arbitrary Traversable containers

UPDATE: This gist is now available as a proper git repository.

Will Fancher recently wrote a blog post (see also the Reddit thread) about sorting arbitrary Traversable containers without any of the ugly incomplete pattern matches that accompany the well-known technique of dumping all the entries into a list and then sucking them back out in State. Fancher used a custom applicative based on the usual free applicative. Unfortunately, this type is rather hard to work with, and Fancher was not immediately able to find a way to use anything better than insertion sort. This gist demonstrates an asymptotically optimal heap sort using a heap-merging applicative.

The three modules:

  • BasicNat: unary natural numbers, singletons, and properties
  • IndexedPairingHeap: size-indexed pairing heaps
  • HSTrav: the big payoff: heap-sorting anything

Note: If you came here from Reddit, be aware that I've updated this gist a drop based on comments there, rendering some of the comments a bit outdated. Nothing major has changed, just:

  • The Heap field of Sort is now strict.
  • Some functions are now marked INLINABLE or INLINE.
  • The explanation of why the Applicative instance is valid has been simplified and is also now more complete.
{-# LANGUAGE DataKinds, TypeFamilies, TypeOperators, GADTs,
ScopedTypeVariables, TypeOperators #-}
-- | Type-level natural numbers and singletons, with proofs of
-- a few basic properties.
module BasicNat (
-- | Type-level natural numbers
Nat (..)
, type (+)
-- | Natural number singletons
, Natty (..)
, plus
-- | Basic properties
, plusCommutative
, plusZero
, plusSucc
, plusAssoc
) where
import Data.Type.Equality ((:~:)(..))
import Unsafe.Coerce
-- | Type-level natural numbers
data Nat = Z | S Nat
-- | Type-level natural number addition
type family (+) m n where
'Z + n = n
'S m + n = 'S (m + n)
-- | Singletons for natural numbers
data Natty n where
Zy :: Natty 'Z
Sy :: Natty n -> Natty ('S n)
-- | Singleton addition
plus :: Natty m -> Natty n -> Natty (m + n)
plus Zy n = n
plus (Sy m) n = Sy (plus m n)
----------------------------------------------------------
-- Proofs of basic arithmetic
--
-- The legitimate proofs are accompanied by rewrite rules that
-- effectively assert termination. These rules prevent us from
-- actually having to run the proof code, which would be slow.
plusCommutative :: Natty m -> Natty n -> (m + n) :~: (n + m)
plusCommutative Zy n = case plusZero n of Refl -> Refl
plusCommutative (Sy m) n =
case plusCommutative m n of { Refl ->
case plusSucc n m of Refl -> Refl }
{-# NOINLINE plusCommutative #-}
plusZero :: Natty m -> (m + 'Z) :~: m
plusZero Zy = Refl
plusZero (Sy n) = case plusZero n of Refl -> Refl
{-# NOINLINE plusZero #-}
plusSucc :: Natty m -> proxy n -> (m + 'S n) :~: ('S (m + n))
plusSucc Zy _ = Refl
plusSucc (Sy n) p = case plusSucc n p of Refl -> Refl
{-# NOINLINE plusSucc #-}
plusAssoc :: Natty m -> p1 n -> p2 o -> (m + (n + o)) :~: ((m + n) + o)
plusAssoc Zy _ _ = Refl
plusAssoc (Sy m) p1 p2 = case plusAssoc m p1 p2 of Refl -> Refl
{-# NOINLINE plusAssoc #-}
{-# RULES
"plusCommutative" forall m n. plusCommutative m n = unsafeCoerce (Refl :: 'Z :~: 'Z)
"plusZero" forall m . plusZero m = unsafeCoerce (Refl :: 'Z :~: 'Z)
"plusSucc" forall m n. plusSucc m n = unsafeCoerce (Refl :: 'Z :~: 'Z)
"plusAssoc" forall m p1 p2. plusAssoc m p1 p2 = unsafeCoerce (Refl :: 'Z :~: 'Z)
#-}
{-# LANGUAGE ScopedTypeVariables, GADTs, TypeOperators,
RankNTypes, InstanceSigs, DataKinds #-}
module HSTrav where
import IndexedPairingHeap (Heap, Sized (..), empty, singleton, merge, minView)
import Data.Proxy
import Data.Type.Equality ((:~:) (..))
import BasicNat (type (+), Nat (..), plusAssoc, plusZero)
-- | A heap of some size whose element have type @a@ and a
-- function that, applied to any heap at least that large,
-- will produce a result and the rest of the heap.
data Sort a r where
Sort :: (forall n. Proxy n -> Heap (m + n) a -> (Heap n a, r))
-> !(Heap m a)
-> Sort a r
instance Functor (Sort x) where
fmap f (Sort g h) =
Sort (\p h' -> case g p h' of (remn, r) -> (remn, f r)) h
{-# INLINE fmap #-}
instance Ord x => Applicative (Sort x) where
{-# INLINE pure #-}
{-# INLINABLE (<*>) #-}
pure x = Sort (\_ h -> (h, x)) empty
-- Combine two 'Sort's by merging their heaps and composing
-- their functions.
(<*>) :: forall a b . Sort x (a -> b) -> Sort x a -> Sort x b
Sort f (xs :: Heap m x) <*> Sort g (ys :: Heap n x) =
Sort h (merge xs ys)
where
h :: forall o . Proxy o -> Heap ((m + n) + o) x -> (Heap o x, b)
h p v = case plusAssoc (size xs) (size ys) p of
Refl -> case f (Proxy :: Proxy (n + o)) v of { (v', a) ->
case g (Proxy :: Proxy o) v' of { (v'', b) ->
(v'', a b)}}
-- Produce a 'Sort' with a singleton heap and a function that will
-- produce the smallest element of a heap.
liftSort :: Ord x => x -> Sort x x
liftSort a = Sort (\_ h -> case minView h of (x, h') -> (h', x)) (singleton a)
{-# INLINABLE liftSort #-}
-- Apply the function in a 'Sort' to the heap within, producing a
-- result.
runSort :: forall x a . Sort x a -> a
runSort (Sort f xs) = case plusZero (size xs) of
Refl -> snd $ f (Proxy :: Proxy 'Z) xs
-- | Sort an arbitrary 'Traversable' container using a heap.
sortTraversable :: (Ord a, Traversable t) => t a -> t a
sortTraversable = runSort . traverse liftSort
{-# INLINABLE sortTraversable #-}
-- | Sort an arbitrary container using a 'Traversal' (in the
-- 'lens' sense).
sortTraversal :: Ord a => ((a -> Sort a a) -> t -> Sort a t) -> t -> t
sortTraversal trav = runSort . trav liftSort
{-# INLINABLE sortTraversal #-}
{-# LANGUAGE DataKinds, ScopedTypeVariables, TypeOperators, GADTs, RoleAnnotations #-}
module IndexedPairingHeap (
Heap
, Sized (..)
, empty
, singleton
, insert
, merge
, minView
) where
import BasicNat
import Data.Type.Equality ((:~:)(..))
-- | Okasaki's simple representation of a pairing heap, but with
-- a size index.
data Heap n a where
E :: Heap 'Z a
T :: a -> HVec n a -> Heap ('S n) a
-- Coercing a heap could destroy the heap property, so we declare both
-- type parameters nominal.
type role Heap nominal nominal
-- | A vector of heaps whose sizes sum to the index.
data HVec n a where
HNil :: HVec 'Z a
HCons :: Heap m a -> HVec n a -> HVec (m + n) a
class Sized h where
-- | Calculate the size of a structure
size :: h n a -> Natty n
instance Sized Heap where
size E = Zy
size (T _ xs) = Sy (size xs)
instance Sized HVec where
size HNil = Zy
size (HCons h hs) = size h `plus` size hs
-- Produce an empty heap
empty :: Heap 'Z a
empty = E
-- Produce a heap with one element
singleton :: a -> Heap ('S 'Z) a
singleton a = T a HNil
-- Insert an element into a heap
insert :: Ord a => a -> Heap n a -> Heap ('S n) a
insert x xs = merge (singleton x) xs
{-# INLINABLE insert #-}
-- Merge two heaps
merge :: Ord a => Heap m a -> Heap n a -> Heap (m + n) a
merge E ys = ys
merge xs E = case plusZero (size xs) of Refl -> xs
merge h1@(T x xs) h2@(T y ys)
| x <= y = case plusCommutative (size h2) (size xs) of Refl -> T x (HCons h2 xs)
| otherwise = case plusSucc (size xs) (size ys) of Refl -> T y (HCons h1 ys)
{-# INLINABLE merge #-}
-- Get the smallest element of a non-empty heap, and the rest of
-- the heap
minView :: Ord a => Heap ('S n) a -> (a, Heap n a)
minView (T x hs) = (x, mergePairs hs)
{-# INLINABLE minView #-}
mergePairs :: Ord a => HVec n a -> Heap n a
mergePairs HNil = E
mergePairs (HCons h HNil) = case plusZero (size h) of Refl -> h
mergePairs (HCons h1 (HCons h2 hs)) =
case plusAssoc (size h1) (size h2) (size hs) of
Refl -> merge (merge h1 h2) (mergePairs hs)
{-# INLINABLE mergePairs #-}

It is not immediately obvious that Sort is a lawful Applicative at all. Let's see if we can figure it out! The indices just get in the way here, so let's clean up the Applicative instance a bit. It won't compile like this, but that doesn't matter.

pure x = Sort (\_ h -> (h, x)) empty

(<*>) :: forall a b . Sort x (a -> b) -> Sort x a -> Sort x b
Sort f xs <*> Sort g ys =
  Sort h (merge xs ys)
  where
    h :: forall o. Proxy o -> Heap ((m + n) + o) x -> (Heap o x, b)
    h p v = case f Proxy v of { (v', a) ->
              case g Proxy v' of { (v'', b) ->
                (v'', a b)}}

As "paf31" noted, Sort a is (indices, proxies, and strictness annotation aside) the Product of two applicative functors:

Sort a ~= Product (State (Heap a)) (Const (Heap a))

with precisely the Applicative instance that suggests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment