Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Monad notation for Typed Racket
#lang typed/racket
(require (for-syntax racket syntax/parse) racket/stxparam)
(provide do-impl define-do pure)
(module+ test (require typed/rackunit))
(define-syntax (<- stx) (raise-syntax-error '<- "used outside of do"))
(define-syntax-parameter current-monad
(lambda (stx)
(raise-syntax-error 'current-monad "no current monad")))
(define-syntax-parameter current-pure
(lambda (stx)
(raise-syntax-error 'pure "no current pure op")))
(define-syntax (pure stx)
(let ([pure-op (syntax-parameter-value #'current-pure)])
(syntax-parse stx
[(_ arg)
(quasisyntax/loc stx
(#,pure-op arg))])))
(define-syntax-class bind
#:literals (<-)
#:description "bind expression"
(pattern (<- pattern:expr term:expr)))
(define-syntax-class typed-bind
#:literals (<- :)
#:description "bind expression"
(pattern (<- pattern:expr : type:expr term:expr))))
(define-syntax (do-impl stx)
(syntax-parse stx
[(_ bind-op done:expr)
(syntax/loc done done)]
[(_ bind-op to-bind:bind body:expr ...)
(syntax/loc (syntax to-bind)
(bind-op to-bind.term
[to-bind.pattern (do-impl bind-op body ...)])))]
[(_ bind-op to-bind:typed-bind body:expr ...)
(let ([m (syntax-parameter-value #'current-monad)])
(quasisyntax/loc (syntax to-bind)
(bind-op (ann to-bind.term (#,m to-bind.type))
(lambda ([x : to-bind.type])
(match x
[to-bind.pattern (do-impl bind-op body ...)])))))]
[(_ bind-op first:expr rest:expr ...)
(let ([m (syntax-parameter-value #'current-monad)])
(quasisyntax/loc (syntax first)
(bind-op (ann first (#,m Void))
(lambda ([ignored : Void])
(do-impl bind-op rest ...)))))]))
(define-syntax (define-do stx)
(syntax-parse stx
[(_ kwd:id monad:expr bind:expr pure-op:expr)
#'(define-syntax (kwd inner)
(syntax-parse inner
[(_ x:expr (... ...))
#'(syntax-parameterize ([current-monad
(do-impl bind x (... ...)))]))]))
(module+ test
;;; WARNING: not really a monad, oh well
(: option->>= (All (A B)
(-> (Option A)
(-> A (Option B))
(Option B))))
(define (option->>= val fun)
(if val
(fun val)
(define-do option-do Option option->>= identity)
(check-equal? (option-do
(<- x : Number 40)
(<- y : Number #f)
(+ x y))
(check-equal? (option-do (<- x 40)
(check-equal? (option-do (<- x 40)
(<- y "hello")
(box "hi"))
(box "hi"))
(check-equal? (option-do
(<- (list x y z) : (Listof Natural) (list 1 2 3))
(identity y))
(define-type (State S) (All (A) (-> S (List A S))))
(: get (All (S) (-> ((State S) S))))
(define ((get) s) (list s s))
(: put (All (S) (-> S ((State S) Void))))
(define ((put new) old)
(list (void) new))
(: state-pure (All (S A) (-> A ((State S) A))))
(define ((state-pure x) s) (list x s))
(: skip (All (S) (-> ((State S) Void))))
(define (skip) (state-pure (void)))
(: state->>= (All (S A B)
(-> ((State S) A)
(-> A ((State S) B))
((State S) B))))
(define (state->>= val k)
(lambda (state-1)
(match-let ([(list x state-2) (val state-1)])
((k x) state-2))))
(define-syntax (state-do stx)
(syntax-parse stx
[(_ state-type:expr steps:expr ... last:expr)
(syntax/loc stx (syntax-parameterize ([current-monad #'(State state-type)]
[current-pure #'state-pure])
(do-impl state->>= steps ... last)))]))
(: modify (All (S) (-> (-> S S) ((State S) Void))))
(define (modify f)
(state->>= ((inst get S))
(lambda ([x : S])
((inst put S) (f x)))))
(: modify2 (All (S) (-> (-> S S) ((State S) Void))))
(define (modify2 f)
(state-do S
(<- s : S (get))
(put s)))
(: sum-string-lengths (-> (Rec A (U Null String (Pair A A)))
((State Integer) Void)))
(define (sum-string-lengths expr)
[(null? expr) (skip)]
[(string? expr)
(modify (lambda ([x : Integer]) (+ x (string-length expr))))]
[(pair? expr)
(state-do Integer
(sum-string-lengths (car expr))
(sum-string-lengths (cdr expr)))]))
(check-equal? (cadr ((sum-string-lengths '((("hey" "you") "it's" ("a")) "string")) 0))
(struct (A) success ([value : A]) #:transparent)
(struct (E) failure ([error : E]) #:transparent)
(define-type (Error E) (All (A) (U (failure E) (success A))))
(: error->>= (All (E A B) (-> ((Error E) A)
(-> A ((Error E) B))
((Error E) B))))
(define (error->>= v k)
(if (failure? v)
(k (success-value v))))
(define-syntax (error-do stx)
(syntax-parse stx
[(_ error-type:expr steps:expr ... last:expr)
(syntax/loc stx (syntax-parameterize ([current-monad #'(Error error-type)]
[current-pure #'success])
(do-impl error->>= steps ... last)))]))
(: all-success (All (E A) (-> (Listof ((Error E) A))
((Error E) (Listof A)))))
(define (all-success lst)
(match lst
[(cons x xs)
(error-do E
(<- hd : A x)
(<- tl : (Listof A) (all-success xs))
(pure (cons hd tl)))]
(success '())]))
(check-equal? (all-success (list (success 14) (success "hi") (success empty)))
(success (list 14 "hi" empty)))
(check-equal? (all-success (list (success 14) (failure "no way") (success empty)))
(failure "no way"))
;; Local Variables:
;; eval: (put (quote All) (quote racket-indent-function) 1)
;; End:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment