Skip to content

Instantly share code, notes, and snippets.

@msakai
Created February 24, 2022 14:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msakai/aba4c0b6114e3b6130716a40396813aa to your computer and use it in GitHub Desktop.
Save msakai/aba4c0b6114e3b6130716a40396813aa to your computer and use it in GitHub Desktop.
-- Fisher–Yates shuffle algorithm
-- https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#Modern_method
--
-- Note that mwc-random provides functions for permuting vectors.
-- https://hackage.haskell.org/package/mwc-random-0.14.0.0/docs/System-Random-MWC-Distributions.html#g:5
--
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE BangPatterns #-}
module Shufle
( shuffleVector
, shuffleVectorM
, shuffleList
) where
import Control.Monad
import Control.Monad.Primitive
import Data.Foldable (toList)
import qualified Data.Sequence as Seq
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified System.Random.MWC as Rand
shuffleVector :: (PrimMonad m, VG.Vector v a) => v a -> Rand.Gen (PrimState m) -> m (v a)
shuffleVector xs gen = do
ys <- VGM.new (VG.length xs)
shuffleTo (VG.length xs) (\i -> return (VG.unsafeIndex xs i)) ys gen
VG.unsafeFreeze ys
shuffleTo :: (PrimMonad m, VGM.MVector mv a) => Int -> (Int -> m a) -> mv (PrimState m) a -> Rand.Gen (PrimState m) -> m ()
shuffleTo len f target gen = do
unless (VGM.length target >= len) $ error "target should be larger than source"
forM_ [0 .. len - 1] $ \i -> do
j <- Rand.uniformR (0, i) gen
when (i /= j) $ VGM.unsafeWrite target i =<< VGM.unsafeRead target j
VGM.unsafeWrite target j =<< f i
shuffleVectorM :: (PrimMonad m, VGM.MVector mv a) => mv (PrimState m) a -> Rand.Gen (PrimState m) -> m ()
shuffleVectorM xs gen = do
let n = VGM.length xs
forM_ [n - 1, n - 2 .. 0] $ \i -> do
j <- Rand.uniformR (0, i) gen
when (i /= j) $ VGM.unsafeSwap xs i j
shuffleList :: (PrimMonad m) => [a] -> Rand.Gen (PrimState m) -> m [a]
shuffleList xs gen = liftM toList $ shuffleListToSeq xs gen
shuffleListToSeq :: (PrimMonad m) => [a] -> Rand.Gen (PrimState m) -> m (Seq.Seq a)
shuffleListToSeq xs0 gen = go xs0 Seq.empty
where
go [] !target = return target
go (x : xs) !target = do
let i = Seq.length target
j <- Rand.uniformR (0, i) gen
if i == j then do
go xs (target Seq.|> x)
else do
go xs (Seq.update j x target Seq.|> Seq.index target j)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment