Skip to content

Instantly share code, notes, and snippets.

@fatho
Created September 15, 2014 13:56
Show Gist options
  • Save fatho/2f68c0cfb6ec9df1511c to your computer and use it in GitHub Desktop.
Save fatho/2f68c0cfb6ec9df1511c to your computer and use it in GitHub Desktop.
Cooperative interleaved threading in haskell.
{-# LANGUAGE DeriveFunctor, TemplateHaskell, GeneralizedNewtypeDeriving, FlexibleContexts #-}
module Threads where
import Control.Applicative
import Control.Arrow
import Control.Monad
import Control.Monad.Except
import Control.Monad.State.Class
import Control.Monad.Trans
import Control.Monad.Trans.Free
import Control.Monad.Trans.State (StateT, evalStateT)
import Control.Lens
import Data.Dynamic
import Data.Foldable
import Data.List as L
import Data.Maybe
import Data.Monoid
import Data.Sequence as S hiding ((|>), (<|))
import qualified Data.Map as M
-- * definition of thread data type
data Forked
= IsParent { _fChildId :: ThreadId }
| IsChild { _fParentId :: ThreadId }
data ThreadF next
= Yield next
| Fork (Forked -> next)
| MyThreadId (ThreadId -> next)
| Send ThreadId Dynamic next
| Recv ((ThreadId,Dynamic) -> next)
| Ret
deriving (Functor)
type ThreadT m = FreeT ThreadF m
-- * primitive thread operations
yield :: Monad m => ThreadT m ()
yield = liftF $ Yield ()
fork :: Monad m => ThreadT m Forked
fork = liftF $ Fork id
myThreadId :: Monad m => ThreadT m ThreadId
myThreadId = liftF $ MyThreadId id
send' :: Monad m => ThreadId -> Dynamic -> ThreadT m ()
send' tid msg = liftF $ Send tid msg ()
recv' :: Monad m => ThreadT m (ThreadId, Dynamic)
recv' = liftF $ Recv id
ret :: Monad m => ThreadT m r
ret = liftF $ Ret
-- * derived thread operations
-- | Spawns the given thread and passes the parents thread ID as an argument.
-- Returns the thread ID of the child thread
spawn :: Monad m => (ThreadId -> ThreadT m ()) -> ThreadT m ThreadId
spawn threadInit = do
f <- fork
case f of
IsParent chId -> return chId
IsChild paId -> do
threadInit paId
ret
send :: (Typeable a, Monad m) => ThreadId -> a -> ThreadT m ()
send tid msg = send' tid (toDyn msg)
recv :: (Typeable a, Monad m) => ThreadT m ((ThreadId, Maybe a))
recv = second fromDynamic <$> recv'
-- * scheduler
newtype ThreadId = ThreadId { _threadIdInt :: Int } deriving (Show, Eq, Ord)
type Error = String
newtype RoundRobin m a
= RoundRobin { runRoundRobin' :: ExceptT Error (StateT (SchedState m) m) a }
deriving (Functor, Applicative, Monad, MonadError Error, MonadState (SchedState m), MonadIO, Alternative, MonadPlus)
instance MonadTrans RoundRobin where
lift = RoundRobin . lift . lift
-- | Context associated with a thread
data ThreadContext m
= ThreadContext
{ _tcId :: ThreadId
, _tcThread :: ThreadT m ()
}
-- | Context associated with the scheduler
data SchedState m
= SchedState
{ _sActiveThreads :: Seq (ThreadContext m)
, _sPendingRecv :: M.Map ThreadId ((ThreadId,Dynamic) -> ThreadContext m)
, _sPendingSend :: M.Map ThreadId (Seq (ThreadContext m, Dynamic))
, _sNextId :: ThreadId
}
makeLenses ''ThreadId
makeLenses ''ThreadContext
makeLenses ''SchedState
makeLenses ''Forked
runRoundRobin :: Monad m => RoundRobin m a -> SchedState m -> m (Either Error a)
runRoundRobin sc st = evalStateT (runExceptT $ runRoundRobin' sc) st
incThreadId :: ThreadId -> ThreadId
incThreadId = over threadIdInt (+1)
setupRoundRobin :: Monad m => ThreadT m () -> SchedState m
setupRoundRobin root = SchedState
{ _sActiveThreads = S.singleton (ThreadContext rootId root)
, _sPendingRecv = M.empty
, _sPendingSend = M.empty
, _sNextId = incThreadId rootId
}
where
rootId = ThreadId 1
doMaybe :: (Monad m) => Maybe a -> (a -> m ()) -> m ()
doMaybe Nothing _ = return ()
doMaybe (Just x) f = f x
newThreadId :: (Monad m) => RoundRobin m ThreadId
newThreadId = do
newId <- gets _sNextId
sNextId %= incThreadId
return newId
-- | Removes the next thread from the scheduling queue.
nextThread :: (Monad m) => RoundRobin m (Maybe (ThreadContext m))
nextThread = do
tts <- use sActiveThreads
case uncons tts of
Nothing -> return Nothing
Just (t,ts) -> do
sActiveThreads .= ts
return (Just t)
-- | Insert a thread at the beginning of the queue.
scheduleFirst :: Monad m => ThreadContext m -> RoundRobin m ()
scheduleFirst ctx = sActiveThreads %= (ctx <|)
-- | Insert a thread at the end of the queue.
scheduleLast :: Monad m => ThreadContext m -> RoundRobin m ()
scheduleLast ctx = sActiveThreads %= (|> ctx)
-- | Replaces the thread state inside the context with a new one.
stepContext :: Monad m => ThreadT m () -> ThreadContext m -> ThreadContext m
stepContext = set tcThread
myRoundRobin :: Monad m => RoundRobin m ()
myRoundRobin = do
t <- nextThread
case t of
-- no active threads, check if list of waiting threads is empty
Nothing -> do
pendingRecvs <- use $ sPendingRecv.to M.keys.traverse.threadIdInt.to pure
pendingSends <- use $ sPendingSend.traverse.traverse._1.tcId.threadIdInt.to pure
unless (L.null pendingRecvs && L.null pendingSends) $ do
throwError $ L.concat
[ "Stalled! Pending recv in: "
, L.intercalate " " (fmap show pendingRecvs)
, ". Pending send in :"
, L.intercalate " " (fmap show pendingSends)
]
-- process next thread in queue
Just t -> do
runThread t
myRoundRobin
where
runThread ctx = do
thState <- lift $ runFreeT $ ctx ^. tcThread
case thState of
-- thread exited normally
Pure _ -> return ()
-- thread performs an action
Free a -> runAction ctx a
runAction ctx a = case a of
Yield next ->
scheduleLast $ stepContext next ctx
Fork cont -> do
childId <- newThreadId
let parentId = ctx ^. tcId
childTh = cont (IsChild parentId)
parentTh = cont (IsParent childId)
scheduleLast $ ThreadContext childId childTh
scheduleFirst $ stepContext parentTh ctx
MyThreadId cont ->
scheduleFirst $ stepContext (cont $ ctx^.tcId) ctx
Send tid msg next -> do
-- check if target waits for a message
target <- use $ sPendingRecv.at tid
case target of
Just recvCtx -> do
sPendingRecv.at tid .= Nothing
scheduleLast (recvCtx (ctx^.tcId,msg))
scheduleFirst $ stepContext next ctx
Nothing ->
sPendingSend.at tid %= Just . (|> (stepContext next ctx,msg)) . fromMaybe S.empty
Recv cont -> do
-- check for pending sends
sendQueue <- use $ sPendingSend.at (ctx^.tcId)
case sendQueue >>= uncons of
Just ((sendCtx,msg),xs) -> do
scheduleLast sendCtx
scheduleFirst $ stepContext (cont (sendCtx^.tcId, msg)) ctx
sPendingSend.at (ctx^.tcId) ?= xs
Nothing ->
sPendingRecv.at (ctx^.tcId) ?= flip stepContext ctx . cont
Ret -> return ()
-- * test programs
progErr :: ThreadT IO ()
progErr = do
me <- fmap show myThreadId
flip catchError errHandler $ do
liftIO $ putStrLn $ me ++ ": Hello"
yield
liftIO $ putStrLn $ me ++ ": World"
yield
throwError (userError "BADABOOM")
errHandler :: Show a => a -> ThreadT IO ()
errHandler e = do
me <- fmap show myThreadId
liftIO $ putStrLn $ me ++ ": Error"
yield
liftIO $ putStr me >> putStr ": " >> print e
prog1 :: ThreadT IO ()
prog1 = do
me <- fmap show myThreadId
liftIO $ putStrLn $ me ++ ": Hello"
yield
liftIO $ putStrLn $ me ++ ": World"
prog2 = do
me <- fmap show myThreadId
liftIO $ putStrLn $ "Master: " ++ me
replicateM_ 10 $ do
_ <- spawn (const $ replicateM_ 3 $ void $ spawn $ const progErr)
replicateM_ 3 $ void $ spawn $ const prog1
echoServer :: ThreadId -> ThreadT IO ()
echoServer x = do
me <- fmap show myThreadId
liftIO $ putStrLn (me ++ " waiting for message")
(tid,msg) <- recv'
liftIO $ putStrLn (me ++ " received " ++ show msg ++ " from " ++ show tid)
send' tid msg
echoServer x
commTest :: ThreadT IO ()
commTest = do
me <- fmap show myThreadId
server <- spawn echoServer
forever $ do
str <- liftIO $ getLine
send server str
(tid,Just msg) <- recv
liftIO $ putStrLn $ me ++ " received echo " ++ msg ++ " from " ++ show tid
deadLock :: ThreadT IO ()
deadLock = do
bla <- fork
void $ recv'
main :: IO ()
main = runRoundRobin myRoundRobin (setupRoundRobin commTest) >>= \r -> case r of
Right () -> putStrLn "Success!"
Left err -> putStrLn err
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment