Skip to content

Instantly share code, notes, and snippets.

@andrevdm
Last active May 25, 2022 18:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andrevdm/d564d7261eafdfded27db923c8b28cea to your computer and use it in GitHub Desktop.
Save andrevdm/d564d7261eafdfded27db923c8b28cea to your computer and use it in GitHub Desktop.
Haskell RabbitMQ wrapper. Should handle channel and connection failures. Includes e.g. RPC with timeouts
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE RankNTypes #-}
module Rabbit
( defaultOpts
, mkConnection
, runTopicConsumer
, createTopicPublisher
, runFanoutConsumer
, createFanoutPublisher
, callRpc
, handleRpc
, createRpcCaller
, Connection
, Channel
, Opts(..)
, RpcResult(..)
) where
import Verset
import qualified Data.ByteString.Lazy as BSL
import qualified Data.UUID.V4 as UU
import Control.Concurrent.STM (atomically)
import qualified Control.Concurrent.STM.TVar as Tv
import qualified Control.Concurrent.MVar as Mv
import Control.Exception.Safe (catch, throwM, catches, throwString, fromException, Handler(..))
import qualified Data.Map.Strict as Map
import qualified Network.AMQP as Rb
import qualified Network.AMQP.Types as Rb
data ConnectionState = ConnectionState
{ csConnection :: !Rb.Connection
, csTerminate :: !Bool
}
data Connection = Connection
{ conState :: !(Mv.MVar ConnectionState)
, conOpts :: !Opts
, conRabbitOpts :: !Rb.ConnectionOpts
, conChannels :: !(Tv.TVar (Map Int Channel))
, conNextId :: !(Tv.TVar Int)
}
data Opts = Opts
{ opRetyConnectionBackoffMilliseconds :: ![Int]
, opRetyConnectionMaxAttempts :: !(Maybe Int)
}
data Channel = Channel
{ chConnection :: !Connection
, chOnRestart :: !(Channel -> IO ())
, chRabbitChan :: !(Tv.TVar Rb.Channel)
, chId :: !Int
}
defaultOpts :: Opts
defaultOpts =
Opts
{ opRetyConnectionBackoffMilliseconds = [100, 250, 475, 813, 1_319, 2_078, 3_217, 4_926] -- ^ milliseconds to delay between connection attempts
, opRetyConnectionMaxAttempts = Just 2_000 -- ^ How many times to retry connection. Nothing = never
}
-- | Try create a rabbit connection, can wait for server to be up
tryNewRabbitConnection :: Rb.ConnectionOpts -> Opts -> Int -> Int -> IO Rb.Connection
tryNewRabbitConnection ropts copts retriesLeft atRetry = do
-- Loop until connected
catch
(Rb.openConnection'' ropts)
(\(ex::SomeException) ->
if retriesLeft <= 0
then throwM ex
else do
let
backoff = opRetyConnectionBackoffMilliseconds copts
-- Largest possible back off, default of 200ms if none was set
largestBackoffMs = fromMaybe 200 . lastMay $ backoff
-- Delay for current number of retries
delay = 1_000 * fromMaybe largestBackoffMs (atMay backoff atRetry)
threadDelay delay
tryNewRabbitConnection ropts copts (retriesLeft - 1) (atRetry + 1)
)
mkConnection :: Rb.ConnectionOpts -> Opts -> IO Connection
mkConnection ropts copts = do
rconn <- tryNewRabbitConnection ropts copts (fromMaybe 0 $ opRetyConnectionMaxAttempts copts) 0
cs <- Mv.newMVar $
ConnectionState
{ csConnection = rconn
, csTerminate = False
}
channels <- Tv.newTVarIO mempty
nextId <- Tv.newTVarIO 0
let conn =
Connection
{ conState = cs
, conOpts = copts
, conRabbitOpts = ropts
, conChannels = channels
, conNextId = nextId
}
Rb.addConnectionClosedHandler rconn True (onConnectionClosed conn)
pure conn
onConnectionClosed :: Connection -> IO ()
onConnectionClosed conn = do
csOld <- Mv.takeMVar $ conState conn
if csTerminate csOld
then pass
else do
let
copts = conOpts conn
ropts = conRabbitOpts conn
rconn <- tryNewRabbitConnection ropts copts (fromMaybe 0 $ opRetyConnectionMaxAttempts copts) 0
let csNew =
ConnectionState
{ csConnection = rconn
, csTerminate = False
}
Mv.putMVar (conState conn) csNew
Rb.addConnectionClosedHandler rconn True (onConnectionClosed conn)
chans <- Tv.readTVarIO $ conChannels conn
for_ (Map.elems chans) $ \ch -> do
rch <- Rb.openChannel rconn
atomically $ Tv.writeTVar (chRabbitChan ch) rch
chOnRestart ch ch
mkChannel :: Connection -> Maybe (Channel -> IO ()) -> IO Channel
mkChannel conn onChanRestart = do
id <- atomically $ Tv.stateTVar (conNextId conn) $ \i -> (i + 1, i + 1)
rconn <- getRabbitConnection conn
rch <- Rb.openChannel rconn
rch' <- Tv.newTVarIO rch
let ch =
Channel
{ chConnection = conn
, chRabbitChan = rch'
, chId = id
, chOnRestart = fromMaybe (const pass) onChanRestart
}
atomically $ Tv.modifyTVar' (conChannels conn) (Map.insert id ch)
Rb.addChannelExceptionHandler rch (onChanExcept rconn ch)
pure ch
where
onChanExcept rconn ch ex = do
if isExpectedChannelCloseEx ex
then pass
else do
rch <- Rb.openChannel rconn
atomically $ Tv.writeTVar (chRabbitChan ch) rch
Rb.addChannelExceptionHandler rch (onChanExcept rconn ch)
chOnRestart ch ch
withRestartableChannel :: Connection -> (Channel -> IO ()) -> IO Channel
withRestartableChannel conn fn = do
ch <- mkChannel conn (Just fn)
fn ch
pure ch
getRabbitChannel :: Channel -> IO Rb.Channel
getRabbitChannel = Tv.readTVarIO . chRabbitChan
getRabbitConnection :: Connection -> IO Rb.Connection
getRabbitConnection con = do
cs <- Mv.readMVar $ conState con
if csTerminate cs
then throwString "Connection was terminated"
else pure $ csConnection cs
isExpectedChannelCloseEx :: SomeException -> Bool
isExpectedChannelCloseEx e =
case fromException e :: Maybe Rb.AMQPException of
Just (Rb.ChannelClosedException Rb.Normal _) -> True
_ -> False
------------------------------------------------------------------------------------------------------------------------------
-- Topic consumer
------------------------------------------------------------------------------------------------------------------------------
runTopicConsumer :: Connection -> Text -> Text -> Text -> (Rb.Message -> IO ()) -> IO ()
runTopicConsumer conn exchangeName queueName routingExpr cfn = void . withRestartableChannel conn $ \ch -> do
let ex = Rb.newExchange { Rb.exchangeName = exchangeName
, Rb.exchangeType = "topic"
, Rb.exchangeDurable = True
, Rb.exchangeAutoDelete = False
}
rch <- getRabbitChannel ch
Rb.qos rch 0 1 True
Rb.declareExchange rch ex
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = queueName
, Rb.queueAutoDelete = False
, Rb.queueDurable = True
, Rb.queueHeaders = Rb.FieldTable (Map.singleton "x-queue-mode" $ Rb.FVString "lazy")
}
Rb.bindQueue rch queue exchangeName routingExpr
_ <- Rb.consumeMsgs rch queue Rb.Ack safeGet
pass
where
safeGet (msg, env) = do
catches
(cfn msg >> Rb.ackEnv env)
[ Handler $ \(e::Rb.ChanThreadKilledException) -> throwM e
, Handler $ \(e::SomeException) -> print e -- TODO do something, but don't rethrow
]
------------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------------
-- Topic publisher
------------------------------------------------------------------------------------------------------------------------------
newtype RoutingKey = RoutingKey Text deriving (Show, Eq)
createTopicPublisher :: Connection -> Text -> Text -> IO (BSL.ByteString -> IO ())
createTopicPublisher conn exchangeName routingKey = do
ch' <- withRestartableChannel conn $ \ch -> do
let ex = Rb.newExchange { Rb.exchangeName = exchangeName
, Rb.exchangeType = "topic"
, Rb.exchangeDurable = True
, Rb.exchangeAutoDelete = False
}
rch <- getRabbitChannel ch
Rb.qos rch 0 1 True
Rb.declareExchange rch ex
pure $ \msg -> trySend ch' msg
where
trySend ch' msg = do
rch <- getRabbitChannel ch'
void $ Rb.publishMsg rch exchangeName routingKey
(Rb.newMsg { Rb.msgBody = msg
, Rb.msgDeliveryMode = Just Rb.Persistent
})
------------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------------
-- Fanout consumer
------------------------------------------------------------------------------------------------------------------------------
runFanoutConsumer :: Connection -> Text -> (Rb.Message -> IO ()) -> IO ()
runFanoutConsumer conn exchangeName cfn = void . withRestartableChannel conn $ \ch -> do
let ex = Rb.newExchange { Rb.exchangeName = exchangeName
, Rb.exchangeType = "fanout"
, Rb.exchangeDurable = True
, Rb.exchangeAutoDelete = False
}
rch <- getRabbitChannel ch
Rb.qos rch 0 1 True
Rb.declareExchange rch ex
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = ""
, Rb.queueAutoDelete = True
, Rb.queueDurable = False
, Rb.queueExclusive = True
, Rb.queueHeaders = Rb.FieldTable (Map.singleton "x-queue-mode" $ Rb.FVString "lazy")
}
Rb.bindQueue rch queue exchangeName ""
_ <- Rb.consumeMsgs rch queue Rb.Ack safeGet
pass
where
safeGet (msg, env) = do
catches
(cfn msg >> Rb.ackEnv env)
[ Handler $ \(e::Rb.ChanThreadKilledException) -> throwM e
, Handler $ \(e::SomeException) -> print e -- TODO do something, but don't rethrow
]
------------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------------
-- Fanout publisher
------------------------------------------------------------------------------------------------------------------------------
createFanoutPublisher :: Connection -> Text -> IO (BSL.ByteString -> IO ())
createFanoutPublisher conn exchangeName = do
ch' <- withRestartableChannel conn $ \ch -> do
let ex = Rb.newExchange { Rb.exchangeName = exchangeName
, Rb.exchangeType = "fanout"
, Rb.exchangeDurable = True
, Rb.exchangeAutoDelete = False
}
rch <- getRabbitChannel ch
Rb.qos rch 0 1 True
Rb.declareExchange rch ex
pure $ \msg -> trySend ch' msg
where
trySend ch' msg = do
rch <- getRabbitChannel ch'
void $ Rb.publishMsg rch exchangeName ""
(Rb.newMsg { Rb.msgBody = msg
, Rb.msgDeliveryMode = Just Rb.Persistent
})
------------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------------
-- RPC Handler
------------------------------------------------------------------------------------------------------------------------------
handleRpc :: Connection -> Text -> (Rb.Message -> IO BSL.ByteString) -> IO ()
handleRpc conn rpcName cfn = void . withRestartableChannel conn $ \ch -> do
rch <- getRabbitChannel ch
Rb.qos rch 0 1 True
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = rpcName
, Rb.queueAutoDelete = True
, Rb.queueDurable = False
}
void $ Rb.consumeMsgs rch queue Rb.NoAck (safeReply rch)
where
safeReply rch (msg, _env) = do
catches
(do
resp <- cfn msg
let reply = Rb.newMsg { Rb.msgCorrelationID = Rb.msgCorrelationID msg
, Rb.msgBody = resp
}
void $ Rb.publishMsg rch "" (fromMaybe "?" $ Rb.msgReplyTo msg) reply
)
[ Handler $ \(e::Rb.ChanThreadKilledException) -> throwM e
, Handler $ \(e::SomeException) -> print e -- TODO do something, but don't rethrow
]
------------------------------------------------------------------------------------------------------------------------------
data RpcResult
= RpcOk BSL.ByteString
| RpcCallerTimeout
| RpcHandlerTimeout
deriving (Show, Eq)
------------------------------------------------------------------------------------------------------------------------------
-- Call a RPC service
-- see createRpcCaller
------------------------------------------------------------------------------------------------------------------------------
callRpc :: Connection -> NominalDiffTime -> Text -> BSL.ByteString -> IO RpcResult
callRpc conn timeout rpcName msg = do
ch' <- mkChannel conn Nothing
rch <- getRabbitChannel ch'
Rb.qos rch 0 1 True
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = ""
, Rb.queueAutoDelete = True
, Rb.queueDurable = False
, Rb.queueExclusive = True
}
wait <- Mv.newEmptyMVar
void $ Rb.consumeMsgs rch queue Rb.NoAck $ \(reply, _) -> do
void . Mv.tryPutMVar wait . RpcOk $ Rb.msgBody reply
void . forkIO $ do
threadDelay $ fst (properFraction timeout) * 1_000_000
void . Mv.tryPutMVar wait $ RpcCallerTimeout
void $ Rb.publishMsg rch "" rpcName
(Rb.newMsg { Rb.msgBody = msg
, Rb.msgDeliveryMode = Just Rb.NonPersistent
, Rb.msgReplyTo = Just queue
})
Mv.takeMVar wait
------------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------------
-- Create a RPC caller
-- this is a long running caller. Similar to callRpc but more light weight as you are not
-- create a queue for each call.
-- Use this if you are going to make multiple RPC calls, use callRpc for simple ad-hoc calls
------------------------------------------------------------------------------------------------------------------------------
createRpcCaller :: Connection -> NominalDiffTime -> Text -> IO (BSL.ByteString -> IO RpcResult)
createRpcCaller conn timeout rpcName = do
-- Map of wait handles
waits' <- Tv.newTVarIO mempty
-- Name of RPC queue
queueName' <- Mv.newEmptyMVar
void . withRestartableChannel conn $ \ch' -> do
rch <- getRabbitChannel ch'
Rb.qos rch 0 1 True
-- Create the RPC queue
(queue, _, _) <- Rb.declareQueue rch Rb.newQueue { Rb.queueName = ""
, Rb.queueAutoDelete = True
, Rb.queueDurable = False
, Rb.queueExclusive = True
}
-- Save the queue name, this will change on reconnect
void $ Mv.tryTakeMVar queueName'
void $ Mv.putMVar queueName' queue
-- Consume responses to the RPC queue
void $ Rb.consumeMsgs rch queue Rb.NoAck $ \(reply, _) -> do
-- correlation Id
let id = fromMaybe "?" . Rb.msgCorrelationID $ reply
ackCorrelationId waits' id . RpcOk $ Rb.msgBody reply
-- Channel for the call
chSend' <- mkChannel conn Nothing
rchSend <- getRabbitChannel chSend'
pure $ \request -> do
-- new correlation id
id <- show <$> UU.nextRandom
-- new wait handle
wait <- Mv.newEmptyMVar
atomically $ Tv.modifyTVar' waits' $ \ws -> Map.insert id wait ws
-- get the current rpc queue name
queueName <- Mv.readMVar queueName'
-- send the request
void $ Rb.publishMsg rchSend "" rpcName
(Rb.newMsg { Rb.msgBody = request
, Rb.msgDeliveryMode = Just Rb.NonPersistent
, Rb.msgReplyTo = Just queueName
, Rb.msgCorrelationID = Just id
})
void . forkIO $ do
threadDelay $ fst (properFraction timeout) * 1_000_000
void . Mv.tryPutMVar wait $ RpcCallerTimeout
-- wait for the response
Mv.takeMVar wait
where
ackCorrelationId waits' id result = do
-- try find the wait handle for this correlation id
waitRes <- atomically . Tv.stateTVar waits' $ \waits ->
case Map.lookup id waits of
Nothing -> (Left $ "unknown correlation Id: " <> id, waits)
Just w -> (Right w, Map.delete id waits)
case waitRes of
Right wait -> void . Mv.tryPutMVar wait $ result
Left e -> print e --TODO
------------------------------------------------------------------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment