Skip to content

Instantly share code, notes, and snippets.

@michaelt
Last active August 29, 2015 14:23
Show Gist options
  • Save michaelt/eb738a5b6a7524471e61 to your computer and use it in GitHub Desktop.
Save michaelt/eb738a5b6a7524471e61 to your computer and use it in GitHub Desktop.
{-#LANGUAGE TypeOperators, LambdaCase, BangPatterns, RankNTypes #-}
-- Needed for the MonadBase instance
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-#LANGUAGE ScopedTypeVariables #-}
module Loop where
import Control.Monad.Trans
import Control.Monad
import Control.Monad.Base
import Control.Monad.ST
import Data.STRef
-- * 'exit' the whole loop.
newtype LoopT c e m a = LoopT
{ runLoopT :: forall r. -- This universal quantification forces the
-- LoopT computation to call one of the
-- following continuations.
(c -> m r) -- continue
-> (e -> m r) -- exit
-> (a -> m r) -- return a value
-> m r
}
newtype Loop c e m a = Loop {runLoop :: m (Either c (Either e a))}
instance Monad m => Functor (Loop c e m) where
fmap f = Loop . liftM (fmap (fmap f)) . runLoop
instance Monad m => Applicative (Loop c e m) where pure = return ; (<*>) = ap
instance Monad m => Monad (Loop c e m) where
return = Loop . return . Right . Right
Loop mee >>= f = Loop $ do
ee <- mee
case ee of
Right (Right a) -> runLoop (f a)
Right (Left e) -> return (Right (Left e))
Left c -> return (Left c)
instance MonadTrans (Loop c e) where
lift = Loop . liftM (Right . Right)
instance MonadIO m => MonadIO (Loop c e m) where
liftIO = lift . liftIO
instance Functor (LoopT c e m) where
fmap f m = LoopT $ \next fin cont -> runLoopT m next fin (cont . f)
instance Applicative (LoopT c e m) where
pure a = LoopT $ \_ _ cont -> cont a
f1 <*> f2 = LoopT $ \next fin cont ->
runLoopT f1 next fin $ \f ->
runLoopT f2 next fin (cont . f)
instance Monad (LoopT c e m) where
return a = LoopT $ \_ _ cont -> cont a
m >>= k = LoopT $ \next fin cont ->
runLoopT m next fin $ \a ->
runLoopT (k a) next fin cont
instance MonadTrans (LoopT c e) where
lift m = LoopT $ \_ _ cont -> m >>= cont
instance MonadIO m => MonadIO (LoopT c e m) where
liftIO = lift . liftIO
instance MonadBase b m => MonadBase b (LoopT c e m) where
liftBase = liftBaseDefault
--
instance MonadBase b m => MonadBase b (Loop c e m) where
liftBase = liftBaseDefault
-- | Skip the rest of the loop body and go to the next iteration.
continue :: LoopT () e m a
continue = continueWith ()
continue_ :: Monad m => Loop () e m a
continue_ = continueWith_ ()
-- | Break out of the loop entirely.
exit :: LoopT c () m a
exit = exitWith ()
exit_ :: Monad m => Loop c () m a
exit_ = exitWith_ ()
-- | Like 'continue', but return a value from the loop body.
continueWith :: c -> LoopT c e m a
continueWith c = LoopT $ \next _ _ -> next c
continueWith_ :: Monad m => c -> Loop c e m a
continueWith_ c = Loop (return (Left c))
-- | Like 'exit', but return a value from the loop as a whole.
-- See the documentation of 'iterateLoopT' for an example.
exitWith :: e -> LoopT c e m a
exitWith e = LoopT $ \_ fin _ -> fin e
exitWith_ :: Monad m => e -> Loop c e m a
exitWith_ e = Loop (return (Right (Left e)))
------------------------------------------------------------------------
-- Looping constructs
-- | Call the loop body with each item in the list.
--
-- If you do not need to 'continue' or 'exit' the loop, consider using
-- 'Control.Monad.forM_' instead.
foreach :: Monad m => [a] -> (a -> LoopT c () m c) -> m ()
foreach list body = loop list
where loop [] = return ()
loop (x:xs) = stepLoopT (body x) (\_ -> loop xs)
--
foreach_ :: Monad m => [a] -> (a -> Loop c () m c) -> m ()
foreach_ list body = loop list
where loop [] = return ()
loop (x:xs) = stepLoop (body x) (\_ -> loop xs)
-- | Repeat the loop body while the predicate holds. Like a @while@ loop in C,
-- the condition is tested first.
while :: Monad m => m Bool -> LoopT c () m c -> m ()
while cond body = loop
where loop = do b <- cond
if b then stepLoopT body (\_ -> loop)
else return ()
-- | Like a @do while@ loop in C, where the condition is tested after
-- the loop body.
--
-- 'doWhile' returns the result of the last iteration. This is possible
-- because, unlike 'foreach' and 'while', the loop body is guaranteed to be
-- executed at least once.
doWhile :: Monad m => LoopT a a m a -> m Bool -> m a
doWhile body cond = loop
where loop = stepLoopT body $ \a -> do
b <- cond
if b then loop
else return a
-- | Execute the loop body once. This is a convenient way to introduce early
-- exit support to a block of code.
--
-- 'continue' and 'exit' do the same thing inside of 'once'.
once :: Monad m => LoopT a a m a -> m a
once body = runLoopT body return return return
-- | Execute the loop body again and again. The only way to exit 'repeatLoopT'
-- is to call 'exit' or 'exitWith'.
repeatLoopT :: Monad m => LoopT c e m a -> m e
repeatLoopT body = loop
where loop = runLoopT body (\_ -> loop) return (\_ -> loop)
repeatLoop (Loop mee) = loop
where
loop = do
ee <- mee
case ee of
Left c -> loop
Right (Left e) -> return e
Right (Right a) -> loop
-- | Call the loop body again and again, passing it the result of the previous
-- iteration each time around. The only way to exit 'iterateLoopT' is to call
-- 'exit' or 'exitWith'.
--
-- Example:
--
count :: Int -> IO Int
count n = iterateLoopT 0 $ \i ->
if i < n
then do lift $ print i
return $ i+1
else exitWith i
--
count_ :: Int -> IO Int
count_ n = iterateLoop 0 $ \i ->
if i < n
then do lift $ print i
return $ i+1
else exitWith_ i
iterateLoopT :: Monad m => c -> (c -> LoopT c e m c) -> m e
iterateLoopT z body = loop z where loop c = stepLoopT (body c) loop
iterateLoop :: Monad m => c -> (c -> Loop c e m c) -> m e
iterateLoop z body = loop z where loop c = stepLoop (body c) loop
stepLoopT :: Monad m => LoopT c e m c -> (c -> m e) -> m e
stepLoopT body next = runLoopT body next return next
stepLoop :: Monad m => Loop c e m c -> (c -> m e) -> m e
stepLoop (Loop mee) f = do
ee <- mee
case ee of
Left c -> f c
Right (Left e) -> return e
Right (Right c) -> f c
------------------------------------------------------------------------
-- Lifting other operations
-- | Lift a function like 'Control.Monad.Trans.Reader.local' or
-- 'Control.Exception.mask_'.
liftLocalLoopT :: Monad m => (forall a. m a -> m a) -> LoopT c e m b -> LoopT c e m b
liftLocalLoopT f cb = LoopT $ \next fin cont -> do
m <- f $ runLoopT cb (return . next) (return . fin) (return . cont)
m
{-#LANGUAGE ScopedTypeVariables, BangPatterns #-}
module Main where
import Loop
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Data.STRef
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Base
countCircularPrimes :: Int -> Int
countCircularPrimes e =
runST $ do
bmp <- newArray (1, e) False :: ST s (STUArray s Int Bool)
total <- newSTRef 0
foreach (filter isPrime [2..e]) $ \i -> do
-- If i is marked, we've already visited i and its rotations,
-- so go on to the next prime.
whenM (liftBase $ readArray bmp i)
continue
let rs = rotateDigits i
-- Count the number of unique rotations. We may end up marking a
-- number with fewer digits, but that's okay because:
--
-- * We've already visited numbers with fewer digits.
--
-- * A circular prime will never contain the digit 0.
--
-- Thus, any counts affected by truncation will be discarded anyway.
count <- liftBase $ newSTRef 0
foreach rs $ \j -> do
whenM (liftBase $ readArray bmp j)
exit
liftBase $ writeArray bmp j True
liftBase $ modifySTRef' count (+1)
when (all isPrime rs) $ liftBase $
total += count
readSTRef total
where
sieve = mkSieve e
isPrime p = sieve ! p
--
countCircularPrimes_ :: Int -> Int
countCircularPrimes_ e =
runST $ do
bmp <- newArray (1, e) False :: ST s (STUArray s Int Bool)
total <- newSTRef 0
foreach_ (filter isPrime [2..e]) $ \i -> do
-- If i is marked, we've already visited i and its rotations,
-- so go on to the next prime.
whenM (liftBase $ readArray bmp i)
continue_
let rs = rotateDigits i
-- Count the number of unique rotations. We may end up marking a
-- number with fewer digits, but that's okay because:
--
-- * We've already visited numbers with fewer digits.
--
-- * A circular prime will never contain the digit 0.
--
-- Thus, any counts affected by truncation will be discarded anyway.
count <- liftBase $ newSTRef 0
foreach_ rs $ \j -> do
whenM (liftBase $ readArray bmp j)
exit_
liftBase $ writeArray bmp j True
liftBase $ modifySTRef' count (+1)
when (all isPrime rs) $ liftBase $
total += count
readSTRef total
where
sieve = mkSieve e
isPrime p = sieve ! p
-- main :: IO ()
-- main = print $ countCircularPrimes_ 999999
main :: IO ()
main = do
foreach [1..10] $ \(i :: Int) -> do
foreach [1..10] $ \(j :: Int) -> do
when (j > i) $
lift continue
when (i == 2 && j == 2) $
exit
when (i == 9 && j == 9) $
lift exit
liftBase $ print (i, j)
liftBase $ putStrLn "Inner loop finished"
putStrLn "Outer loop finished"
main_ :: IO ()
main_ = do
foreach_ [1..10] $ \(i :: Int) -> do
foreach_ [1..10] $ \(j :: Int) -> do
when (j > i) $
lift continue_
when (i == 2 && j == 2) $
exit_
when (i == 9 && j == 9) $
lift exit_
liftBase $ print (i, j)
liftBase $ putStrLn "Inner loop finished"
putStrLn "Outer loop finished"
------------------------------------------------------------------------
-- Helper functions
mkSieve :: Int -> UArray Int Bool
mkSieve e = runSTUArray $ do
bmp <- newArray (2, e) True
forM_ [2..e] $ \i ->
whenM (readArray bmp i) $
forM_ [i*2, i*3 .. e] $ \j ->
writeArray bmp j False
return bmp
--
-- | Return a list of every rotation of the decimal digits of a number.
rotateDigits :: Int -> [Int]
rotateDigits start
| start < 1 = error "rotateDigits: n < 1"
| otherwise = let (_, !r, !rs) = go 1
in rs [r]
where
go p | p' > start || p' < p = (p, start, id)
| otherwise = let (!factor, !r, !rs) = go p'
(!n, !d) = r `divMod` 10
in (factor, d * factor + n, rs . (r :))
where
p' = p*10
(+=) :: STRef s Int -> STRef s Int -> ST s ()
(+=) total n = readSTRef n >>= modifySTRef' total . (+)
whenM :: Monad m => m Bool -> m () -> m ()
whenM p m = p >>= \b -> if b then m else return ()
-- This solves Google Code Jam 2012 Qualification Problem C "Recycled Numbers" [1].
-- The problem is: given a range of numbers with the same number of digits,
-- count how many pairs of them are the same modulo rotation of digits.
--
-- [1]: http://code.google.com/codejam/contest/1460488/dashboard#s=p2
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Loop
import Control.Applicative ((<$>))
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.Class
import Data.Array.ST
import Data.STRef
main :: IO ()
main = do
-- t <- readLn
-- forM_ [1..t] $ \(x :: Int) -> do
-- [a, b] <- map read . words <$> getLine
let y = recycledNumbers (10000000, 20000000)
putStrLn $ "Case #" ++ show 1 ++ ": " ++ show y
recycledNumbers :: (Int, Int) -> Int
recycledNumbers (lb, ub)
| not (1 <= lb && lb <= ub && factor == rotateFactor ub)
= error "recycledNumbers: invalid bounds"
| otherwise = runST $ do
bmp <- newArray (lb, ub) False :: ST s (STUArray s Int Bool)
total <- newSTRef 0
forM_ [lb..ub] $ \i -> do
count <- newSTRef 0
foreach (iterate rotate i) $ \j -> do
when (not $ j >= i && j <= ub)
continue
whenM (lift $ readArray bmp j)
exit
lift $ writeArray bmp j True
lift $ modifySTRef' count (+1)
readSTRef count >>= modifySTRef' total . (+) . numPairs
readSTRef total
where
factor = rotateFactor lb
rotate x = let (n, d) = x `divMod` 10
in d*factor + n
numPairs n = (n-1) * n `div` 2
--
recycledNumbers_ :: (Int, Int) -> Int
recycledNumbers_ (lb, ub)
| not (1 <= lb && lb <= ub && factor == rotateFactor ub)
= error "recycledNumbers: invalid bounds"
| otherwise = runST $ do
bmp <- newArray (lb, ub) False :: ST s (STUArray s Int Bool)
total <- newSTRef 0
forM_ [lb..ub] $ \i -> do
count <- newSTRef 0
foreach_ (iterate rotate i) $ \j -> do
when (not $ j >= i && j <= ub)
continue_
whenM (lift $ readArray bmp j)
exit_
lift $ writeArray bmp j True
lift $ modifySTRef' count (+1)
readSTRef count >>= modifySTRef' total . (+) . numPairs
readSTRef total
where
factor = rotateFactor lb
rotate x = let (n, d) = x `divMod` 10
in d*factor + n
numPairs n = (n-1) * n `div` 2
------------------------------------------------------------------------
-- Helper functions
-- | Return the power of 10 corresponding to the most significant digit in the
-- number.
rotateFactor :: Int -> Int
rotateFactor n | n < 1 = error "rotateFactor: n < 1"
| otherwise = loop 1
where
loop p | p' > n = p
| p' < p = p -- in case of overflow
| otherwise = loop p'
where p' = p * 10
(+=) :: STRef s Int -> STRef s Int -> ST s ()
(+=) total n = readSTRef n >>= modifySTRef' total . (+)
whenM :: Monad m => m Bool -> m () -> m ()
whenM p m = p >>= \b -> if b then m else return ()
@michaelt
Copy link
Author

Control.Monad.Loop with a non-church-encoded variant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment