Skip to content

Instantly share code, notes, and snippets.

@baggepinnen
Last active May 29, 2019 10:50
Show Gist options
  • Save baggepinnen/6fcb15c4f2a61cce9c78aa884da51b0c to your computer and use it in GitHub Desktop.
Save baggepinnen/6fcb15c4f2a61cce9c78aa884da51b0c to your computer and use it in GitHub Desktop.
Compiler pass to translate branching code to map over function
#I would like to use Cassette.jl to implement a compiler pass. The task is to (conditioned on types) translate code that contains a branch into a map over a function that contains the branch. An example of the transformation I would like to do is from
code = Meta.@lower if x > 0
return x^2
else
return -x^2
end
#to
code2 = Meta.@lower map(x.particles) do x
if x > 0
return x^2
else
return -x^2
end
end
#where the translation should only be done if `x` is of a special type defined below
struct Particles{T} <: Real
particles::Vector{T}
end
#Without this transformation, code like the following fails predictably
Base.:(^)(p::Particles,r) = Particles(p.particles.^r)
Base.:(>)(p::Particles, r) = Particles(map(>, p.particles, r))
p = Particles(randn(10))
function negsquare(x)
if x > 0
return x^2
else
return -x^2
end
end
#julia> negsquare(p)
#ERROR: TypeError: non-boolean (Particles) used in boolean context
#If the code was translated to
#function negsquare(x)
# Particles(map(x.particles) do x
# if x > 0
# return x^2
# else
# return -x^2
# end
# end)
#end
#julia> negsquare(p)
#Particles([0.404953, 0.210984, -1.00176, 0.253796, -0.00620389, 0.831144, -0.0240916, -1.90169, 0.875192, 1.4788])
#```
#I would get the desired result.
#I have so far made some progress with Cassette; I can sort out code that do not operate on `Particles` and I can identify branches in the code. However, I cannot figure out how to transform the relevant code into the map statement above. My attempt, with three inserted "QUESTION" and "TODO" in comments
using Cassette
contains_branch(ir::Core.CodeInfo) = any(contains_branch, ir.code)
contains_branch(ex::Expr) = ex.head == :gotoifnot # Unfortunately, there seems to be more ways of branching in ir
contains_branch(any) = false
branch_target(ex) = ex.args[2] # Return the goto statement target index
"My custom compiler pass"
function mapif(::Type{<:Ctx}, reflection::Cassette.Reflection)
ir = reflection.code_info
any(x-> x <: Particles, reflection.signature.parameters) || (return ir) # No particles included in this call
contains_branch(ir) || (return ir) # If there is no branch we leave the code alone
stmtcount = function (stmt, i)
contains_branch(stmt) || (return nothing)
return 1 # QUESTION: One function call replaces the branch
end
newstmts = function (stmt, i)
@show branch_body = ir.code[i+1:branch_target(stmt)-1] # the branch body starts one after the index of the gotoifnot and ends one before the branch target
# TODO: put the branch body into a map function, have to somehow get rid of all stmt that were put into the function
[stmt] # Must have length
end
Cassette.insert_statements!(ir.code, ir.codelocs, stmtcount, newstmts) # QUESTION: Is it good to send in the entire ir.code so that all SSAValues are updated?
ir
end
mapifpass = Cassette.@pass mapif
ctx = Ctx(pass=mapifpass)
Cassette.overdub(ctx, negsquare, p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment