Last active
May 4, 2023 23:46
-
-
Save msakai/5abe8f13cedd1d361a3a57449cb9205e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
----------------------------------------------------------------------------- | |
-- | | |
-- Module : FairLock | |
-- Copyright : (c) Masahiro Sakai 2023 | |
-- License : BSD-3-Clause | |
-- | |
-- Simple Lock implemented using STM. | |
-- When multiple threads are blocked on a Lock, they are woken up in FIFO order. | |
-- | |
-- Note that this is a toy implementation and you should use MVar instead for | |
-- practical purpose. MVar also guarantees fairness and is faster. | |
-- | |
-- concurrent-extra package provides Control.Concurrent.Lock module with lock | |
-- interface implemented using MVar. | |
-- <https://hackage.haskell.org/package/concurrent-extra> | |
-- | |
----------------------------------------------------------------------------- | |
module FairLock | |
( Lock | |
, newLock | |
, acquireLock | |
, releaseLock | |
, withLock | |
) where | |
import Control.Concurrent.STM | |
import Control.Exception | |
import Control.Monad | |
import Data.Sequence (Seq) | |
import qualified Data.Sequence as Seq | |
newtype Lock = FairLock (TVar State) | |
data State | |
= Unlocked | |
| Locked WaitList | |
-- | List of waiting threads | |
-- | |
-- Those threads can be signaled by putting () to the TMVar. | |
type WaitList = Seq (TMVar ()) | |
newLock :: IO Lock | |
newLock = FairLock <$> newTVarIO Unlocked | |
acquireLock :: Lock -> IO () | |
acquireLock lock@(FairLock tv) = mask_ $ do | |
join $ atomically $ do | |
st <- readTVar tv | |
case st of | |
Unlocked -> do | |
writeTVar tv (Locked Seq.empty) | |
return $ return () | |
Locked ws -> do | |
w <- newEmptyTMVar | |
writeTVar tv (Locked (ws Seq.|> w)) | |
return $ wait w | |
where | |
wait :: TMVar () -> IO () | |
wait w = do | |
let cleanup :: IO () | |
cleanup = atomically $ do | |
m <- tryTakeTMVar w | |
case m of | |
Just _ -> do | |
-- this thread is already scheduled | |
releaseLockSTM lock | |
Nothing -> do | |
-- this thread is not scheduled yet | |
st <- readTVar tv | |
case st of | |
Unlocked -> undefined | |
Locked ws -> writeTVar tv (Locked (Seq.filter (w /=) ws)) | |
atomically (takeTMVar w) `onException` cleanup | |
releaseLock :: Lock -> IO () | |
releaseLock lock = atomically $ releaseLockSTM lock | |
releaseLockSTM :: Lock -> STM () | |
releaseLockSTM (FairLock tv) = do | |
st <- readTVar tv | |
case st of | |
Unlocked -> undefined | |
Locked ws -> | |
case Seq.viewl ws of | |
Seq.EmptyL -> writeTVar tv Unlocked | |
w Seq.:< ws' -> do | |
putTMVar w () -- should not block | |
writeTVar tv (Locked ws') | |
withLock :: Lock -> IO a -> IO a | |
withLock lock = bracket_ (acquireLock lock) (releaseLock lock) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
module Main where | |
import Control.Concurrent | |
import Data.IORef | |
import FairLock | |
import Test.Tasty | |
import Test.Tasty.HUnit | |
main :: IO () | |
main = defaultMain tests | |
tests :: TestTree | |
tests = testGroup "Tests" [unitTests] | |
unitTests :: TestTree | |
unitTests = testGroup "Unit tests" | |
[ testCase "lock" case_lock | |
, testCase "fifo" case_fifo | |
, testCase "exception" case_exception | |
] | |
case_lock :: Assertion | |
case_lock = do | |
lock <- newLock | |
ref <- newIORef ([] :: [Int]) | |
events <- withLock lock $ do | |
_ <- forkIO $ withLock lock $ modifyIORef ref (1:) | |
threadDelay (40*1000) | |
readIORef ref | |
events @?= [] | |
threadDelay (40*1000) | |
events2 <- readIORef ref | |
events2 @?= [1] | |
case_fifo :: Assertion | |
case_fifo = do | |
lock <- newLock | |
ref <- newIORef ([] :: [Int]) | |
withLock lock $ do | |
_ <- forkIO $ do | |
withLock lock $ modifyIORef ref (1:) | |
_ <- forkIO $ do | |
threadDelay (40*1000) | |
withLock lock $ modifyIORef ref (2:) | |
_ <- forkIO $ do | |
threadDelay (80*1000) | |
withLock lock $ modifyIORef ref (3:) | |
threadDelay (120*1000) | |
threadDelay (40*1000) | |
events <- readIORef ref | |
events @?= [3,2,1] | |
case_exception :: Assertion | |
case_exception = do | |
lock <- newLock | |
ref <- newIORef ([] :: [Int]) | |
acquireLock lock | |
_ <- forkIO $ do | |
withLock lock $ modifyIORef ref (1:) | |
th2 <- forkIO $ do | |
threadDelay (40*1000) | |
withLock lock $ modifyIORef ref (2:) | |
_ <- forkIO $ do | |
threadDelay (80*1000) | |
withLock lock $ modifyIORef ref (3:) | |
threadDelay (120*1000) | |
killThread th2 | |
releaseLock lock | |
threadDelay (40*1000) | |
events <- readIORef ref | |
events @?= [3,1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment