Skip to content

Instantly share code, notes, and snippets.

@wilbowma
Created April 1, 2021 05:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wilbowma/b0267e61dc788294adff9635006d196c to your computer and use it in GitHub Desktop.
Save wilbowma/b0267e61dc788294adff9635006d196c to your computer and use it in GitHub Desktop.
#lang racket
;; An introduction to procedure inlining following a simplification of GHC's
;; inliner from:
;; https://www.microsoft.com/en-us/research/wp-content/uploads/2002/07/inline.pdf
;;
;; This version is a simplification in that:
;; 1. It picks loop breakers naively
;; 2. It ONLY inlines procedures (not other bound expressions)
;; 3. All calls are direct, so it's not context sensitive (although the
;; implementation leaves some hooks and comments about where context sensitivity
;; would be added)
;;
;; It's tailored to Exprs-lang-v8 from my CPSC 411 course.
(require
(prefix-in r: graph)
cpsc411/compiler-lib
cpsc411/graph-lib
cpsc411/reference/a8-solution)
(module+ test
(require rackunit))
(define (monadic-fold f init ls)
(for/foldr ([ss '()]
[r init])
([x ls])
(let-values ([(s e) (f x r)])
(values (cons s ss) e))))
(define (scc g)
(define (fg->g fg)
(let ([g (r:directed-graph '())])
(for-each (curry r:add-vertex! g) (map car fg))
(for ([v (map car fg)])
(for ([u (get-neighbors fg v)])
(r:add-directed-edge! g v u)))
g))
(r:scc (fg->g g)))
#|
Exprs-unique-lang:
p ::= (module (define label (lambda (aloc ...) tail)) ... tail)
tail e ::= (let ([aloc e] ...) e) (call label e ...) (primop e ...) v (if e e e)
v ::= aloc integer boolean label
primop ::= - + * zero?
|#
;; Exprs-unique-lang -> Exprs-unique-lang/dependencies
(define (analyze-dependencies p)
(define (analyze-p p)
(match p
[`(module ,defs ... ,tail)
(let-values ([(tail deps) (analyze-tail tail)])
`(module ,(info-set '() 'deps deps) ,@(map analyze-def defs) ,tail))]))
(define (analyze-def def)
(match def
[`(define ,label (lambda (,aloc ...) ,tail))
(let-values ([(tail deps) (analyze-tail tail)])
`(define ,label ,(info-set '() 'deps deps) (lambda (,@aloc) ,tail)))]))
(define (analyze-tail tail)
(analyze-e tail '()))
(define (analyze-e e deps)
(match e
[`(let ([,alocs ,es] ...) ,e)
(let-values ([(es deps) (monadic-fold analyze-e deps es)])
(let-values ([(e deps) (analyze-e e deps)])
(values
`(let ,(map list alocs es) ,e)
deps)))]
[`(call ,label ,es ...)
(let-values ([(es deps) (monadic-fold analyze-e deps es)])
(values
`(call ,label ,@es)
(if (label? label) (cons label deps) deps)))]
[`(,primop ,es ...)
(let-values ([(es deps) (monadic-fold analyze-e deps es)])
(values `(,primop ,@es) deps))]
[_ (analyze-v e deps)]))
(define (analyze-v v deps)
(if (label? v)
(values v (cons v deps))
(values v deps)))
(analyze-p p))
(module+ test
(analyze-dependencies
(uniquify
'(module
(define fact
(lambda (x)
(if (call eq? 0 x)
1
(call * x (call fact (call - x 1))))))
(call fact 5))))
(analyze-dependencies
(uniquify
'(module
(define id
(lambda (x)
x))
(call id 5))))
(analyze-dependencies
(uniquify
'(module
(define odd?
(lambda (x)
(if (call eq? x 0)
0
(let ([y (call + x -1)])
(call even? y)))))
(define even?
(lambda (x)
(if (call eq? x 0)
1
(let ([y (call + x -1)])
(call odd? y)))))
(call even? 5))))
)
#|
Exprs-unique-lang/dependencies:
p ::= (module info (define label info (lambda (aloc ...) tail)) ... tail)
info ::= (deps any/c ...)
deps ::= (label ...)
...
|#
(require racket/trace)
;; Exprs-unique-lang/dependencies -> Exprs-unique-lang/o-map
(define (occurance-analysis p)
(define (build-dependency-graph main-deps defs)
(match defs
[`((define ,labels ,deps ,tails) ...)
(for/fold ([g (add-directed-edges
(new-graph (cons 'main labels))
'main
(info-ref main-deps 'deps))])
([l labels]
[d deps])
(add-directed-edges g l (info-ref d 'deps)))]))
(define label-flags (make-hash))
(define (hash->info h)
(for/fold ([info '()])
([(k v) (in-hash h)])
(info-set info k v)))
;; To break loops, use a naive algorithm: mark an arbitrary member of a SCC as
;; a loop breaker in the info, until removing it from the graph, until there
;; are no cycles
(define (break-loops! g)
(define (pick-loop-breaker scc)
(first (remove 'main scc)))
(let ([sccls (scc g)])
(for ([scc sccls])
(unless (and (equal? 1 (length scc))
(not (member (car scc) (get-neighbors g (car scc)))))
(let ([v (pick-loop-breaker scc)])
(hash-set! label-flags v 'loop-breaker)
(break-loops! (remove-vertex g v)))))))
(define (analyze-p p)
(match p
[`(module ,info ,defs ... ,tail)
(let ([g (build-dependency-graph info defs)]
[labels (map second defs)])
;; Assume all labels are dead, to start with
(for ([l labels])
(hash-set! label-flags l 'dead))
;; Next, mark loop breakers
(break-loops! g)
;; Then, do occurrence analysis
(let ([new-defs (map analyze-def defs)])
(analyze-tail tail)
`(module ,(hash->info label-flags) ,@new-defs ,tail)))]))
(define (analyze-def def)
(match def
[`(define ,label ,_ (lambda (,alocs ...) ,tail))
(analyze-e tail (lambda (v)
(match v
['loop-breaker 'loop-breaker]
['dead 'once-unsafe]
[_ 'multi-unsafe])))
`(define ,label (lambda ,alocs ,tail))]))
(define (analyze-tail tail)
(analyze-e tail (lambda (v)
(match v
['loop-breaker 'loop-breaker]
['dead 'once-safe]
['once-safe 'multi-safe]
[_ v]))))
(define (analyze-e e updater)
(let analyze-e ([e e])
(match e
[`(let ([,alocs ,es] ...) ,e)
(for-each analyze-e es)
(analyze-e e)]
[`(call ,label ,es ...)
(analyze-v label updater)
(for-each analyze-e es)]
[`(,primop ,es ...)
(for-each analyze-e es)]
[_ (analyze-v e updater)])))
(define (analyze-v v updater)
(if (label? v)
(analyze-label v updater)
(void)))
(define (analyze-label l updater)
(hash-update! label-flags l updater))
(analyze-p p))
(module+ test
(occurance-analysis
(analyze-dependencies
(uniquify
'(module
(define odd?
(lambda (x)
(if (call eq? x 0)
0
(let ([y (call + x -1)])
(call even? y)))))
(define even?
(lambda (x)
(if (call eq? x 0)
1
(let ([y (call + x -1)])
(call odd? y)))))
(call even? 5)))))
)
#|
Exprs-unique-lang/o-map:
p ::= (module info (define label (lambda (aloc ...) tail)) ... tail)
info ::= (o-map any/c ...)
o-map ::= (((or/c label aloc) (flags ...)) ...)
flags ::= 'loop-breaker 'dead 'once-safe 'multi-safe 'once-unsafe 'multi-unsafe
...
|#
(define current-inline-threshold (make-parameter 100))
;; Exprs-unique-lang/o-map -> Exprs-unique-lang
(define (simplify p)
(define (simplify-p p)
(match p
[`(module ,info ,defs ... ,tail)
(let* ([undead-defs (filter (lambda (d) (not (eq? 'dead (info-ref info (second d)))))
defs)]
[defs-map (map cdr undead-defs)])
`(module ,@(map (curry simplify-def info defs-map) undead-defs)
,(simplify-e info defs-map tail)))]))
(define (simplify-def info defs def)
(match def
[`(define ,label (lambda (,alocs ...) ,tail))
`(define ,label (lambda ,alocs ,(simplify-e info defs tail)))]))
(define (small-enough? lam)
(define call-overhead 10)
(define if-overhead 2)
(define (primop-overhead p) 10)
(define (size-of lam)
(define (size-of-e e)
(match e
[`(let ([,alocs ,es] ...) ,e)
(apply + (size-of-e e) (map size-of-e es))]
[`(call ,label ,es ...)
(apply + call-overhead (map size-of-e es))]
[`(if ,e ,e1 ,e2)
(+ if-overhead (size-of-e e) (max (size-of-e e1) (size-of-e e2)))]
[`(,primop ,es ...)
(apply + (primop-overhead primop) (map size-of-e es))]
[_ (size-of-v e)]))
(define (size-of-v v) 0)
(match lam
[`(lambda (,alocs ...) ,tail)
(- (size-of-e tail)
call-overhead)]))
(< (size-of lam) (current-inline-threshold)))
(define (simplify-e info defs e)
(match e
[`(let ([,alocs ,es] ...) ,e)
`(let ,(for/list ([aloc alocs]
[e es])
`[,aloc ,(simplify-e info defs e)])
,(simplify-e info defs e))]
[`(call ,label ,es ...)
(let ([flag (info-ref info label #f)]
[rhs (info-ref defs label #f)])
(define (do-inline)
(match rhs
[`(lambda (,alocs ...) ,tail)
`(let ,(for/list ([aloc alocs]
[e es])
`[,aloc ,(simplify-e info defs e)])
,(simplify-e info defs tail))]))
(if (match flag
;; unconditionally inline
['once-safe #t]
['multi-safe (small-enough? rhs)]
['once-unsafe
;; Normally, would guard to ensure it's not a large expression
;; We're only inlining functions, so this is fine.
#t]
['multi-unsafe
;; Similarly, we would guard to ensure it's not a large expression
;; We're only inlining functions, so this is fine.
(small-enough? rhs)]
[_ #f])
(do-inline)
`(call ,label ,@(map (curry simplify-e info defs) es))))]
[`(,primop ,es ...)
`(,primop ,@(map (curry simplify-e info defs) es))]
[_ (simplify-v info defs e)]))
(define (simplify-v info defs v)
v)
(simplify-p p))
(module+ test
(simplify
(occurance-analysis
(analyze-dependencies
(uniquify
'(module
(define id
(lambda (x)
x))
(call id 5)))))))
(define ((until-fix f [bound 5]) p)
(let/ec k
(let loop ([p p]
[n bound])
(when (zero? n)
(k p))
(let ([new-p (f p)])
(if (equal? new-p p)
new-p
(loop new-p (sub1 n)))))))
;; Exprs-unique-lang -> Exprs-unique-lang
(define simplify-exprs (compose simplify occurance-analysis analyze-dependencies))
(module+ test
((until-fix simplify-exprs)
(uniquify
'(module
(define id
(lambda (x)
x))
(call id 5))))
((until-fix simplify-exprs)
(uniquify
'(module
(define odd?
(lambda (x)
(if (call eq? x 0)
0
(let ([y (call + x -1)])
(call even? y)))))
(define even?
(lambda (x)
(if (call eq? x 0)
1
(let ([y (call + x -1)])
(call odd? y)))))
(call even? 5))))
(parameterize ([current-pass-list
(list
check-exprs-lang
uniquify
implement-safe-primops)])
(compile
'(module
(define odd?
(lambda (x)
(if (call eq? x 0)
0
(let ([y (call + x -1)])
(call even? y)))))
(define even?
(lambda (x)
(if (call eq? x 0)
1
(let ([y (call + x -1)])
(call odd? y)))))
(call even? 5))))
(parameterize ([current-pass-list
(list
check-exprs-lang
uniquify
implement-safe-primops
(until-fix simplify-exprs))])
(compile
'(module
(define odd?
(lambda (x)
(if (call eq? x 0)
0
(let ([y (call + x -1)])
(call even? y)))))
(define even?
(lambda (x)
(if (call eq? x 0)
1
(let ([y (call + x -1)])
(call odd? y)))))
(call even? 5)))))
#;(define (transpose-graph og)
(for/fold ([g (new-graph (map car og))])
([v (map car og)])
(for/fold ([g g])
([u (get-neighbors og v)])
(add-directed-edge g u v))))
#;(define (dfs start g)
(let dfs ([start start]
[visit (list start)])
(if (empty? vs)
visit
(dfs (car vs) (cons (car vs) visit))
(for/fold ([])))))
;; Exprs-unique-lang/dependencies -> Exprs-unique-lang/loops-breakers
#;(define (analyze-loops p)
(define (fg->g fg)
(let ([g (directed-graph '())])
(for-each (curry add-vertex! g) (map car fg))
(for ([v (map car fg)])
(for ([u (get-neighbors fg v)])
(add-directed-edge! g v u)))
g))
(define (build-dependency-graph main-deps defs)
(match defs
[`((define ,labels ,deps _) ...)
(for/fold ([g (add-directed-edges
(new-graph (cons 'main labels))
; Main depends on itself
'main (cons 'main main-deps))])
([l labels]
[d deps])
(add-directed-edges l d))]))
(define (analyze-p p)
(match p
[`(module ,info ,defs ... ,tail)
;; First, a graph
(let* ([g (build-depdenency-graph info defs)]
;; Next, we need the strongly-connected components
[sccls (scc g)])
;; Start algorithm
(analyze-loop-breakers g sccls))]))
(define (pick-loop-breaker g scc)
(first (sort scc #:key (lambda (k)
(if (eq? 1 (count k (append-map second g)))
3
0)))))
(define (get-subgraph og u)
(for/fold ([g (new-graph)])
([v (map car og)])
(if (member u (get-neighbors og v))
(add-directed-edge (add-vertex (add-vertex g u) v)
v u)
h)))
(define (remove-directed-edges-to g u)
(for/list ([als g])
(cons (car als) (remove (second als) u))))
(define (analyze-loop-breakers g sccls)
(for ([scc sccls])
(unless (and (eq? 1 (length scc))
(not (member (car scc) (get-neighbors g (car scc)))))
(pick-loop-breaker scc)))
g))
#|
Exprs-unique-lang/loops-breakers:
p ::= (module info (define label info (lambda (aloc ...) tail)) ... tail)
info ::= (lb-seq dgraph any/c ...)
dgraph ::= ((label (label ...)) ...)
lb-seq ::= (label ...) ;; ordered sequence
...
|#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment