Created
February 24, 2022 14:47
-
-
Save msakai/aba4c0b6114e3b6130716a40396813aa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Fisher–Yates shuffle algorithm | |
-- https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#Modern_method | |
-- | |
-- Note that mwc-random provides functions for permuting vectors. | |
-- https://hackage.haskell.org/package/mwc-random-0.14.0.0/docs/System-Random-MWC-Distributions.html#g:5 | |
-- | |
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE BangPatterns #-} | |
module Shufle | |
( shuffleVector | |
, shuffleVectorM | |
, shuffleList | |
) where | |
import Control.Monad | |
import Control.Monad.Primitive | |
import Data.Foldable (toList) | |
import qualified Data.Sequence as Seq | |
import qualified Data.Vector.Generic as VG | |
import qualified Data.Vector.Generic.Mutable as VGM | |
import qualified System.Random.MWC as Rand | |
shuffleVector :: (PrimMonad m, VG.Vector v a) => v a -> Rand.Gen (PrimState m) -> m (v a) | |
shuffleVector xs gen = do | |
ys <- VGM.new (VG.length xs) | |
shuffleTo (VG.length xs) (\i -> return (VG.unsafeIndex xs i)) ys gen | |
VG.unsafeFreeze ys | |
shuffleTo :: (PrimMonad m, VGM.MVector mv a) => Int -> (Int -> m a) -> mv (PrimState m) a -> Rand.Gen (PrimState m) -> m () | |
shuffleTo len f target gen = do | |
unless (VGM.length target >= len) $ error "target should be larger than source" | |
forM_ [0 .. len - 1] $ \i -> do | |
j <- Rand.uniformR (0, i) gen | |
when (i /= j) $ VGM.unsafeWrite target i =<< VGM.unsafeRead target j | |
VGM.unsafeWrite target j =<< f i | |
shuffleVectorM :: (PrimMonad m, VGM.MVector mv a) => mv (PrimState m) a -> Rand.Gen (PrimState m) -> m () | |
shuffleVectorM xs gen = do | |
let n = VGM.length xs | |
forM_ [n - 1, n - 2 .. 0] $ \i -> do | |
j <- Rand.uniformR (0, i) gen | |
when (i /= j) $ VGM.unsafeSwap xs i j | |
shuffleList :: (PrimMonad m) => [a] -> Rand.Gen (PrimState m) -> m [a] | |
shuffleList xs gen = liftM toList $ shuffleListToSeq xs gen | |
shuffleListToSeq :: (PrimMonad m) => [a] -> Rand.Gen (PrimState m) -> m (Seq.Seq a) | |
shuffleListToSeq xs0 gen = go xs0 Seq.empty | |
where | |
go [] !target = return target | |
go (x : xs) !target = do | |
let i = Seq.length target | |
j <- Rand.uniformR (0, i) gen | |
if i == j then do | |
go xs (target Seq.|> x) | |
else do | |
go xs (Seq.update j x target Seq.|> Seq.index target j) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment