Created
July 20, 2020 18:56
-
-
Save WJWH/7867b0726fa6d667bb9c456c52bc303c to your computer and use it in GitHub Desktop.
Not segfaulting program
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE OverloadedStrings, BangPatterns, GeneralizedNewtypeDeriving, CPP #-} | |
module Main where | |
#include <sys/epoll.h> | |
import Foreign.C.Error (eNOENT, getErrno, throwErrno, | |
throwErrnoIfMinus1, throwErrnoIfMinus1_) | |
import Foreign.C.Types | |
import Foreign.Marshal hiding (void, newArray) | |
import Foreign.Storable | |
import Foreign.Ptr | |
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) | |
import GHC.ForeignPtr (mallocPlainForeignPtrBytes, newForeignPtr_) | |
import System.Posix.IO | |
import System.Posix.Types | |
import System.Posix.Internals (c_close) | |
import System.Posix.Internals (setCloseOnExec) | |
import Control.Concurrent | |
import Control.Concurrent.MVar | |
import qualified Control.Exception as E | |
import Control.Monad (forever, void, when) | |
import Data.Bits (Bits, FiniteBits, (.|.), (.&.), shiftL, shiftR) | |
import Data.IORef (IORef, atomicModifyIORef', newIORef, readIORef, writeIORef) | |
import Data.Word (Word32) | |
import qualified Data.Map as M | |
import Debug.Trace | |
import Network.Socket | |
import Network.Socket.ByteString (recv, sendAll) | |
resp = "HTTP/1.0 200 OK\n\n" | |
main :: IO () | |
main = do | |
epollFd <- epollCreate | |
ringLock <- newMVar (epollFd, M.empty) | |
evtArray <- newArray 512 -- :: Array Event | |
forkIO $ runTCPServer ringLock Nothing "3000" (talk ringLock) | |
forever $ do | |
n <- unsafeLoad evtArray $ \es cap -> epollWait epollFd es cap (-1) | |
when (n > 0) $ do | |
(_, callbackmap) <- readMVar ringLock | |
forM_ evtArray $ \e -> case M.lookup (eventFd e) callbackmap of | |
Nothing -> error "lost callback" | |
Just mv -> void $ tryPutMVar mv () | |
talk :: MVar (EPollFd,M.Map Fd (MVar ())) -> Socket -> IO () | |
talk l s = do | |
fd <- getFdFromSocket s | |
waitForReadable l fd | |
msg <- recv s 1024 | |
waitForWritable l fd | |
sendAll s resp | |
-- from the "network-run" package. | |
runTCPServer :: MVar (EPollFd,M.Map Fd (MVar ())) -> Maybe HostName -> ServiceName -> (Socket -> IO a) -> IO a | |
runTCPServer lock mhost port server = withSocketsDo $ do | |
addr <- resolve | |
E.bracket (open addr) close loop | |
where | |
resolve = do | |
let hints = defaultHints { | |
addrFlags = [AI_PASSIVE] | |
, addrSocketType = Stream | |
} | |
head <$> getAddrInfo (Just hints) mhost (Just port) | |
open addr = do | |
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) | |
setSocketOption sock ReuseAddr 1 | |
withFdSocket sock setCloseOnExecIfNeeded | |
bind sock $ addrAddress addr | |
listen sock 1024 | |
return sock | |
loop sock = forever $ do | |
getFdFromSocket sock >>= waitForReadable lock | |
(conn, _peer) <- accept sock | |
void $ forkFinally (server conn) (const $ gracefulClose conn 5000) | |
-- super magic values sourced from epoll.h | |
ePOLLIN = 1 :: EventType | |
ePOLLOUT = 4 :: EventType | |
ePOLLONESHOT = shiftL 30 1 :: EventType | |
ePOLL_CTL_ADD = 1 :: ControlOp | |
ePOLL_CTL_DEL = 2 :: ControlOp | |
ePOLL_CTL_MOD = 3 :: ControlOp | |
waitForReadable :: MVar (EPollFd,M.Map Fd (MVar ())) -> Fd -> IO () | |
waitForReadable lock fd = waitFor ePOLLIN lock fd | |
waitForWritable :: MVar (EPollFd,M.Map Fd (MVar ())) -> Fd -> IO () | |
waitForWritable lock fd = waitFor ePOLLOUT lock fd | |
waitFor :: EventType -> MVar (EPollFd,M.Map Fd (MVar ())) -> Fd -> IO () | |
waitFor evt lock fd = do | |
(epollFd, callbackmap) <- takeMVar lock | |
mv <- newEmptyMVar | |
let newMap = M.insert fd mv callbackmap | |
modifyFdOnce epollFd fd evt | |
putMVar lock (epollFd, newMap) | |
takeMVar mv | |
getFdFromSocket :: Socket -> IO Fd | |
getFdFromSocket s = Fd <$> unsafeFdSocket s | |
newtype EPollFd = EPollFd { | |
fromEPollFd :: CInt | |
} deriving (Eq, Show) | |
data Event = Event { | |
eventTypes :: EventType | |
, eventFd :: Fd | |
} deriving (Show) | |
-- -- | @since 4.3.1.0 | |
instance Storable Event where | |
sizeOf _ = #size struct epoll_event | |
alignment _ = alignment (undefined :: CInt) | |
peek ptr = do | |
ets <- #{peek struct epoll_event, events} ptr | |
ed <- #{peek struct epoll_event, data.fd} ptr | |
let !ev = Event (EventType ets) ed | |
return ev | |
poke ptr e = do | |
#{poke struct epoll_event, events} ptr (unEventType $ eventTypes e) | |
#{poke struct epoll_event, data.fd} ptr (eventFd e) | |
newtype ControlOp = ControlOp CInt | |
deriving ( Show -- ^ @since 4.4.0.0 | |
, Eq -- ^ @since 4.4.0.0 | |
, Num -- ^ @since 4.4.0.0 | |
, Bits -- ^ @since 4.4.0.0 | |
, FiniteBits -- ^ @since 4.7.0.0 | |
) | |
newtype EventType = EventType { | |
unEventType :: Word32 | |
} deriving ( Show -- ^ @since 4.4.0.0 | |
, Eq -- ^ @since 4.4.0.0 | |
, Num -- ^ @since 4.4.0.0 | |
, Bits -- ^ @since 4.4.0.0 | |
, FiniteBits -- ^ @since 4.7.0.0 | |
) | |
modifyFdOnce :: EPollFd -> Fd -> EventType -> IO Bool | |
modifyFdOnce ep fd evt = | |
do let !ev = evt .|. ePOLLONESHOT | |
res <- with (Event ev fd) $ | |
epollControl_ ep ePOLL_CTL_MOD fd | |
if res == 0 | |
then return True | |
else do err <- getErrno | |
if err == eNOENT | |
then with (Event ev fd) $ \evptr -> do | |
epollControl ep ePOLL_CTL_ADD fd evptr | |
return True | |
else throwErrno "modifyFdOnce" | |
epollCreate :: IO EPollFd | |
epollCreate = do | |
fd <- throwErrnoIfMinus1 "epollCreate" $ | |
c_epoll_create 256 -- argument is ignored | |
setCloseOnExec fd | |
let !epollFd' = EPollFd fd | |
return epollFd' | |
epollControl :: EPollFd -> ControlOp -> Fd -> Ptr Event -> IO () | |
epollControl epfd op fd event = | |
throwErrnoIfMinus1_ "epollControl" $ epollControl_ epfd op fd event | |
epollControl_ :: EPollFd -> ControlOp -> Fd -> Ptr Event -> IO CInt | |
epollControl_ (EPollFd epfd) (ControlOp op) (Fd fd) event = | |
c_epoll_ctl epfd op fd event | |
epollWait :: EPollFd -> Ptr Event -> Int -> Int -> IO Int | |
epollWait (EPollFd epfd) events numEvents timeout = | |
fromIntegral <$> c_epoll_wait epfd events (fromIntegral numEvents) (fromIntegral timeout) | |
foreign import ccall unsafe "sys/epoll.h epoll_create" | |
c_epoll_create :: CInt -> IO CInt | |
foreign import ccall unsafe "sys/epoll.h epoll_ctl" | |
c_epoll_ctl :: CInt -> CInt -> CInt -> Ptr Event -> IO CInt | |
foreign import ccall safe "sys/epoll.h epoll_wait" | |
c_epoll_wait :: CInt -> Ptr Event -> CInt -> CInt -> IO CInt | |
-- array related stuff | |
-- Invariant: size <= capacity | |
newtype Array a = Array (IORef (AC a)) | |
-- The actual array content. | |
data AC a = AC | |
!(ForeignPtr a) -- Elements | |
!Int -- Number of elements (length) | |
!Int -- Maximum number of elements (capacity) | |
allocArray :: Storable a => Int -> IO (ForeignPtr a) | |
allocArray n = allocHack undefined | |
where | |
allocHack :: Storable a => a -> IO (ForeignPtr a) | |
allocHack dummy = mallocPlainForeignPtrBytes (n * sizeOf dummy) | |
newArray :: Storable a => Int -> IO (Array a) | |
newArray c = do | |
es <- allocArray cap | |
fmap Array (newIORef (AC es 0 cap)) | |
where | |
cap = firstPowerOf2 c | |
unsafeLoad :: Array a -> (Ptr a -> Int -> IO Int) -> IO Int | |
unsafeLoad (Array ref) load = do | |
AC es _ cap <- readIORef ref | |
len' <- withForeignPtr es $ \p -> load p cap | |
writeIORef ref (AC es len' cap) | |
return len' | |
forM_ :: Storable a => Array a -> (a -> IO ()) -> IO () | |
forM_ ary g = forHack ary g undefined | |
where | |
forHack :: Storable b => Array b -> (b -> IO ()) -> b -> IO () | |
forHack (Array ref) f dummy = do | |
AC es len _ <- readIORef ref | |
let size = sizeOf dummy | |
offset = len * size | |
withForeignPtr es $ \p -> do | |
let go n | n >= offset = return () | |
| otherwise = do | |
f =<< peek (p `plusPtr` n) | |
go (n + size) | |
go 0 | |
firstPowerOf2 :: Int -> Int | |
firstPowerOf2 !n = | |
let !n1 = n - 1 | |
!n2 = n1 .|. (n1 `shiftR` 1) | |
!n3 = n2 .|. (n2 `shiftR` 2) | |
!n4 = n3 .|. (n3 `shiftR` 4) | |
!n5 = n4 .|. (n4 `shiftR` 8) | |
!n6 = n5 .|. (n5 `shiftR` 16) | |
!n7 = n6 .|. (n6 `shiftR` 32) | |
in n7 + 1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Compile with
hsc2hs segfault_epoll.hsc
and thenghc -O2 -threaded -debug -rtsopts -with-rtsopts=-N -with-rtsopts=-DS segfault_epoll.hs
. It does not segfault when hit withab -n 50000 -c 500
.