Skip to content

Instantly share code, notes, and snippets.

@phipsgabler
Last active March 1, 2020 18:16
Show Gist options
  • Save phipsgabler/dd6bd32144fd5f6e0584d8cbeec2fcc0 to your computer and use it in GitHub Desktop.
Save phipsgabler/dd6bd32144fd5f6e0584d8cbeec2fcc0 to your computer and use it in GitHub Desktop.
using IRTools
abstract type AbstractNode end
struct Node <: AbstractNode
ir::IRTools.IR
children::Vector{<:AbstractNode}
end
Node(ir) = Node(ir, Vector{AbstractNode}())
struct PrimitiveNode <: AbstractNode end
function transform(ir)
original_ir = copy(ir)
pipe = IRTools.Pipe(ir)
node = IRTools.pushfirst!(
pipe, IRTools.xcall(Main, :Node, QuoteNode(original_ir)))
children = IRTools.insert!(
pipe, node, IRTools.xcall(:getproperty, node, QuoteNode(:children)), after=true)
for (v, stmt) in pipe
call = IRTools.insert!(pipe, v, IRTools.xcall(Main, :track, stmt.expr.args...))
childnode = IRTools.insert!(pipe, v, IRTools.xcall(:getindex, call, 2))
push!(pipe, IRTools.xcall(:push!, children, childnode))
pipe[v] = IRTools.xcall(:getindex, call, 1)
end
new_ir = IRTools.finish(pipe)
for block in IRTools.blocks(new_ir)
if IRTools.isreturn(block)
r = IRTools.returnvalue(block)
n = IRTools.substitute(pipe, node)
IRTools.return!(block, IRTools.xcall(:tuple, r, n))
end
end
return new_ir
end
const PrimitiveFunction = Union{typeof(>), typeof(+)}
track(f::PrimitiveFunction, args...) = f(args...), PrimitiveNode()
IRTools.@dynamo function track(args...)
ir = IRTools.IR(args...)
ir == nothing && return
return transform(ir)
end
relu(x) = x > 0 ? x : 0
## `track(f, args...)` will evaluate `f`, and return the result together with a `Node`, which
## contains the IR of the called method, and a list of child `Nodes` for the recursively tracked
## execution path. If a branch is not taken, there is no node recorded for it.
# julia> track(x -> x > 0 ? x + 1 : relu(x), 3)
# (4, Node(1: (%1, %2)
# %3 = %2 > 0
# br 3 unless %3
# 2:
# %4 = %2 + 1
# return %4
# 3:
# %5 = Main.relu(%2)
# return %5,
# AbstractNode[PrimitiveNode(), PrimitiveNode()]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment