Created
January 1, 2020 00:17
-
-
Save lylek/f30f6bd4d7a6898511cfa60e30fe71c4 to your computer and use it in GitHub Desktop.
My attempt at making a peer-to-peer version of the distributed chat server from Concurrent and Parallel Programming in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TemplateHaskell, DeriveDataTypeable, DeriveGeneric #-} | |
{-# LANGUAGE RecordWildCards #-} | |
{-# LANGUAGE LambdaCase #-} | |
import Control.Distributed.Process | |
hiding (Message, mask, finally, handleMessage, proxy) | |
import Control.Distributed.Process.Backend.SimpleLocalnet (Backend, findPeers, initializeBackend, newLocalNode) | |
import Control.Distributed.Process.Closure | |
import Control.Distributed.Process.Node (initRemoteTable, runProcess) | |
import Control.Concurrent.Async | |
import Control.Monad.IO.Class | |
import Control.Monad | |
import Text.Printf | |
import Control.Concurrent | |
import GHC.Generics (Generic) | |
import Data.Binary | |
import Data.Typeable | |
import Network | |
import System.IO | |
import Control.Concurrent.STM | |
import qualified Data.Map as Map | |
import Data.Map (Map) | |
import qualified Data.Foldable as F | |
import Control.Exception | |
import System.Environment (getArgs) | |
import ConcurrentUtils | |
-- --------------------------------------------------------------------------- | |
-- Data structures and initialisation | |
-- <<Client | |
type ClientName = String | |
data Client | |
= ClientLocal LocalClient | |
| ClientRemote RemoteClient | |
data RemoteClient = RemoteClient | |
{ remoteName :: ClientName | |
, clientHome :: ProcessId | |
} | |
data LocalClient = LocalClient | |
{ localName :: ClientName | |
, clientHandle :: Handle | |
, clientKicked :: TVar (Maybe String) | |
, clientSendChan :: TChan Message | |
} | |
clientName :: Client -> ClientName | |
clientName (ClientLocal c) = localName c | |
clientName (ClientRemote c) = remoteName c | |
newLocalClient :: ClientName -> Handle -> STM LocalClient | |
newLocalClient name handle = do | |
c <- newTChan | |
k <- newTVar Nothing | |
return LocalClient { localName = name | |
, clientHandle = handle | |
, clientSendChan = c | |
, clientKicked = k | |
} | |
getLocalClientNames :: Server -> STM [ClientName] | |
getLocalClientNames Server{..} = do | |
clientmap <- readTVar clients | |
return [name | (ClientLocal (LocalClient name _ _ _)) <- Map.elems clientmap] | |
-- >> | |
-- <<Message | |
data Message = Notice String | |
| Tell ClientName String | |
| Broadcast ClientName String | |
| Command String | |
deriving (Typeable, Generic, Show) | |
instance Binary Message | |
-- >> | |
-- <<PMessage | |
data PMessage | |
= MsgServerInfo Bool ProcessId [ClientName] | |
| MsgSend ClientName Message | |
| MsgBroadcast Message | |
| MsgKick ClientName ClientName | |
| MsgNewClient ClientName ProcessId | |
| MsgClientDisconnected ClientName ProcessId | |
deriving (Typeable, Generic, Show) | |
instance Binary PMessage | |
-- >> | |
-- <<Server | |
data Server = Server | |
{ clients :: TVar (Map ClientName Client) | |
, proxychan :: TChan (Process ()) | |
, servers :: TVar [ProcessId] | |
, spid :: ProcessId | |
} | |
newServer :: [ProcessId] -> Process Server | |
newServer pids = do | |
pid <- getSelfPid | |
liftIO $ do | |
s <- newTVarIO pids | |
c <- newTVarIO Map.empty | |
o <- newTChanIO | |
return Server { clients = c, servers = s, proxychan = o, spid = pid } | |
-- >> | |
-- ----------------------------------------------------------------------------- | |
-- Basic operations | |
-- <<sendLocal | |
sendLocal :: LocalClient -> Message -> STM () | |
sendLocal LocalClient{..} msg = writeTChan clientSendChan msg | |
-- >> | |
-- <<sernd | |
sendRemote :: Server -> ProcessId -> PMessage -> STM () | |
sendRemote Server{..} pid pmsg = writeTChan proxychan (send pid pmsg) | |
-- >> | |
-- <<sendMessage | |
sendMessage :: Server -> Client -> Message -> STM () | |
sendMessage server (ClientLocal client) msg = | |
sendLocal client msg | |
sendMessage server (ClientRemote client) msg = | |
sendRemote server (clientHome client) (MsgSend (remoteName client) msg) | |
-- >> | |
-- <<sendToName | |
sendToName :: Server -> ClientName -> Message -> STM Bool | |
sendToName server@Server{..} name msg = do | |
clientmap <- readTVar clients | |
case Map.lookup name clientmap of | |
Nothing -> return False | |
Just client -> sendMessage server client msg >> return True | |
-- >> | |
-- <<sendRemoteAll | |
sendRemoteAll :: Server -> PMessage -> STM () | |
sendRemoteAll server@Server{..} pmsg = do | |
pids <- readTVar servers | |
mapM_ (\pid -> sendRemote server pid pmsg) pids | |
-- >> | |
-- <<broadcastLocal | |
broadcastLocal :: Server -> Message -> STM () | |
broadcastLocal server@Server{..} msg = do | |
clientmap <- readTVar clients | |
mapM_ sendIfLocal (Map.elems clientmap) | |
where | |
sendIfLocal (ClientLocal c) = sendLocal c msg | |
sendIfLocal (ClientRemote _) = return () | |
-- >> | |
-- <<broadcast | |
broadcast :: Server -> Message -> STM () | |
broadcast server@Server{..} msg = do | |
sendRemoteAll server (MsgBroadcast msg) | |
broadcastLocal server msg | |
-- >> | |
-- <<tell | |
tell :: Server -> LocalClient -> ClientName -> String -> IO () | |
tell server@Server{..} LocalClient{..} who msg = do | |
ok <- atomically $ sendToName server who (Tell localName msg) | |
if ok | |
then return () | |
else hPutStrLn clientHandle (who ++ " is not connected.") | |
-- >> | |
-- <<kick | |
kick :: Server -> ClientName -> ClientName -> STM () | |
kick server@Server{..} who by = do | |
clientmap <- readTVar clients | |
case Map.lookup who clientmap of | |
Nothing -> | |
void $ sendToName server by (Notice $ who ++ " is not connected") | |
Just (ClientLocal victim) -> do | |
writeTVar (clientKicked victim) $ Just ("by " ++ by) | |
void $ sendToName server by (Notice $ "you kicked " ++ who) | |
Just (ClientRemote victim) -> do | |
sendRemote server (clientHome victim) (MsgKick who by) | |
-- >> | |
-- ----------------------------------------------------------------------------- | |
-- Handle a local client | |
talk :: Server -> Handle -> IO () | |
talk server@Server{..} handle = do | |
hSetNewlineMode handle universalNewlineMode | |
-- Swallow carriage returns sent by telnet clients | |
hSetBuffering handle LineBuffering | |
readName | |
where | |
-- <<readName | |
readName = do | |
hPutStrLn handle "What is your name?" | |
name <- hGetLine handle | |
if null name | |
then readName | |
else mask $ \restore -> do | |
client <- atomically $ newLocalClient name handle | |
ok <- atomically $ checkAddClient server (ClientLocal client) | |
if not ok | |
then restore $ do | |
hPrintf handle | |
"The name %s is in use, please choose another\n" name | |
readName | |
else do | |
atomically $ sendRemoteAll server (MsgNewClient name spid) | |
restore (runClient server client) | |
`finally` disconnectLocalClient server name | |
-- >> | |
checkAddClient :: Server -> Client -> STM Bool | |
checkAddClient server@Server{..} client = do | |
clientmap <- readTVar clients | |
let name = clientName client | |
if Map.member name clientmap | |
then return False | |
else do writeTVar clients (Map.insert name client clientmap) | |
broadcastLocal server $ Notice $ name ++ " has connected" | |
return True | |
deleteClient :: Server -> ClientName -> STM () | |
deleteClient server@Server{..} name = do | |
modifyTVar' clients $ Map.delete name | |
broadcastLocal server $ Notice $ name ++ " has disconnected" | |
disconnectLocalClient :: Server -> ClientName -> IO () | |
disconnectLocalClient server@Server{..} name = atomically $ do | |
deleteClient server name | |
sendRemoteAll server (MsgClientDisconnected name spid) | |
-- <<runClient | |
runClient :: Server -> LocalClient -> IO () | |
runClient serv@Server{..} client@LocalClient{..} = do | |
race server receive | |
return () | |
where | |
receive = forever $ do | |
msg <- hGetLine clientHandle | |
atomically $ sendLocal client (Command msg) | |
server = join $ atomically $ do | |
k <- readTVar clientKicked | |
case k of | |
Just reason -> return $ | |
hPutStrLn clientHandle $ "You have been kicked: " ++ reason | |
Nothing -> do | |
msg <- readTChan clientSendChan | |
return $ do | |
continue <- handleMessage serv client msg | |
when continue $ server | |
-- >> | |
-- <<handleMessage | |
handleMessage :: Server -> LocalClient -> Message -> IO Bool | |
handleMessage server client@LocalClient{..} message = | |
case message of | |
Notice msg -> output $ "*** " ++ msg | |
Tell name msg -> output $ "*" ++ name ++ "*: " ++ msg | |
Broadcast name msg -> output $ "<" ++ name ++ ">: " ++ msg | |
Command msg -> | |
case words msg of | |
["/kick", who] -> do | |
atomically $ kick server who localName | |
return True | |
"/tell" : who : what -> do | |
tell server client who (unwords what) | |
return True | |
["/quit"] -> | |
return False | |
('/':_):_ -> do | |
hPutStrLn clientHandle $ "Unrecognised command: " ++ msg | |
return True | |
_ -> do | |
atomically $ broadcast server $ Broadcast localName msg | |
return True | |
where | |
output s = do hPutStrLn clientHandle s; return True | |
-- >> | |
-- ----------------------------------------------------------------------------- | |
-- Main server | |
-- <<socketListener | |
socketListener :: Server -> Int -> IO () | |
socketListener server port = withSocketsDo $ do | |
sock <- listenOn (PortNumber (fromIntegral port)) | |
printf "Listening on port %d\n" port | |
forever $ do | |
(handle, host, port) <- accept sock | |
printf "Accepted connection from %s: %s\n" host (show port) | |
forkFinally (talk server handle) | |
(\_ -> hClose handle) | |
-- >> | |
-- <<proxy | |
proxy :: Server -> Process () | |
proxy Server{..} = forever $ join $ liftIO $ atomically $ readTChan proxychan | |
-- >> | |
-- <<chatServer | |
chatServer :: Int -> Process () | |
chatServer port = do | |
server <- newServer [] | |
liftIO $ forkIO (socketListener server port) -- <1> | |
spawnLocal (proxy server) -- <2> | |
forever $ receiveWait | |
[ match $ handleWhereIsReply server | |
, match $ handleRemoteMessage server | |
, match $ handleProcessMonitorNotification server | |
] | |
-- >> | |
handleWhereIsReply :: Server -> WhereIsReply -> Process () | |
handleWhereIsReply server@Server{..} (WhereIsReply "chatServer" (Just pid)) = do | |
localClientNames <- liftIO $ atomically $ getLocalClientNames server | |
send pid $ MsgServerInfo True spid localClientNames | |
handleWhereIsReply _ (WhereIsReply _ _) = return () | |
handleProcessMonitorNotification :: Server -> ProcessMonitorNotification -> Process () | |
handleProcessMonitorNotification | |
server@Server{..} | |
(ProcessMonitorNotification monitorRef deadpid _diedReason) = do | |
say $ "Got a notification that process " ++ show deadpid ++ " died" | |
unmonitor monitorRef | |
-- Don't care why it died, but remove that server's clients from the clientmap | |
liftIO $ atomically $ do | |
clientmap <- readTVar clients | |
-- Go through every client in the map, check if it belongs to that server | |
let clientmap' = flip Map.filter clientmap $ \case | |
(ClientRemote (RemoteClient _ clientHome)) -> clientHome /= deadpid | |
(ClientLocal _) -> True | |
writeTVar clients clientmap' | |
-- <<handleRemoteMessage | |
handleRemoteMessage :: Server -> PMessage -> Process () | |
handleRemoteMessage server@Server{..} m = | |
case m of | |
MsgServerInfo justJoined rpid rclients -> join $ liftIO $ atomically $ do | |
oldServers <- readTVar servers | |
if elem rpid oldServers -- if server is known | |
then return $ return () | |
else do | |
-- add process id to servers list | |
writeTVar servers (rpid : oldServers) | |
-- Add info about remote clients, kicking when there's a conflict. | |
mapM_ (handleNewClient rpid) rclients | |
-- if the other server just joined, send back our own MsgServerInfo | |
when justJoined $ do | |
localClientNames <- getLocalClientNames server | |
sendRemote server rpid (MsgServerInfo False spid localClientNames) | |
-- start monitoring the remote process | |
return $ do | |
void $ monitor rpid | |
say $ "Starting to monitor process " ++ show rpid | |
MsgSend name msg -> liftIO $ atomically $ void $ sendToName server name msg -- <2> | |
MsgBroadcast msg -> liftIO $ atomically $ broadcastLocal server msg -- <2> | |
MsgKick who by -> liftIO $ atomically $ kick server who by -- <2> | |
MsgNewClient name pid -> liftIO $ atomically $ handleNewClient pid name | |
MsgClientDisconnected name pid -> liftIO $ atomically $ do -- <4> | |
clientmap <- readTVar clients | |
case Map.lookup name clientmap of | |
Nothing -> return () | |
Just (ClientRemote (RemoteClient _ pid')) | pid == pid' -> | |
deleteClient server name | |
Just _ -> | |
return () | |
where | |
handleNewClient pid name = do | |
ok <- checkAddClient server (ClientRemote (RemoteClient name pid)) | |
-- TODO: Should this really kick a new remote client if it doesn't conflict | |
-- with a local client? Maybe let the server that has that client do the kicking. | |
when (not ok) $ | |
sendRemote server pid (MsgKick name "SYSTEM") | |
-- >> | |
remotable [] | |
registeredServerName :: String | |
registeredServerName = "chatServer" | |
--pattern RegisteredServerName :: String | |
--pattern RegisteredServerName = registeredServerName | |
-- <<main | |
master :: Backend -> String -> Process () | |
master backend chat_port_str = do | |
nodes <- liftIO $ findPeers backend 5000 -- timeout 5 sec | |
mynodeid <- getSelfNode | |
let peers = filter (/= mynodeid) nodes | |
mypid <- getSelfPid | |
register registeredServerName mypid | |
forM_ peers $ \node -> whereisRemoteAsync node registeredServerName | |
let chat_port = read chat_port_str | |
chatServer chat_port | |
main = do | |
[port, chat_port] <- getArgs | |
backend <- initializeBackend "localhost" port | |
(Main.__remoteTable initRemoteTable) | |
node <- newLocalNode backend | |
runProcess node (master backend chat_port) | |
-- >> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment