Skip to content

Instantly share code, notes, and snippets.

@hiredman
Created February 26, 2021 20:44
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hiredman/5644dd40f2621b0a783a3231ea29ff1a to your computer and use it in GitHub Desktop.
Save hiredman/5644dd40f2621b0a783a3231ea29ff1a to your computer and use it in GitHub Desktop.
(ns com.manigfeald.yield
(:require [clojure.tools.analyzer :as an]
[clojure.tools.analyzer.ast :as ast]
[clojure.tools.analyzer.env :as env]
[clojure.tools.analyzer.jvm :as an-jvm]
[clojure.tools.analyzer.passes :refer [schedule]]
[clojure.tools.analyzer.passes.jvm.annotate-loops
:refer [annotate-loops]]
[clojure.tools.analyzer.passes.jvm.emit-form :as e]
[clojure.tools.analyzer.passes.jvm.warn-on-reflection
:refer [warn-on-reflection]]
[clojure.tools.analyzer.passes.collect-closed-overs
:refer [collect-closed-overs]]
[clojure.pprint :as pp]
[clojure.set :as set]
[clojure.core.memoize :as memo]))
(set! *warn-on-reflection* true)
(defn i
"Takes an array of registers and reads or writes to register b"
([^objects a ^long b]
(aget a b))
([^objects a ^long b c]
(aset a b c)))
(defn c
"Takes an array of registers, reads register b, and sets register b to nil"
[^objects a ^long b]
((constantly (aget a b)) (aset a b nil)))
(defn make-locals
"Make an array of registers of size c"
[c]
(object-array c))
;; TODO: type hints
;; TODO: resets close over values
;; TODO: link to dataflow analysis slides
;; TODO: exception handlers need to copy exception
;; TODO: pull out custom passes
;; TODO: use traversal
;; TODO: build in trampoline
;; reset/shift http://pllab.is.ocha.ac.jp/~asai/cw2011tutorial/main-e.pdf
(def shift)
(def dummy-bind (gensym '_))
(def exception-bind (gensym 'ex))
(def handler-bind (gensym 'handler))
(def trampoline-bind (gensym 'trampoline))
;; clojure code is run through tools.analyzer which does the hard work
;; 1. alpha renaming
;; 2. type analysis
;; 3. macro expansion
;; 4. well formed code
;; 5. free variables
;; 6. identify loop bindings
;; 7. taint (expression contains a shift)
;; 8. etc
;; https://llvm.org/devmtg/2016-09/slides/Wingo-PersistentDataStructures.pdf
;; the ast from tools.analyzer is transformed into a cps-soup. a
;; graph of labels to nodes, each node explicitly knows atleast 2
;; possible continuations.
;; 1. normal-k the next instruction in normal flow
;; 2. exception-k the exception handler to jump to
;; some nodes have more/other continuations: branching, exception handling
;; a third continuation recur-k is threaded through the transformation
;; the soup can also be viewed as a degenerate case of basic blocks,
;; each block is a single instruction. so we can use so classic basic
;; block dataflow analysis on it. dataflow analysis tells us when a
;; local is unused so it can be cleared.
;; the graph of labels to nodes is turned in to a sequence of
;; instruction blocks by sticking instructions that follow each
;; other in to blocks. when required blocks are split.
;; if an instruction is the target of a jmp or a branch it has to be
;; the first instruction in a block.
;; if an instrunction is jmp, branch, shift, or return it has to be the last
;; instruction in a block
;; if an instruction has a different exception handler then the
;; instruction before it then it has to be the first instruction in a
;; block
;; the generated code is modeled on a SSA style register machine,
;; where each register is a slot in an array and the array has one
;; slot for every local in the transformed code. on the last read of a
;; register it is nil'ed out.
(declare soup)
(defmulti soup* (fn [s label ast normal-k exception-k recur-k normal-name] (:op ast)))
;; values
;; TODO host-interop, instance-call, instance-field, instance?,
;; keyword-invoke, letfn, map, new, primitive-invoke, protocol-invoke,
;; quote, reify, set, set!, static-field, the-var, throw, vector,
;; with-meta
(defmethod soup* :const [s label ast normal-k exception-k recur-k normal-name]
(assoc s label {:type :const
:bind normal-name
:ks [:normal-k :exception-k]
:exception-k exception-k
:normal-k normal-k
:value (:val ast)}))
(defmethod soup* :var [s label ast normal-k exception-k recur-k normal-name]
(assoc s label {:type :const
:bind normal-name
:normal-k normal-k
:ks [:normal-k :exception-k]
:exception-k exception-k
:value (:form ast)}))
(defmethod soup* :local [s label ast normal-k exception-k recur-k normal-name]
(assoc s label {:type :local
:tag (:tag ast)
:normal-k normal-k
:bind normal-name
:ks [:normal-k :exception-k]
:exception-k exception-k
:name (:name ast)}))
(defn help-bind
"given a sequence bindings generates continuations to effect all those
bindings"
[s bindings normal-name label exception-k]
(reduce
(fn [{:keys [me next s name]} binding]
{:me next
:next (gensym 'label)
:name (:name binding)
:s (soup
s
me
(:init binding)
next
exception-k
nil
name)})
{:me label
:name normal-name
:next (gensym 'label)
:s s}
bindings))
(defmethod soup* :static-call [s label ast normal-k exception-k recur-k normal-name]
(let [arg-syms (repeatedly (count (:args ast)) #(gensym 'arg))
invoke (gensym 'label)
{:keys [s me name]} (help-bind s
(map
(fn [exp binding]
{:op :binding
:name exp
:init binding})
arg-syms
(:args ast))
normal-name
label
exception-k)]
(-> s
(assoc me {:type :static-call
:bind name
:class (:class ast)
:method (:method ast)
:args arg-syms
:ks [:normal-k :exception-k]
:normal-k normal-k
:exception-k exception-k}))))
(defmethod soup* :invoke [s label ast normal-k exception-k recur-k normal-name]
(if (and (= :var (:op (:fn ast)))
(= #'shift (:var (:fn ast))))
(let [_ (assert (= 1 (count (:args ast))))
arg-name (gensym 'arg)
shift (gensym 'label)
s (soup
s
label
(first (:args ast))
shift
exception-k
recur-k
normal-name)]
(assoc s shift {:type :shift
:bind arg-name
:arg arg-name
:ks [:normal-k :exception-k]
:exception-k exception-k
:normal-k normal-k}))
(let [arg-syms (repeatedly (inc (count (:args ast))) #(gensym 'arg))
invoke (gensym 'label)
{:keys [s me next name]} (help-bind s
(map
(fn [exp binding]
{:op :binding
:name exp
:init binding})
arg-syms
(cons (:fn ast) (:args ast)))
normal-name
label
exception-k)]
(-> s
(assoc me {:type :invoke
:bind name
:fn (first arg-syms) :args (rest arg-syms)
:ks [:normal-k :exception-k]
:normal-k normal-k
:exception-k exception-k})))))
;; control flow
;; TODO case, try
(defmethod soup* :do [s label ast normal-k exception-k recur-k normal-name]
(let [{:keys [s me next name]} (help-bind s
(map
(fn [exp binding]
{:op :binding
:name exp
:init binding})
(repeat dummy-bind)
(:statements ast))
normal-name
label
exception-k)]
(soup s me (:ret ast) normal-k exception-k recur-k name)))
(defmethod soup* :if [s label ast normal-k exception-k recur-k normal-name]
(let [test-name (gensym 'test)
then-label (gensym 'label)
else-label (gensym 'label)
branch-label (gensym 'label)
s (soup
s
label
(:test ast)
branch-label
exception-k
recur-k
normal-name)
s (soup
s
then-label
(:then ast)
normal-k
exception-k
recur-k
dummy-bind)
s (soup
s
else-label
(:else ast)
normal-k
exception-k
recur-k
dummy-bind)]
(assoc s branch-label {:type :branch
:bind test-name
:test-name test-name
:ks [:normal-k :exception-k :else-k]
:exception-k exception-k
:normal-k then-label
:else-k else-label})))
(defmethod soup* :recur [s label ast normal-k exception-k recur-k normal-name]
(let [_ (contains? ast ::loop-bindings)
_ (assert (= (count (:exprs ast))
(count (::loop-bindings ast))))
temps (repeatedly (count (::loop-bindings ast)) gensym)
;; the double bind is required because the
;; expressions passed to recur could refer to the
;; locals that are being rebound. so the results need
;; to be computed and bound to temps, then they can
;; be shuffled
{:keys [s me next name]} (help-bind s
(map
(fn [exp name]
{:op :binding
:name name
:init exp})
(:exprs ast)
temps)
normal-name
label
exception-k)
{:keys [s me next name]} (help-bind s
(map
(fn [tmp-name loop-name]
{:op :binding
:name (:name loop-name)
:init {:op :local
:name tmp-name}})
temps
(::loop-bindings ast))
name
me
exception-k)]
(assoc s me {:type :jmp
:normal-k recur-k
:ks [:normal-k :exception-k]
:exception-k exception-k
:bind name})))
;; dataflow
(defmethod soup* :loop [s label ast normal-k exception-k recur-k normal-name]
(let [{:keys [s me next name]} (help-bind s
(:bindings ast)
normal-name
label
exception-k)]
(soup s me (:body ast) normal-k exception-k me name)))
(defmethod soup* :let [s label ast normal-k exception-k recur-k normal-name]
(let [{:keys [s me next name]} (help-bind s
(:bindings ast)
normal-name
label
exception-k)]
(soup s me (:body ast) normal-k exception-k recur-k name)))
(defn soup [s label ast normal-k exception-k recur-k normal-name]
(if (and (not (::tainted? ast))
(not (= :local (:op ast)))
(not (= :const (:op ast)))
(not (= :var (:op ast))))
(assoc s label {:type :form
:bind normal-name
:normal-k normal-k
:ks [:normal-k :exception-k]
:exception-k exception-k
:closed-overs (::free ast)
:value (e/emit-hygienic-form ast)})
(soup* s label ast normal-k exception-k recur-k normal-name)))
(defn successors [s label]
(when-let [k (get s label)]
(->> (:ks k)
(map k)
(filter #(contains? s %)))))
(defn calculate-predecessors [s]
(reduce
(fn [s label]
(reduce
(fn [s slabel]
(update-in s [slabel :pre] (fnil conj #{}) label))
s
(successors s label)))
s
(keys s)))
(defn uses [s label]
(when-let [k (get s label)]
(disj (case (:type k)
:local #{(:name k)}
:const #{}
:shift #{(:arg k)}
:static-call (set (:args k))
:invoke (conj (set (:args k))
(:fn k))
:form (:closed-overs k #{})
:branch #{(:test-name k)}
:jmp #{})
(:bind k))))
(defn defs [s label]
(when-let [k (get s label)]
#{(:bind k)}))
(defn dataflow [s]
(let [x (reduce
(fn [x label]
(let [k (get-in x [:s label])
out (apply set/union
(:out k #{})
(for [successor (successors (:s x) label)]
(get-in x [:s successor :in] #{})))
in (set/union
(:in k #{})
(uses (:s x) label)
(set/difference out (defs (:s x) label)))]
(if (or (not= out (:out k))
(not= in (:in k)))
(-> x
(assoc-in [:s label] (assoc k :in in :out out))
(assoc :changed? true))
x)))
{:s s
:changed? false}
(keys s))]
(if (:changed? x)
(recur (:s x))
s)))
(def control-flow? #{:jmp :shift :branch :case})
(def values? #{:form :local :const :static-call :invoke})
(defn traversal [start s]
(reify
clojure.lang.IReduceInit
(reduce [_ fun init]
(loop [init init
stack (list start)
visited #{'ok-exit 'exception-exit}]
(if (seq stack)
(let [[label & stack] stack]
(if (visited label)
(recur init stack visited)
(let [k (get s label)
stack (concat (cons (:normal-k k) (map k (:ks k)))
stack)
visited (conj visited label)]
(recur (fun init (assoc k :label label)) stack visited))))
(unreduced init))))))
;; TODO: fold this into emit
(defn instruction-seqeuences [instructions]
(lazy-seq
(when (seq instructions)
(loop [accum []
[i & is] instructions]
(cond (and (seq accum)
(> (count (:pre i)) 1))
(cons (conj accum {:type :jmp :normal-k (:label i) :exception-k (:exception-k (peek accum))})
(instruction-seqeuences (cons i is)))
(control-flow? (:type i))
(cons (conj accum i) (instruction-seqeuences is))
(values? (:type i))
(recur (conj accum i) is)
:else
(assert nil))))))
;; we end up with two frames accumulating on the stack per
;; continuation call :(
(defn runner [locals f ^long label]
(assert f)
(let [s (Object.)]
(loop [a label]
(let [r (try
(assert a)
(f locals a)
(catch Throwable t
(i locals 1 t)
s))]
(if (identical? r s)
(if (i locals 0)
(recur (long (i locals 0)))
(throw (i locals 1)))
(if (i locals 2)
((c locals 2))
r))))))
(defn value-production [k locals-sym locals clearable]
(case (:type k)
:invoke (let [f-id (get locals (:fn k))]
(assert f-id)
(cons (if (clearable f-id)
`(c ~locals-sym ~f-id)
`(i ~locals-sym ~f-id))
(for [arg (:args k)
:let [arg-id (get locals arg)]]
(if (clearable arg-id)
`(c ~locals-sym ~arg-id)
`(i ~locals-sym ~arg-id)))))
:local (do
(assert (get locals (:name k)) (:name k))
(if (clearable (get locals (:name k)))
`(c ~locals-sym ~(get locals (:name k)))
`(i ~locals-sym ~(get locals (:name k)))))
:const (:value k)
:form (if (seq (:closed-overs k))
`(let [~@(for [n (:closed-overs k)
:let [idx (get locals n)]
i [n (if (clearable idx)
`(c ~locals-sym ~idx)
`(i ~locals-sym ~idx))]]
i)]
~(:value k))
(:value k))
:static-call `(. ~(:class k)
~(cons (:method k)
(for [arg (:args k)
:let [arg-id (get locals arg)]]
(if (clearable arg-id)
`(c ~locals-sym ~arg-id)
`(i ~locals-sym ~arg-id)))))))
(defn bind [y locals-sym binding-id clearable value]
(cond-> y
binding-id (update-in [:block] conj `(i ~locals-sym ~binding-id ~value))
(not binding-id) (update-in [:block] conj value)))
(defn clearable [k]
(-> (set/difference (:in k #{}) (:out k #{}))
(disj dummy-bind)
(cond-> #_nil
(not (contains? (:out k #{}) (:bind k)))
(conj (:bind k)))))
(defn map-locals-to-registers [soup]
(reduce
(fn [m label]
(reduce
(fn [m n]
(if (= n dummy-bind)
m
(if (contains? m n)
m
(assoc m n (count m)))))
m
(into (defs soup label)
(uses soup label))))
{handler-bind 0
exception-bind 1
trampoline-bind 2}
(keys soup)))
(defn emit* [locals-sym fn-sym label-sym s]
(let [locals (map-locals-to-registers s)
case-body (reduce
(fn [x block]
(let [block-id (or (get (:labels x) (:label (first block)))
(count (:labels x)))
x (assoc-in x [:labels (:label (first block))] block-id)
exception-id (or (get (:labels x) (:exception-k (first block)))
(count (:labels x)))
x (assoc-in x [:labels (:exception-k (first block))] exception-id)
x (update-in x [:block] conj `(i ~locals-sym 0 ~exception-id))
x (reduce
(fn foo [y k]
;; any locals that flow in but not out
;; clearable, and the :bind of the
;; kont is also clearable if it
;; doesn't flow out
(let [clearable (cond-> (disj (set/difference (:in k #{})
(:out k #{}))
dummy-bind)
(not (contains? (:out k #{}) (:bind k)))
(conj (:bind k)))
clearable (set (map locals clearable))
binding-id (get locals (:bind (get s (:normal-k k))))
value (gensym 'value)]
(case (:type k)
;; values
(:form :local :const :static-call :invoke)
(bind y locals-sym binding-id clearable (value-production k locals-sym locals clearable))
;;control flow
:shift (-> y
(update-in [:labels] assoc (:normal-k k) (count (:labels y)))
(update-in [:block] conj
`(i
~locals-sym
~(get locals trampoline-bind)
(fn []
((c ~locals-sym ~(get locals (:arg k)))
(^:once fn [~value]
~@(when (and binding-id
(not (clearable binding-id)))
[`(i ~locals-sym ~binding-id ~value)])
(runner
~locals-sym
~fn-sym
~(count (:labels y)))))))))
:jmp (let [target-block-id (or (get (:labels y) (:normal-k k))
(count (:labels y)))]
(-> y
(assoc-in [:labels (:normal-k k)] target-block-id)
(update-in [:block] conj `(recur ~locals-sym ~target-block-id))))
:branch (let [then-id (or (get (:labels y) (:normal-k k))
(count (:labels y)))
y (assoc-in y [:labels (:normal-k k)] then-id)
else-id (or (get (:labels y) (:else-k k))
(count (:labels y)))
test-exp (if (clearable (get locals (:test-name k)))
`(c ~locals-sym ~(get locals (:test-name k)))
`(i ~locals-sym ~(get locals (:test-name k))))]
(-> y
(assoc-in [:labels (:else-k k)] else-id)
(update-in [:block] conj `(if ~test-exp
(recur ~locals-sym ~then-id)
(recur ~locals-sym ~else-id)))))
)))
x
block)]
(-> x
(assoc :block [])
(update-in [:blocks] conj block-id (cons 'do (:block x))))))
{:labels {'exception-exit 0}
:exception-handler 'exception-exit
:stack nil
:block []
:blocks [0 `(do
(i ~locals-sym 0 nil)
;; TODO: why does c fail here?
(throw (c ~locals-sym 1)))]}
(instruction-seqeuences (into [] (traversal 'start s))))]
`(let [~locals-sym (make-locals ~(count locals))]
;; closed overs
~@(for [v (:in (get s 'start))]
`(i ~locals-sym ~(get locals v) ~v))
(runner
~locals-sym
(^:once fn ~fn-sym [~locals-sym ~(with-meta label-sym {:tag 'long})]
(case ~label-sym
~@(:blocks case-body)))
1))))
(def postwalk-transforms
(comp
(fn [ast]
(case (:op ast)
:local ast
:fn-method ast
:fn ast
:if (assoc ast ::tainted? (or (::tainted? (:test ast))
(::tainted? (:then ast))
(::tainted? (:else ast))))
:do (assoc ast ::tainted? (reduce
(fn [a b] (or a (::tainted? b)))
(::tainted? (:ret ast))
(:statements ast)))
:recur (assoc ast ::tainted? true)
:loop (assoc ast ::tainted? (reduce
(fn [a b] (or a (::tainted? b)))
(::tainted? (:body ast))
(:bindings ast)))
:static-call (assoc ast ::tainted? (reduce
(fn [a b]
(or a (::tainted? b)))
false
(:args ast)))
:invoke (if (and (= :var (-> ast :fn :op))
(= #'shift (:var (:fn ast))))
(assoc ast ::tainted? true)
(assoc ast ::tainted? (reduce
(fn [a b] (or a (::tainted? b)))
(::tainted? (:fn ast))
(:args ast))))
:binding (assoc ast ::tainted? (get-in ast [:init ::tainted?]))
:var ast
:const ast))
(fn [ast]
(case (:op ast)
(:binding
:const
:var
:invoke
:static-call
:if
:do
:fn
:recur) (assoc ast ::free (reduce
set/union
#{}
(for [child-key (:children ast)
:let [x (get ast child-key)]
child (if (sequential? x) x [x])]
(::free child))))
(:let :loop) (assoc ast ::free (reduce
(fn [a b] (set/union
(disj a (:name b))
(::free b)))
(::free (:body ast))
(reverse (:bindings ast))))
:fn-method (assoc ast ::free (set/difference (::free (:body ast)) (set (map :name (:params ast)))))
:local (assoc ast ::free #{(:name ast)})))))
(defmacro reset [& body]
(emit*
(gensym 'locals)
(gensym 'this-fn)
(gensym 'label)
(dataflow
(calculate-predecessors
(soup
{}
'start
(-> (binding [an-jvm/run-passes (schedule an-jvm/default-passes)]
(an-jvm/analyze
`(do ~@body)
(assoc (an-jvm/empty-env)
:locals (into {} (for [[n _] &env]
[n {:op :binding
:name n
:form n
:local :let}])))
{}))
(ast/prewalk
(fn [ast]
(let [ast (if (= (:op ast) :loop)
(assoc ast ::loop-bindings (:bindings ast))
ast)]
(reduce
(fn [ast childname]
(if (sequential? (get ast childname))
(assoc ast childname (mapv (fn [cast] (assoc cast ::loop-bindings (::loop-bindings ast))) (get ast childname)))
(assoc ast childname (assoc (get ast childname) ::loop-bindings (::loop-bindings ast)))))
ast
(:children ast)))))
(ast/postwalk postwalk-transforms))
'ok-exit
'exception-exit
nil
dummy-bind)))))
;; (pp/pprint
;; (macroexpand
;; '(reset
;; (loop [n 0
;; n' 1]
;; (recur (shift (fn [k] [n (fn foo []
;; (k n'))]))
;; (+ n n'))))))
#_(pp/with-pprint-dispatch
pp/code-dispatch
(pp/pprint
(macroexpand
'(reset
(loop [n 0
n' 1]
(do
(try
(/ 1 0)
(catch Exception e
(println e)))
(if (> 40 n)
(recur n' (+ n (shift (fn [k] #(k n')))))
n')))))))
;; (prn
;; (trampoline
;; (fn []
;; (reset
;; (loop [n 0
;; n' 1]
;; (do
;; (println "foo")
;; (if (> 40 n)
;; (recur n' (+ n (shift (fn [k] #(k n')))))
;; n')))))))
(defn chan []
{:readers (java.util.concurrent.LinkedBlockingQueue.)
:writers (java.util.concurrent.LinkedBlockingQueue.)})
(defn f [chan value]
(fn [k]
(if-let [reader (.poll ^java.util.concurrent.LinkedBlockingQueue (:readers chan))]
#(do
(reader value)
(k true))
(.put ^java.util.concurrent.LinkedBlockingQueue (:writers chan) [value k]))))
(defn g [chan]
(fn [k]
(assert (map? chan))
(if-let [[v writer] (.poll ^java.util.concurrent.LinkedBlockingQueue (:writers chan))]
#(do
(writer true)
(k v))
(.put ^java.util.concurrent.LinkedBlockingQueue (:readers chan) k))))
(defmacro !> [chan value]
`(shift (f ~chan ~value)))
(defmacro <! [chan]
`(shift (g ~chan)))
(defmacro go [& body]
`(let [c# (chan)]
(trampoline
(fn []
(reset
(!> c# (do ~@body)))))
c#))
;; (def ch (chan))
;; (pp/with-pprint-dispatch
;; pp/code-dispatch
;; (pp/pprint
;; (an-jvm/macroexpand-all
;; '(go
;; (loop [n 0
;; n' 1]
;; (!> ch n)
;; (recur n' (+ n n')))))))
(let [ch (chan)]
(go
(loop [n 0
n' 1]
(when (> 1e6 n)
(!> ch n)
(recur n' (+ n n')))))
(go
(while true
(println (<! ch)))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment