Skip to content

Instantly share code, notes, and snippets.

@rolph-recto
Created December 26, 2017 07:14
Show Gist options
  • Save rolph-recto/5e8ca67f701675103d651795a0c70a89 to your computer and use it in GitHub Desktop.
Save rolph-recto/5e8ca67f701675103d651795a0c70a89 to your computer and use it in GitHub Desktop.
Go embedded in Haskell
{-# LANGUAGE DeriveFunctor #-}
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Except
import Control.Monad.Free
import qualified Data.Map.Strict as M
data GoExpr =
GoInt Int
| GoBool Bool
| GoString String
deriving (Show, Eq)
class GoExprable a where
toGoExpr :: a -> GoExpr
instance GoExprable Int where
toGoExpr = GoInt
instance GoExprable Bool where
toGoExpr = GoBool
instance GoExprable GoExpr where
toGoExpr = id
type GoRoutine = Int
-- (in, out)
-- this allows the input and output channels to differ,
-- which allows us to implement channel combinators
type ChanPort = Int
type GoChan = (ChanPort, ChanPort)
type GoProgram a = Free GoCmd a
data GoSelectBranch =
GoSelectBranch {
branchChan :: GoChan,
branchProg :: GoExpr -> GoProgram ()
}
data GoCmd next =
GoRun (GoProgram ()) next
| GoMakeChan (GoChan -> next)
| GoPutChan GoChan GoExpr next
| GoGetChan GoChan (GoExpr -> next)
| GoSelect [GoSelectBranch] next
| GoPrint String next
deriving (Functor)
liftFree x = Free (fmap Pure x)
defaultPort = 0
selcase c p = GoSelectBranch c p
seldefault p = GoSelectBranch (defaultPort, defaultPort) p
go cmd = liftFree (GoRun cmd ())
newchan = liftFree (GoMakeChan id)
putchan c v = liftFree (GoPutChan c (toGoExpr v) ())
getchan c = liftFree (GoGetChan c id)
select bs = liftFree (GoSelect bs ())
goprint s = liftFree (GoPrint s ())
data GoRuntime =
GoRuntime {
curGoroutine :: GoRoutine,
nextGoroutine :: Int,
goroutines :: M.Map GoRoutine (GoProgram ()),
nextPort :: ChanPort,
portVals :: M.Map ChanPort (Maybe GoExpr),
waitQueues :: M.Map ChanPort [GoRoutine],
readyQueue :: [GoRoutine]
}
type GoInterp = ExceptT String (StateT GoRuntime IO)
scheduleNextGoroutine :: GoInterp ()
scheduleNextGoroutine = do
st <- get
let rq = readyQueue st
case rq of
rid:rs -> do
case M.lookup rid (goroutines st) of
Just cont -> do
put $ st { curGoroutine = rid, readyQueue = rs }
interpGo cont
Nothing -> throwError $ "invalid goroutine id " ++ (show rid)
[] -> return ()
wakeFromWaitQueue :: ChanPort -> GoInterp ()
wakeFromWaitQueue port = do
st <- get
let wqs = waitQueues st
case M.lookup port wqs of
Just (rid:rs) -> do
-- because of select, the woken goroutine might be in several wait queues.
-- we must remove it from all wait queues
let wqs' = M.map (filter (\rid2 -> not (rid == rid2))) wqs
let rq = readyQueue st
let rq' = rq ++ [rid]
put $ st { readyQueue = rq', waitQueues = wqs' }
-- nobody to wake!
Just [] -> return ()
Nothing -> throwError $ "unknown port " ++ (show port)
addToWaitQueue :: ChanPort -> GoRoutine -> GoProgram () -> GoInterp ()
addToWaitQueue port rid cont = do
st <- get
let wqs = waitQueues st
case M.lookup port wqs of
Just wq -> do
-- add to port's wait queue
let wq' = wq ++ [rid]
let wqs' = M.insert port wq' wqs
-- update goroutine program with new instruction stream
let grs = M.insert rid cont (goroutines st)
put $ st { waitQueues = wqs', goroutines = grs }
Nothing -> throwError $ "unknown port " ++ (show port)
interpGo :: GoProgram () -> GoInterp ()
interpGo prog = case prog of
-- create a new goroutine and add it to the back of the ready queue
Free (GoRun goroutine next) -> do
st <- get
let rnum = nextGoroutine st
let rmap = goroutines st
let rmap' = M.insert rnum goroutine rmap
let rq = readyQueue st
let rq' = rq ++ [rnum]
put $ st { nextGoroutine = rnum+1, goroutines = rmap', readyQueue = rq' }
interpGo next
-- create a new channel
Free (GoMakeChan next) -> do
st <- get
let pnum = nextPort st
let pmap = portVals st
let pmap' = M.insert pnum Nothing pmap
let wqs = waitQueues st
let wqs' = M.insert pnum [] wqs
put $ st { nextPort = pnum+1, portVals = pmap', waitQueues = wqs' }
interpGo $ next (pnum, pnum)
-- put a value in a channel
Free (GoPutChan (port,_) v next) -> do
st <- get
let pmap = portVals st
let rid = curGoroutine st
-- block goroutine if channel is full, otherwise put value into channel
case M.lookup port pmap of
Just Nothing -> do
let pmap' = M.insert port (Just v) pmap
put $ st { portVals = pmap' }
wakeFromWaitQueue port
interpGo next
Just _ -> do
addToWaitQueue port rid prog
scheduleNextGoroutine
Nothing -> throwError $ "Channel " ++ (show port) ++ " not found!"
Free (GoGetChan (_,port) next) -> do
st <- get
let pmap = portVals st
let rid = curGoroutine st
-- block goroutine if channel is empty, otherwise get value from channel
case M.lookup port pmap of
Just Nothing -> do
addToWaitQueue port rid prog
scheduleNextGoroutine
Just (Just cval) -> do
let pmap' = M.insert port Nothing pmap
put $ st { portVals = pmap' }
wakeFromWaitQueue port
interpGo $ next cval
Nothing -> throwError $ "Channel " ++ (show port) ++ " not found!"
Free (GoSelect branches next) -> do
st <- get
let rid = curGoroutine st
unblockedBranches <- filterM isBranchUnblocked branches
case unblockedBranches of
-- one of the branches is unblocked. jump to it
b:bs -> do
let bprog = branchProg b
-- notice that we have to add the instruction stream after the
-- branch, otherwise it won't be executed!
interpGo $ Free (GoGetChan (branchChan b) (\v -> (bprog v) >> next))
-- all of the branches are blocked. block the goroutine until one of
-- the branches becomes unblocked
[] -> do
let ports = map (snd . branchChan) branches
forM ports $ \port -> addToWaitQueue port rid prog
scheduleNextGoroutine
where
isBranchUnblocked :: GoSelectBranch -> GoInterp Bool
isBranchUnblocked branch = do
st <- get
let port = snd $ branchChan branch
-- default port is always unblocked
if port == defaultPort
then return True
else do
case M.lookup port (portVals st) of
Just Nothing -> return False
Just (Just cval) -> return True
Nothing -> throwError $ "Channel " ++ (show port) ++ " not found!"
Free (GoPrint e next) -> do
liftIO $ putStrLn e
interpGo next
-- goroutine is finished: run next one
Pure _ -> scheduleNextGoroutine
run_gomain :: GoProgram () -> IO ()
run_gomain main = do
let init = GoRuntime { curGoroutine = 1, nextGoroutine = 2,
goroutines = M.empty, nextPort = 1, portVals = M.empty,
waitQueues = M.empty, readyQueue = [] }
res <- evalStateT (runExceptT (interpGo main)) init
case res of
Left err -> putStrLn err
Right _ -> return ()
-- channel combinators
-- pipe input from one channel and set it as the output of another channel
chan_pipe :: (GoChan -> GoChan -> GoProgram ()) -> GoChan -> GoChan -> GoProgram GoChan
chan_pipe action inchan@(inchan_put, _) outchan@(_, outchan_get) = do
go $ action inchan outchan
return (inchan_put, outchan_get)
-- like chan_pipe, but the output channel is created
chan_transform :: (GoChan -> GoChan -> GoProgram ()) -> GoChan -> GoProgram GoChan
chan_transform action chan = newchan >>= chan_pipe action chan
chan_filter :: (GoExpr -> Bool) -> GoChan -> GoProgram GoChan
chan_filter pred = do
chan_transform filter_goroutine
where filter_goroutine inchan outchan = do
val <- getchan inchan
if pred val
then do
putchan outchan val
filter_goroutine inchan outchan
else do
filter_goroutine inchan outchan
chan_map :: GoExprable a => (GoExpr -> a) -> GoChan -> GoProgram GoChan
chan_map f = do
chan_transform map_goroutine
where map_goroutine inchan outchan = do
val <- getchan inchan
putchan outchan $ toGoExpr $ f val
map_goroutine inchan outchan
chan_fold :: GoExprable a => a -> (GoExpr -> a -> a) -> GoChan -> GoProgram GoChan
chan_fold init f = do
chan_transform (fold_goroutine init)
where fold_goroutine acc inchan outchan = do
val <- getchan inchan
let acc' = f val acc
putchan outchan $ toGoExpr acc'
fold_goroutine acc' inchan outchan
-- partition an input channel into two output channels
chan_partition :: (GoExpr -> Bool) -> GoChan -> GoProgram (GoChan, GoChan)
chan_partition pred chan@(chan_put, _) = do
lchan@(_, lchan_get) <- newchan
rchan@(_, rchan_get) <- newchan
let lchan' = (chan_put, rchan_get)
let rchan' = (chan_put, lchan_get)
go $ partition_goroutine lchan rchan
return (lchan', rchan')
where partition_goroutine lchan rchan = do
val <- getchan chan
if pred val
then do
putchan lchan val
partition_goroutine lchan rchan
else do
putchan rchan val
partition_goroutine lchan rchan
-- channel creation functions
-- given a function to produce channel values, create a new channel
chan_produce :: (GoChan -> GoProgram ()) -> GoProgram GoChan
chan_produce action = do
chan <- newchan
go $ action chan
return chan
chan_repeat :: GoExprable a => a -> GoProgram GoChan
chan_repeat x = do
chan_produce repeat_goroutine
where repeat_goroutine chan = do
putchan chan x
repeat_goroutine chan
chan_iterate :: GoExprable a => a -> (a -> a) -> GoProgram GoChan
chan_iterate x f = do
chan_produce $ iterate_goroutine x
where iterate_goroutine acc chan = do
putchan chan (toGoExpr acc)
iterate_goroutine (f acc) chan
chan_cycle :: GoExprable a => [a] -> GoProgram GoChan
chan_cycle lst = do
chan_produce $ cycle_goroutine lst
where cycle_goroutine [] chan = cycle_goroutine lst chan
cycle_goroutine (x:xs) chan = do
putchan chan (toGoExpr x)
cycle_goroutine xs chan
chan_unfoldr :: GoExprable a => b -> (b -> (a, b)) -> GoProgram GoChan
chan_unfoldr x f = do
chan_produce $ unfoldr_goroutine x
where unfoldr_goroutine acc chan = do
let (v, acc') = f acc
putchan chan $ toGoExpr v
unfoldr_goroutine acc' chan
-- misc channel functions
-- get multiple items from the channel at once
-- this blocks until all items are received
chan_take :: Int -> GoChan -> GoProgram [GoExpr]
chan_take 0 chan = return []
chan_take n chan = do
hd <- getchan chan
tl <- chan_take (n-1) chan
return $ hd:tl
-- example: producer-consumer queue
consume :: GoChan -> GoProgram ()
consume chan = do
v <- getchan chan
goprint $ show v
consume chan
produce :: [Int] -> GoChan -> GoProgram ()
produce range chan = do
forM_ range $ \i -> do
putchan chan i
bufmain :: GoProgram ()
bufmain = do
c <- newchan
go $ produce [1..10] c
go $ consume c
goprint "producer-consumer queue running!"
-- example: sieve of eratosthenes
source :: Int -> GoChan -> GoProgram ()
source max c = do
forM_ [2..max] $ \i -> do
putchan c i
pfilter :: Int -> GoChan -> GoChan -> GoProgram ()
pfilter p left right = do
GoInt v <- getchan left
if v `mod` p /= 0
then do
putchan right v
pfilter p left right
else do
pfilter p left right
sink :: GoChan -> GoProgram ()
sink left = do
GoInt v <- getchan left
goprint $ "got prime: " ++ (show v)
right <- newchan
go $ pfilter v left right
sink right
sievemain = do
goprint "running prime sieve..."
c <- newchan
go $ source 50 c
go $ sink c
---
count :: GoExprable a => [a] -> GoChan -> GoProgram ()
count [] chan = return ()
count (x:xs) chan = do
putchan chan x
count xs chan
printLeftRight n left right = do
forM_ [1..n] $ \_ -> do
select [
selcase left (\x -> do
goprint $ "left: " ++ (show x)
),
selcase right (\x -> do
goprint $ "right: " ++ (show x)
)]
eomain :: GoProgram ()
eomain = do
c <- newchan
(evenchan, oddchan) <- chan_partition (\(GoInt x) -> x `mod` 2 == 0) c
go $ count ([1..] :: [Int]) c
go $ printLeftRight 10 evenchan oddchan
---
source2 lst chan = do
forM_ lst $ \x -> do
putchan chan x
sink2 chan = do
v <- getchan chan
goprint (show v)
sink2 chan
mapmain = do
chan <- newchan >>= chan_map (\(GoInt x) -> GoInt (x+100))
go $ source2 ([1..10] :: [Int]) chan
go $ sink2 chan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment