Skip to content

Instantly share code, notes, and snippets.

@cdparks
Last active May 6, 2020 22:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cdparks/6b76be13b2bdcb6c21bd7f4077da3059 to your computer and use it in GitHub Desktop.
Save cdparks/6b76be13b2bdcb6c21bd7f4077da3059 to your computer and use it in GitHub Desktop.
sum(upto(1, 1000)) GRIN example
#define VERBOSE 0
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef enum {
tag_int,
tag_nil,
tag_cons,
tag_upto,
tag_sum
} tag;
typedef struct node node;
typedef struct {
node *head;
node *tail;
} cons;
typedef struct {
node *start;
node *stop;
} upto;
typedef struct {
node *list;
} sum;
typedef struct node {
tag tag;
union {
int as_int;
cons as_cons;
upto as_upto;
sum as_sum;
};
} node;
#define HEAP_SIZE 2048
struct node heap[HEAP_SIZE];
node *hp = &heap[0];
node *limit = &heap[HEAP_SIZE- 1];
const node Nil = (node) {
.tag = tag_nil
};
node Cons(node *head, node *tail) {
return (node) {
.tag = tag_cons,
.as_cons = (cons) {
.head = head,
.tail = tail
}
};
}
node Int(int i) {
return (node) {
.tag = tag_int,
.as_int = i
};
}
node Upto(node *start, node *stop) {
return (node) {
.tag = tag_upto,
.as_upto = (upto) {
.start = start,
.stop = stop
}
};
}
node Sum(node *list) {
return (node) {
.tag = tag_sum,
.as_sum = (sum) {
.list = list
}
};
}
#define MAX_DEPTH 3
#define ALLOC_FORMAT(b, s, ...) \
do { \
(b) = malloc(snprintf(0, 0, (s), ##__VA_ARGS__) + 1); \
sprintf((b), (s), ##__VA_ARGS__); \
} while (0)
char* show_node_deep(node *p, int depth) {
char *buffer = 0;
if (depth > MAX_DEPTH) {
ALLOC_FORMAT(buffer, "...");
} else {
char *lhs = 0;
char *rhs = 0;
switch (p->tag) {
case tag_int:
ALLOC_FORMAT(buffer, "Int(%d)", p->as_int);
break;
case tag_nil:
ALLOC_FORMAT(buffer, "Nil");
break;
case tag_cons:
lhs = show_node_deep(p->as_cons.head, depth + 1);
rhs = show_node_deep(p->as_cons.tail, depth + 1);
ALLOC_FORMAT(buffer, "Cons(%s, %s)", lhs, rhs);
break;
case tag_upto:
lhs = show_node_deep(p->as_upto.start, depth + 1);
rhs = show_node_deep(p->as_upto.stop, depth + 1);
ALLOC_FORMAT(buffer, "Upto(%s, %s)", lhs, rhs);
break;
case tag_sum:
lhs = show_node_deep(p->as_sum.list, depth + 1);
ALLOC_FORMAT(buffer, "Sum(%s)", lhs);
break;
}
free(lhs);
free(rhs);
}
return buffer;
}
char* show_node(node *p) {
return show_node_deep(p, 0);
}
node *store(node n) {
if (hp > limit) {
printf("HEAP OVERFLOW\n");
abort();
}
#if VERBOSE
char *s = show_node(&n);
printf("store(%s)\n", s);
free(s);
#endif
*hp = n;
return hp++;
}
node fetch(node *p) {
#if VERBOSE
char *s = show_node(p);
printf("fetch(%s)\n", s);
free(s);
#endif
return *p;
}
node update(node *p, node n) {
#if VERBOSE
char *s = show_node(&n);
printf("update(%p, %s)\n", p, s);
free(s);
#endif
*p = n;
return n;
}
node eval(node *p);
node upto_impl(node *start, node *stop) {
int lo = eval(start).as_int;
int hi = eval(stop).as_int;
if (lo > hi) {
return Nil;
}
node *x = store(Int(lo + 1));
node *xs = store(Upto(x, stop));
return Cons(start, xs);
}
node sum_impl(node *list) {
node xs = eval(list);
switch (xs.tag) {
case tag_nil:
return Int(0);
case tag_cons: {
int x = eval(xs.as_cons.head).as_int;
int s = sum_impl(xs.as_cons.tail).as_int;
return Int(x + s);
}
default:
printf("argument to sum_impl() must be a list\n");
abort();
}
}
node eval(node *p) {
node n = fetch(p);
switch (n.tag) {
case tag_int:
return n;
case tag_nil:
return n;
case tag_cons:
return n;
case tag_upto:
return update(p, upto_impl(n.as_upto.start, n.as_upto.stop));
case tag_sum:
return update(p, sum_impl(n.as_sum.list));
}
printf("impossible: unrecognized tag in eval; n.tag == %d\n", n.tag);
abort();
}
void print_heap(const char *title, node *lim) {
printf("%s:\n", title);
node *p = &heap[0];
int i = 0;
while (p < lim) {
char *s = show_node(p);
printf("%04d [%p]: %s\n", i, p, s);
free(s);
i += 1;
p += 1;
}
printf("\n");
}
void eval_print(node *p) {
node *lim = p + 1;
char *before = show_node(p);
print_heap("Before", lim);
node r = eval(p);
print_heap("After", lim);
char *after = show_node(&r);
printf("Eval: %s => %s\n", before, after);
free(before);
free(after);
}
int main() {
node *p1 = store(Int(1));
node *p2 = store(Int(1000));
node *p3 = store(Upto(p1, p2));
node *p4 = store(Sum(p3));
eval_print(p4);
return 0;
}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE FlexibleContexts #-}
import Control.Monad.State
import Data.Map.Strict (Map, (!))
import qualified Data.Map.Strict as Map
import Prelude hiding (log, sum)
newtype Address = Address { unAddress :: Int }
deriving Eq
instance Show Address where
show (Address i) = show i <> "#"
data Heap = Heap
{ memory :: Map Int Node
, next :: Int
, level :: Int
}
store :: MonadState Heap m => Node -> m Address
store node = do
n <- gets next
modify $ \s -> s { memory = Map.insert n node $ memory s, next = n + 1 }
pure $ Address n
fetch :: MonadState Heap m => Address -> m Node
fetch (Address i) = do
mem <- gets memory
pure $ mem ! i
update :: MonadState Heap m => Address -> Node -> m ()
update (Address i) node = modify $ \s -> s { memory = Map.insert i node $ memory s }
indent :: MonadState Heap m => m ()
indent = modify $ \s -> s { level = level s + 1 }
dedent :: MonadState Heap m => m ()
dedent = modify $ \s -> s { level = level s - 1 }
log :: (MonadIO m, MonadState Heap m) => String -> m ()
log message = do
n <- gets level
let prefix = concat $ replicate n " "
liftIO $ putStrLn $ prefix <> message
data Node
= CInt Int
| CNil
| CCons Address Address
| FUpto Address Address
| FSum Address
deriving (Eq, Show)
eval :: (MonadIO m, MonadFail m, MonadState Heap m) => Address -> m Node
eval addr# = do
node <- fetch addr#
log $ show node <> " <- fetch " <> show addr#
case node of
CInt{} -> do
log $ "pure " <> show node
pure node
CNil -> do
log $ "pure " <> show node
pure node
CCons{} -> do
log $ "pure " <> show node
pure node
FUpto lo# hi# -> do
indent
ret <- upto lo# hi#
dedent
update addr# ret
log $ "pure " <> show ret
pure ret
FSum xs# -> do
indent
ret <- sum xs#
dedent
update addr# ret
log $ "pure " <> show ret
pure ret
upto :: (MonadIO m, MonadFail m, MonadState Heap m) => Address -> Address -> m Node
upto lo# hi# = do
CInt lo <- eval lo#
CInt hi <- eval hi#
if lo > hi
then pure CNil
else do
x# <- store $ CInt $ lo + 1
xs# <- store $ FUpto x# hi#
pure $ CCons lo# xs#
sum :: (MonadIO m, MonadFail m, MonadState Heap m) => Address -> m Node
sum xs# = do
xs <- eval xs#
case xs of
CNil -> pure $ CInt 0
CCons y# ys# -> do
CInt y <- eval y#
CInt s <- sum ys#
pure $ CInt $ y + s
_ -> impossible xs
impossible :: MonadFail m => Node -> m a
impossible node = fail $ "Should not encounter " <> show node <> " here"
runGrin :: StateT Heap IO a -> IO a
runGrin m = evalStateT m (Heap Map.empty 0 0)
main :: IO ()
main = do
node <- runGrin $ do
t1# <- store $ CInt 1
t2# <- store $ CInt 1000
t3# <- store $ FUpto t1# t2#
t4# <- store $ FSum t3#
eval t4#
print node
Before:
0000 [0x10f0bf050]: Int(1)
0001 [0x10f0bf068]: Int(1000)
0002 [0x10f0bf080]: Upto(Int(1), Int(1000))
0003 [0x10f0bf098]: Sum(Upto(Int(1), Int(1000)))
After:
0000 [0x10f0bf050]: Int(1)
0001 [0x10f0bf068]: Int(1000)
0002 [0x10f0bf080]: Cons(Int(1), Cons(Int(2), Cons(Int(3), Cons(..., ...))))
0003 [0x10f0bf098]: Int(500500)
Eval: Sum(Upto(Int(1), Int(1000))) => Int(500500)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment