Created
September 15, 2014 13:56
-
-
Save fatho/2f68c0cfb6ec9df1511c to your computer and use it in GitHub Desktop.
Cooperative interleaved threading in haskell.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor, TemplateHaskell, GeneralizedNewtypeDeriving, FlexibleContexts #-} | |
module Threads where | |
import Control.Applicative | |
import Control.Arrow | |
import Control.Monad | |
import Control.Monad.Except | |
import Control.Monad.State.Class | |
import Control.Monad.Trans | |
import Control.Monad.Trans.Free | |
import Control.Monad.Trans.State (StateT, evalStateT) | |
import Control.Lens | |
import Data.Dynamic | |
import Data.Foldable | |
import Data.List as L | |
import Data.Maybe | |
import Data.Monoid | |
import Data.Sequence as S hiding ((|>), (<|)) | |
import qualified Data.Map as M | |
-- * definition of thread data type | |
data Forked | |
= IsParent { _fChildId :: ThreadId } | |
| IsChild { _fParentId :: ThreadId } | |
data ThreadF next | |
= Yield next | |
| Fork (Forked -> next) | |
| MyThreadId (ThreadId -> next) | |
| Send ThreadId Dynamic next | |
| Recv ((ThreadId,Dynamic) -> next) | |
| Ret | |
deriving (Functor) | |
type ThreadT m = FreeT ThreadF m | |
-- * primitive thread operations | |
yield :: Monad m => ThreadT m () | |
yield = liftF $ Yield () | |
fork :: Monad m => ThreadT m Forked | |
fork = liftF $ Fork id | |
myThreadId :: Monad m => ThreadT m ThreadId | |
myThreadId = liftF $ MyThreadId id | |
send' :: Monad m => ThreadId -> Dynamic -> ThreadT m () | |
send' tid msg = liftF $ Send tid msg () | |
recv' :: Monad m => ThreadT m (ThreadId, Dynamic) | |
recv' = liftF $ Recv id | |
ret :: Monad m => ThreadT m r | |
ret = liftF $ Ret | |
-- * derived thread operations | |
-- | Spawns the given thread and passes the parents thread ID as an argument. | |
-- Returns the thread ID of the child thread | |
spawn :: Monad m => (ThreadId -> ThreadT m ()) -> ThreadT m ThreadId | |
spawn threadInit = do | |
f <- fork | |
case f of | |
IsParent chId -> return chId | |
IsChild paId -> do | |
threadInit paId | |
ret | |
send :: (Typeable a, Monad m) => ThreadId -> a -> ThreadT m () | |
send tid msg = send' tid (toDyn msg) | |
recv :: (Typeable a, Monad m) => ThreadT m ((ThreadId, Maybe a)) | |
recv = second fromDynamic <$> recv' | |
-- * scheduler | |
newtype ThreadId = ThreadId { _threadIdInt :: Int } deriving (Show, Eq, Ord) | |
type Error = String | |
newtype RoundRobin m a | |
= RoundRobin { runRoundRobin' :: ExceptT Error (StateT (SchedState m) m) a } | |
deriving (Functor, Applicative, Monad, MonadError Error, MonadState (SchedState m), MonadIO, Alternative, MonadPlus) | |
instance MonadTrans RoundRobin where | |
lift = RoundRobin . lift . lift | |
-- | Context associated with a thread | |
data ThreadContext m | |
= ThreadContext | |
{ _tcId :: ThreadId | |
, _tcThread :: ThreadT m () | |
} | |
-- | Context associated with the scheduler | |
data SchedState m | |
= SchedState | |
{ _sActiveThreads :: Seq (ThreadContext m) | |
, _sPendingRecv :: M.Map ThreadId ((ThreadId,Dynamic) -> ThreadContext m) | |
, _sPendingSend :: M.Map ThreadId (Seq (ThreadContext m, Dynamic)) | |
, _sNextId :: ThreadId | |
} | |
makeLenses ''ThreadId | |
makeLenses ''ThreadContext | |
makeLenses ''SchedState | |
makeLenses ''Forked | |
runRoundRobin :: Monad m => RoundRobin m a -> SchedState m -> m (Either Error a) | |
runRoundRobin sc st = evalStateT (runExceptT $ runRoundRobin' sc) st | |
incThreadId :: ThreadId -> ThreadId | |
incThreadId = over threadIdInt (+1) | |
setupRoundRobin :: Monad m => ThreadT m () -> SchedState m | |
setupRoundRobin root = SchedState | |
{ _sActiveThreads = S.singleton (ThreadContext rootId root) | |
, _sPendingRecv = M.empty | |
, _sPendingSend = M.empty | |
, _sNextId = incThreadId rootId | |
} | |
where | |
rootId = ThreadId 1 | |
doMaybe :: (Monad m) => Maybe a -> (a -> m ()) -> m () | |
doMaybe Nothing _ = return () | |
doMaybe (Just x) f = f x | |
newThreadId :: (Monad m) => RoundRobin m ThreadId | |
newThreadId = do | |
newId <- gets _sNextId | |
sNextId %= incThreadId | |
return newId | |
-- | Removes the next thread from the scheduling queue. | |
nextThread :: (Monad m) => RoundRobin m (Maybe (ThreadContext m)) | |
nextThread = do | |
tts <- use sActiveThreads | |
case uncons tts of | |
Nothing -> return Nothing | |
Just (t,ts) -> do | |
sActiveThreads .= ts | |
return (Just t) | |
-- | Insert a thread at the beginning of the queue. | |
scheduleFirst :: Monad m => ThreadContext m -> RoundRobin m () | |
scheduleFirst ctx = sActiveThreads %= (ctx <|) | |
-- | Insert a thread at the end of the queue. | |
scheduleLast :: Monad m => ThreadContext m -> RoundRobin m () | |
scheduleLast ctx = sActiveThreads %= (|> ctx) | |
-- | Replaces the thread state inside the context with a new one. | |
stepContext :: Monad m => ThreadT m () -> ThreadContext m -> ThreadContext m | |
stepContext = set tcThread | |
myRoundRobin :: Monad m => RoundRobin m () | |
myRoundRobin = do | |
t <- nextThread | |
case t of | |
-- no active threads, check if list of waiting threads is empty | |
Nothing -> do | |
pendingRecvs <- use $ sPendingRecv.to M.keys.traverse.threadIdInt.to pure | |
pendingSends <- use $ sPendingSend.traverse.traverse._1.tcId.threadIdInt.to pure | |
unless (L.null pendingRecvs && L.null pendingSends) $ do | |
throwError $ L.concat | |
[ "Stalled! Pending recv in: " | |
, L.intercalate " " (fmap show pendingRecvs) | |
, ". Pending send in :" | |
, L.intercalate " " (fmap show pendingSends) | |
] | |
-- process next thread in queue | |
Just t -> do | |
runThread t | |
myRoundRobin | |
where | |
runThread ctx = do | |
thState <- lift $ runFreeT $ ctx ^. tcThread | |
case thState of | |
-- thread exited normally | |
Pure _ -> return () | |
-- thread performs an action | |
Free a -> runAction ctx a | |
runAction ctx a = case a of | |
Yield next -> | |
scheduleLast $ stepContext next ctx | |
Fork cont -> do | |
childId <- newThreadId | |
let parentId = ctx ^. tcId | |
childTh = cont (IsChild parentId) | |
parentTh = cont (IsParent childId) | |
scheduleLast $ ThreadContext childId childTh | |
scheduleFirst $ stepContext parentTh ctx | |
MyThreadId cont -> | |
scheduleFirst $ stepContext (cont $ ctx^.tcId) ctx | |
Send tid msg next -> do | |
-- check if target waits for a message | |
target <- use $ sPendingRecv.at tid | |
case target of | |
Just recvCtx -> do | |
sPendingRecv.at tid .= Nothing | |
scheduleLast (recvCtx (ctx^.tcId,msg)) | |
scheduleFirst $ stepContext next ctx | |
Nothing -> | |
sPendingSend.at tid %= Just . (|> (stepContext next ctx,msg)) . fromMaybe S.empty | |
Recv cont -> do | |
-- check for pending sends | |
sendQueue <- use $ sPendingSend.at (ctx^.tcId) | |
case sendQueue >>= uncons of | |
Just ((sendCtx,msg),xs) -> do | |
scheduleLast sendCtx | |
scheduleFirst $ stepContext (cont (sendCtx^.tcId, msg)) ctx | |
sPendingSend.at (ctx^.tcId) ?= xs | |
Nothing -> | |
sPendingRecv.at (ctx^.tcId) ?= flip stepContext ctx . cont | |
Ret -> return () | |
-- * test programs | |
progErr :: ThreadT IO () | |
progErr = do | |
me <- fmap show myThreadId | |
flip catchError errHandler $ do | |
liftIO $ putStrLn $ me ++ ": Hello" | |
yield | |
liftIO $ putStrLn $ me ++ ": World" | |
yield | |
throwError (userError "BADABOOM") | |
errHandler :: Show a => a -> ThreadT IO () | |
errHandler e = do | |
me <- fmap show myThreadId | |
liftIO $ putStrLn $ me ++ ": Error" | |
yield | |
liftIO $ putStr me >> putStr ": " >> print e | |
prog1 :: ThreadT IO () | |
prog1 = do | |
me <- fmap show myThreadId | |
liftIO $ putStrLn $ me ++ ": Hello" | |
yield | |
liftIO $ putStrLn $ me ++ ": World" | |
prog2 = do | |
me <- fmap show myThreadId | |
liftIO $ putStrLn $ "Master: " ++ me | |
replicateM_ 10 $ do | |
_ <- spawn (const $ replicateM_ 3 $ void $ spawn $ const progErr) | |
replicateM_ 3 $ void $ spawn $ const prog1 | |
echoServer :: ThreadId -> ThreadT IO () | |
echoServer x = do | |
me <- fmap show myThreadId | |
liftIO $ putStrLn (me ++ " waiting for message") | |
(tid,msg) <- recv' | |
liftIO $ putStrLn (me ++ " received " ++ show msg ++ " from " ++ show tid) | |
send' tid msg | |
echoServer x | |
commTest :: ThreadT IO () | |
commTest = do | |
me <- fmap show myThreadId | |
server <- spawn echoServer | |
forever $ do | |
str <- liftIO $ getLine | |
send server str | |
(tid,Just msg) <- recv | |
liftIO $ putStrLn $ me ++ " received echo " ++ msg ++ " from " ++ show tid | |
deadLock :: ThreadT IO () | |
deadLock = do | |
bla <- fork | |
void $ recv' | |
main :: IO () | |
main = runRoundRobin myRoundRobin (setupRoundRobin commTest) >>= \r -> case r of | |
Right () -> putStrLn "Success!" | |
Left err -> putStrLn err |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment