Skip to content

Instantly share code, notes, and snippets.

@LambdaP
Last active August 29, 2015 14:23
Show Gist options
  • Save LambdaP/f725d392814f34bfb597 to your computer and use it in GitHub Desktop.
Save LambdaP/f725d392814f34bfb597 to your computer and use it in GitHub Desktop.
In-place quicksort in Haskell using STArrays.
{-# LANGUAGE FlexibleContexts #-}
import Control.Monad.ST
import Data.Array.ST
quickSort :: (Ord a) => [a] -> [a]
quickSort l = runST $
do
arr <- newListArray (0, length l - 1) l
bounds <- getBounds arr
step bounds arr
getElems arr
where
step (i, j) arr | i >= j = return ()
| otherwise =
do
(k, pivot) <- switchPivot (i, j) arr
switchSides (i, j) (k, pivot) arr
step (i, k - 1) arr
step (k + 1, j) arr
nLower :: (Ord e) => e -> (Int, Int) -> STArray s Int e -> ST s Int
nLower pivot (i, j) arr = go i 0
where
go k acc | k > j = return acc
| otherwise =
do
cur <- readArray arr k
let acc' = if cur < pivot then acc + 1 else acc
go (k + 1) acc'
switchPivot :: (Ord e) => (Int, Int) -> STArray s Int e -> ST s (Int, e)
switchPivot (i, j) arr | i == j =
do
pivot <- readArray arr i
return (i, pivot)
| otherwise =
do
ixpivot <- getIxPivot (i, j) arr
pivot <- readArray arr (i + ixpivot)
ixtmp <- nLower pivot (i, j) arr
tmp <- readArray arr (i + ixtmp)
writeArray arr (i + ixpivot) tmp
writeArray arr (i + ixtmp) pivot
return (i + ixtmp, pivot)
where
-- Will make it easier to implement another pivot.
getIxPivot = getMedianPivot
simplePivot _ _ = return 0
-- Like this one.
getMedianPivot (i, j) arr =
do
let ixend = j - i
ixmiddle = ixend `div` 2
first <- readArray arr i
last <- readArray arr j
mid <- readArray arr (i + ixmiddle)
return $
if first > last
then
if mid > last
then ixmiddle
else ixend
else
if mid > first
then ixmiddle
else 0
switchSides :: (Ord e) => (Int, Int)
-> (Int, e)
-> STArray s Int e
-> ST s ()
switchSides (i, j) (k, pivot) arr | i == j = return ()
| otherwise = go1 i j
where
go1 l h | l == k = return ()
| otherwise =
do
cur <- readArray arr l
if cur < pivot then go1 (l + 1) h else go2 l h
go2 l h | h == k = return ()
| otherwise =
do
cur <- readArray arr h
if cur > pivot then go2 l (h - 1) else
do
tmp <- readArray arr l
writeArray arr h tmp
writeArray arr l cur
go1 (l + 1) (h - 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment