Last active
December 20, 2019 12:48
-
-
Save mforets/c6e9addbcad8a9ba66a2084352cb45e4 to your computer and use it in GitHub Desktop.
System macro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# originally from @ueliwechsler | |
# ------------------------------------ | |
# this file should be loaded into `MathematicalSystems/src/macros.jl` | |
using MacroTools # for the @capture macro | |
using InteractiveUtils # for subtypes | |
export @system | |
# return the tuple containing the dimension(s) in `f_dims` | |
function _capture_dim(f_dims) | |
if @capture(f_dims, (x_)) | |
dims = x | |
elseif @capture(f_dims, (x_, u_)) | |
dims = [x, u] | |
elseif @capture(f_dims, (x_, u_, w_)) | |
dims = [x, u, w] | |
else | |
throw(ArgumentError("the dimensions $f_dims could not be parsed; " * | |
"see the documentation for valid examples")) | |
end | |
return dims | |
end | |
# return `true` if the given expression `expr` corresponds to an equation `lhs = rhs` | |
# and `false` otherwise. This function just detects the presence of the symbol `=`. | |
function is_equation(expr) | |
return @capture(expr, lhs_ = rhs_) | |
end | |
# return the tuple `(true, AT)` if the given expression `expr` corresponds to a dynamic equation, | |
# either of the form `x⁺ = rhs` or `x' = rhs` and `false` otherwise, in the former | |
# case `AT = AbstractDiscreteSystem` and in the latter `AT = AbstractContinuousSystem` | |
# otherwise, return `(false, nothing)` | |
function is_dynamic_equation(expr) | |
!@capture(expr, lhs_ = rhs_) && return (false, nothing) | |
if @capture(lhs, (x_') | (E*x_')) | |
AT = AbstractContinuousSystem | |
subject = x | |
elseif @capture(lhs, (x⁺) | (Ex⁺) | (E*x⁺)) | |
AT = AbstractDiscreteSystem | |
subject = x⁺ | |
else | |
# cast to a string for pattern matching | |
str = string(lhs) | |
if occursin("⁺", str) | |
AT = AbstractDiscreteSystem | |
error("couldn't parse subject of the equation in this case") | |
else | |
return (false, nothing) | |
end | |
end | |
return (true, AT, subject) | |
end | |
function parse_system(exprs) | |
# define default dynamic equation, unknown abstract system type, | |
# and empty list of constraints | |
equation = nothing | |
x = :x | |
AT = nothing | |
constraints = Vector{Expr}() | |
# define default input symbol and default dimension | |
noise = :w | |
dimension = nothing | |
# main loop to parse the subexpressions in exprs | |
for ex in exprs | |
if is_equation(ex) # parse an equation | |
(found, abstract_system_type, subject) = is_dynamic_equation(ex) | |
if found | |
equation = ex | |
x = subject | |
AT = abstract_system_type | |
elseif @capture(ex, (dim = (f_dims_)) | (dims = (f_dims_))) | |
dimension = _capture_dim(f_dims) | |
elseif @capture(ex, (noise = w_) | (w_ = noise)) | |
noise = w | |
else | |
throw(ArgumentError("could not properly parse the equation $ex; " * | |
"see the documentation for valid examples")) | |
end | |
elseif @capture(ex, state_ ∈ Set_) # parse a constraint | |
push!(constraints, ex) | |
elseif @capture(ex, (noise: w_) | (w_: noise)) # parse a noise symbol | |
noise = w | |
elseif @capture(ex, (dim: (f_dims_)) | (dims: (f_dims_))) # parse a dimension | |
dimension = _capture_dim(f_dims) | |
else | |
throw(ArgumentError("the expression $ex could not be parsed; " * | |
"see the documentation for valid examples")) | |
end | |
end | |
# error handling for the given equations | |
equation == nothing && error("the dynamic equation was not found") | |
# error handling for the given set constraints | |
nsets = length(constraints) | |
nsets > 3 && error("cannot parse $nsets set constraints") | |
return equation, x, AT, constraints, noise, dimension | |
end | |
# =========================================================== | |
# =========================================================== | |
function corresponding_type(abstract_type, sys_type::AbstractSystem) | |
fields = fieldnames(sys_type) | |
return corresponding_type(abstract_type, fields) | |
end | |
function corresponding_type(abstract_type, fields::Tuple) | |
TYPES = subtypes(abstract_type) | |
TYPES_FIELDS = fieldnames.(TYPES) | |
is_in(x, y) = all([el ∈ y for el in x]) | |
idx = findall(x -> is_in(x, fields) && is_in(fields,x), TYPES_FIELDS) | |
if length(idx) == 0 | |
error("The entry $(fields) does not match a MathematicalSystem structure.") | |
end | |
return TYPES[idx][1] | |
end | |
macro system(expr...) | |
@show equ, sets, noise, dim = extract_expr(expr) | |
@show equ_2, abstract_type = extract_abstract_type(equ) | |
@show lhs_extract, _, state = extract_lhs(equ_2) | |
rhs_extract = extract_rhs(equ_2, state, noise, dim) | |
set_extract = expand_set.(sets, state, noise) | |
rhs_fields = [tuple[2] for tuple in rhs_extract] | |
lhs_fields = [tuple[2] for tuple in lhs_extract] | |
set_fields = [tuple[2] for tuple in set_extract] | |
@show set_fields | |
(length(unique(set_fields)) != length(set_fields)) && | |
error("There is some ambiguity in the set definition") | |
rhs_var_names = [tuple[1] for tuple in rhs_extract] | |
lhs_var_names = [tuple[1] for tuple in lhs_extract] | |
set_var_names = [tuple[1] for tuple in set_extract] | |
field_names = (rhs_fields..., lhs_fields..., set_fields...) | |
var_names = (rhs_var_names..., lhs_var_names..., set_var_names...) | |
@show field_names | |
sys_type = corresponding_type(abstract_type, field_names) | |
return esc(Expr(:call, :($sys_type), :($(var_names...)))) | |
end | |
function extract_abstract_type(expr) | |
str = string(expr) | |
for pat = ["⁺", "'"] | |
if occursin(pat, str) | |
str = replace(str, pat => "") | |
abstract_type = (pat == "⁺") ? AbstractDiscreteSystem : | |
AbstractContinuousSystem | |
return Meta.parse(str), abstract_type | |
end | |
end | |
error("there is no distinction between continous and discrete" * | |
"use ' for continuous and ⁺ for discrete ") | |
end | |
function extract_lhs(expr) | |
@capture(expr, lhs_ = A_) | |
@show lhs | |
# state variable needs to be a single variable or emojii or | |
# a multiplication sign * needs to be used | |
if @capture(lhs, E_*x_) | |
user_E = E | |
expr = replace(expr, string(user_E)*" * " => "") | |
return [(user_E, :E)], expr, x | |
else | |
return Symbol[], expr, lhs | |
end | |
end | |
function Base.replace(symbol, old_new::Pair{Symbol,Symbol}) | |
new_str = replace(string(symbol), Pair(string.(old_new)...)) | |
return Meta.parse(new_str) | |
end | |
function Base.replace(symbol, old_new::Pair) | |
@show string(symbol) | |
new_str = replace(string(symbol), old_new) | |
return Meta.parse(new_str) | |
end | |
# constant terms needs to be a single variable | |
function extract_rhs(expr, state, noise, dim) | |
if expr.head == :(=) | |
equ = expr | |
elseif expr.head == :tuple && expr.args[1].head == :(=) | |
equ = expr.args[1] | |
else | |
error("Wrong structure, first entry is not a equation.") | |
end | |
rhs = equ.args[2] | |
if isdefined(rhs, :head) && rhs.head == :block # if E*x the equation is a code-block | |
rhs = rhs.args[end] # get rhs of code-block | |
end | |
if @capture(rhs, A_ + B__) # If rhs is a sum | |
@show summands = add_asterix.([A, B...], Ref(state), Ref(noise)) | |
@show params = extract_sum(summands, state, noise) | |
elseif @capture(rhs, f_(a__)) && f != :(*) # If rhs is function call | |
params = extract_function(rhs, dim) | |
else # if rhs i | |
if rhs == state | |
params = [(dim, :statedim)] | |
else | |
rhs = add_asterix(rhs, state, noise) | |
if @capture(rhs, array_ * var_) | |
if state == var | |
params = [(array, :A)] | |
else | |
throw(ArgumentError("if there is only one term on the right side, it needs to"* | |
" include the state.")) | |
end | |
end | |
end | |
end | |
return params | |
end | |
function add_asterix(summand, state, noise) | |
if @capture(summand, A_ * x_) | |
return summand | |
end | |
str = string(summand) | |
if length(str) == 1 | |
return summand | |
else | |
statestr = string(state); lenstate = length(statestr) | |
noisestr = string(noise); lennoise = length(noisestr) | |
if lenstate < length(str) && str[(end-lenstate+1):end] == statestr | |
return Meta.parse(str[1:end-length(statestr)]*"*$state") | |
elseif lennoise < length(str) && str[(end-lennoise+1):end] == noisestr | |
return Meta.parse(str[1:end-length(noisestr)]*"*$noise") | |
else | |
return Meta.parse(str[1:end-1]*"*"*str[end]) | |
end | |
end | |
end | |
function extract_sum(summands, state, noise) | |
params = Any[] | |
for summand in summands | |
if @capture(summand, array_ * var_) | |
@show array, var | |
if state == var | |
push!(params, (array, :A)) | |
elseif noise == var | |
push!(params, (array, :D)) | |
else | |
push!(params, (array, :B)) | |
end | |
elseif @capture(summand, array_) | |
push!(params, (array, :c)) | |
end | |
end | |
return params | |
end | |
function extract_function(rhs, dim) | |
@show dim | |
if @capture(rhs, f_(x_)) | |
return [(f, :f), (dim, :statedim)] | |
elseif @capture(rhs, f_(x_,u_)) | |
return [(f, :f), (dim[1], :statedim), | |
(dim[2], :inputdim)] | |
elseif @capture(rhs, f_(x_,u_,w_)) | |
return [(f, :f), (dim[1], :statedim), | |
(dim[2], :inputdim), | |
(dim[3], :noisedim)] | |
end | |
end | |
function expand_set(expr, state, noise=:w) | |
if @capture(expr, x_ ∈ Set_) | |
if x == state | |
return Set, :X | |
elseif x == noise | |
return Set, :W | |
else #if length(string(x)) ==1 | |
return Set, :U | |
end | |
end | |
error("The set-entry $(expr) does not have the correct form") | |
end | |
function extract_expr(exprs) | |
equs = Any[] | |
sets = Any[] | |
noises = Any[] | |
dims = Any[] | |
for ex in exprs | |
@show ex | |
if @capture(ex, lhs_ = rhs_) | |
push!(equs, ex) | |
elseif @capture(ex, state_ ∈ Set_) | |
push!(sets, ex) | |
elseif @capture(ex, noise: w_) | |
push!(noises, w) | |
elseif @capture(ex, dim: (f_dims_)) | |
if @capture(f_dims, (x_,u_,w_)) | |
push!(dims, [x,u,w]) | |
elseif @capture(f_dims, (x_,u_)) | |
push!(dims, [x,u]) | |
elseif @capture(f_dims, (x_)) | |
push!(dims, x) | |
end | |
else | |
error("The macro has not the right formula!") | |
end | |
end | |
equ = (length(equs) == 1) ? equs[1] : error("more than one or no equations") | |
(length(sets) <= 3) ? nothing : error("more than three set definitions") | |
if length(dims) == 0 | |
dim = 0 | |
elseif length(dims) == 1 | |
dim = dims[1] | |
else | |
error("more than one dims") | |
end | |
if length(noises) == 0 | |
noise = :w | |
elseif length(noises) == 1 | |
noise = noises[1] | |
else | |
error("more than one noises") | |
end | |
return equ, sets, noise, dim | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment