Skip to content

Instantly share code, notes, and snippets.

@devmotion
Last active August 2, 2018 23:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devmotion/5fef9f5a80398c5dc89511e06afc2a03 to your computer and use it in GitHub Desktop.
Save devmotion/5fef9f5a80398c5dc89511e06afc2a03 to your computer and use it in GitHub Desktop.
Julia macros
using MacroTools: namify
using Base.Meta: isexpr
## Utilities
# Split struct definition (neglecting constructors)
function splitstruct(structdef)
# Split type definition
isexpr(structdef, :struct, 3) || error("Not a type definition:", structdef)
mutable, structhead, structbody = structdef.args
# Split head of type definition
dict = splitstructhead(structhead)
dict[:mutable] = mutable
# Collect fields
fields = gatherfields(structbody)
dict[:fields] = fields
dict
end
# Split struct head
function splitstructhead(structhead)
# Handle supertype annotations
if isexpr(structhead, :<:, 2)
name_param, stype = structhead.args
else
name_param, stype = structhead, nothing
end
# Handle parameters
if isexpr(name_param, :curly) && length(name_param.args) > 1
name, params = name_param.args[1], name_param.args[2:end]
else
name, params = name_param, nothing
end
isa(name, Symbol) || error("Not a head of a type definition:", structhead)
dict = Dict{Symbol,Any}(:name => name)
params != nothing && (dict[:params] = params)
stype != nothing && (dict[:supertype] = stype)
dict
end
# Collect all fields
gatherfields(ex) = _gatherfields!([], ex)
_gatherfields!(fields, ex) = fields
_gatherfields!(fields, ex::Symbol) = push!(fields, ex)
function _gatherfields!(fields, ex::Expr)
if isexpr(ex, Symbol("::"), 2)
push!(fields, ex)
else
for arg in ex.args
_gatherfields!(fields, arg)
end
fields
end
end
# Combine struct definition
function combinestruct(dict::Dict)
structhead = combinestructhead(dict)
# Generate struct definition
if dict[:mutable]
structdef = :(mutable struct $structhead
$(dict[:fields]...)
end)
else
structdef = :(struct $structhead
$(dict[:fields]...)
end)
end
# Add inner constructor
if haskey(dict, :inner)
push!(structdef.args[3].args, dict[:inner])
end
structdef
end
# Combine struct head
function combinestructhead(dict::Dict)
name = dict[:name]
params = get(dict, :params, [])
stype = get(dict, :supertype, :Any)
isempty(params) ? :($name <: $stype) : :($name{$(params...)} <: $stype)
end
# Add inner constructor
function inner!(dict::Dict, n::Int)
# Compute subset of fields
allfields = dict[:fields]
fields = n < 0 ? fields = @view(allfields[1:end+n]) :
(n < length(allfields) ? @view(allfields[1:n]) : allfields)
# Obtain parameters without supertypes
paramnames = [namify(p) for p in get(dict, :params, [])]
# Add inner constructor
if isempty(paramnames)
dict[:inner] =
:(function $(dict[:name])($(fields...))
new($(namify.(fields)...))
end)
else
dict[:inner] =
:(function $(dict[:name]){$(paramnames...)}($(fields...)) where {$(paramnames...)}
new{$(paramnames...)}($(namify.(fields)...))
end)
end
dict
end
## Extend struct definition
function extend!(dict::Dict, template::Dict)
# Merge parameters
if haskey(template, :params)
if !haskey(dict, :params)
dict[:params] = template[:params]
else
# TODO: do not copy existing parameters?
append!(dict[:params], template[:params])
end
end
# Merge supertypes
!haskey(dict, :supertype) && haskey(template, :supertype) &&
(dict[:supertype] = template[:supertype])
# Merge fields
tfields = template[:fields]
if !isempty(tfields)
fields = dict[:fields]
if isempty(fields)
fields = tfields
else
# TODO: do not copy existing fields?
append!(fields, tfields)
end
end
dict
end
# MACROS
## Add inner constructor
macro add_inner(n::Int, structdef::Expr)
esc(add_inner(n, structdef))
end
function add_inner(n::Int, structdef::Expr)
dict = splitstruct(structdef)
inner!(dict, n)
combinestruct(dict)
end
## Struct template
# Default template
_template(::Val) = Dict()
macro base(structdef::Expr)
esc(base(structdef))
end
function base(structdef::Expr)
dict = splitstruct(structdef)
:($structdef;
_template(::$(Type{Val{dict[:name]}})) = $dict)
end
macro base_inner(n::Int, structdef::Expr)
esc(base_inner(n, structdef))
end
function base_inner(n::Int, structdef::Expr)
dict = splitstruct(structdef)
inner!(dict, n)
:($(combinestruct(dict));
_template(::$(Type{Val{dict[:name]}})) = $dict)
end
## Extend struct definition
macro extend(template::Symbol, structdef::Expr)
esc(extend(template, structdef))
end
function extend(template::Symbol, structdef::Expr)
dict = splitstruct(structdef)
tdict = _template(Val{template})
extend!(dict, tdict)
combinestruct(dict)
end
macro extend_inner(template::Symbol, n::Int, structdef::Expr)
esc(extend_inner(template, n, structdef))
end
function extend_inner(template::Symbol, n::Int, structdef::Expr)
dict = splitstruct(structdef)
tdict = _template(Val{template})
extend!(dict, tdict)
inner!(dict, n)
combinestruct(dict)
end
# EXAMPLE
using DiffEqBase
using OrdinaryDiffEq: OrdinaryDiffEqAlgorithm
@base_inner -2 mutable struct ODEIntegrator{algType<:OrdinaryDiffEqAlgorithm,uType,tType,pType,eigenType,QT,tdirType,ksEltype,SolType,F,CacheType,O,FSALType} <: DiffEqBase.AbstractODEIntegrator
sol::SolType
u::uType
k::ksEltype
t::tType
dt::tType
f::F
p::pType
uprev::uType
uprev2::uType
tprev::tType
alg::algType
dtcache::tType
dtchangeable::Bool
dtpropose::tType
tdir::tdirType
eigen_est::eigenType
EEst::QT
qold::QT
q11::QT
erracc::QT
dtacc::tType
success_iter::Int
iter::Int
saveiter::Int
saveiter_dense::Int
cache::CacheType
kshortsize::Int
force_stepfail::Bool
last_stepfail::Bool
just_hit_tstop::Bool
event_last_time::Bool
accept_step::Bool
isout::Bool
reeval_fsal::Bool
u_modified::Bool
opts::O
fsalfirst::FSALType
fsallast::FSALType
end
@show @macroexpand(@extend_inner ODEIntegrator -2 mutable struct DDEIntegrator{absType,relType,residType,IType,NType,tstopsType} <: AbstractDDEIntegrator
prev_idx::Int
prev2_idx::Int
fixedpoint_abstol::absType
fixedpoint_reltol::relType
resid::residType # This would have to resize for resizing DDE to work
fixedpoint_norm::NType
max_fixedpoint_iters::Int
saveat::tstopsType
tracked_discontinuities::Vector{Discontinuity{tType}}
integrator::IType
end)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment