Created
June 8, 2021 20:00
-
-
Save samuela/7a639712ae32cbd764d3ce1be5fabec9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: +, *, ==, isless, show | |
import InteractiveUtils: @code_typed | |
import Printf: @printf | |
abstract type TracedThing end | |
# These are constants that we meet along the way. Debatable whether we need this at all... | |
struct TracedConstant <: TracedThing | |
world | |
# TODO: make the value type polymorphic | |
value::Float64 | |
end | |
# These are the "leaves" of the computation graph and have names. | |
struct TracedLeaf <: TracedThing | |
world | |
name | |
value::Float64 | |
end | |
struct TracedMult <: TracedThing | |
world | |
lhs::TracedThing | |
rhs::TracedThing | |
value::Float64 | |
TracedMult(lhs::TracedThing, rhs::TracedThing) = begin | |
@assert lhs.world === rhs.world | |
new(lhs.world, lhs, rhs, lhs.value * rhs.value) | |
end | |
end | |
struct TracedAdd <: TracedThing | |
world | |
lhs::TracedThing | |
rhs::TracedThing | |
value::Float64 | |
TracedAdd(lhs::TracedThing, rhs::TracedThing) = begin | |
@assert lhs.world === rhs.world | |
new(lhs.world, lhs, rhs, lhs.value + rhs.value) | |
end | |
end | |
*(lhs::L, rhs::R) where {L<:TracedThing,R<:TracedThing} = TracedMult(lhs, rhs) | |
+(lhs::L, rhs::R) where {L<:TracedThing,R<:TracedThing} = TracedAdd(lhs, rhs) | |
# Little test case here... | |
let x = TracedLeaf(nothing, "x", -1.5) | |
x * x | |
end | |
is_free1(x, y) = (-2.0 <= x) && (x <= 2.0) && (-2.0 <= y) && (y <= 2.0) && (x*x + y*y >= 1.0) | |
is_free2(x, y) = (-2.0 <= x <= 2.0) && (-2.0 <= y <= 2.0) && (x*x + y*y >= 1.0) | |
function is_free3(x, y) | |
if !(-2.0 <= x <= 2.0) | |
return false | |
end | |
if !(-2.0 <= y <= 2.0) | |
return false | |
end | |
x*x + y*y >= 1.0 | |
end | |
abstract type Constraint end | |
struct LessThan <: Constraint | |
lhs | |
rhs | |
end | |
struct GreaterThan <: Constraint | |
lhs | |
rhs | |
end | |
struct LessThanOrEq <: Constraint | |
lhs | |
rhs | |
end | |
struct GreaterThanOrEq <: Constraint | |
lhs | |
rhs | |
end | |
struct Equal <: Constraint | |
lhs | |
rhs | |
end | |
struct NotEqual <: Constraint | |
lhs | |
rhs | |
end | |
function isless(lhs::Float64, rhs::TracedThing) | |
res = isless(lhs, rhs.value) | |
push!(rhs.world, (res ? LessThan : GreaterThanOrEq)(TracedConstant(rhs.world, lhs), rhs)) | |
res | |
end | |
function isless(lhs::TracedThing, rhs::Float64) | |
res = isless(lhs.value, rhs) | |
push!(lhs.world, (res ? LessThan : GreaterThanOrEq)(lhs, TracedConstant(lhs.world, rhs))) | |
res | |
end | |
function ==(lhs::Float64, rhs::TracedThing) | |
res = lhs == rhs.value | |
push!(rhs.world, (res ? Equal : NotEqual)(TracedConstant(rhs.world, lhs), rhs)) | |
res | |
end | |
function ==(lhs::TracedThing, rhs::Float64) | |
res = lhs.value == rhs | |
push!(lhs.world, (res ? Equal : NotEqual)(lhs, TracedConstant(lhs.world, rhs))) | |
res | |
end | |
show(io::IO, x::TracedConstant) = show(io, x.value) | |
show(io::IO, x::TracedLeaf) = print(io, x.name) | |
show(io::IO, x::TracedMult) = @printf io "%s * %s" repr(x.lhs) repr(x.rhs) | |
show(io::IO, x::TracedAdd) = @printf io "%s + %s" repr(x.lhs) repr(x.rhs) | |
show(io::IO, x::LessThan) = @printf io "%s < %s" repr(x.lhs) repr(x.rhs) | |
show(io::IO, x::GreaterThanOrEq) = @printf io "%s >= %s" repr(x.lhs) repr(x.rhs) | |
show(io::IO, x::Equal) = @printf io "%s == %s" repr(x.lhs) repr(x.rhs) | |
show(io::IO, x::NotEqual) = @printf io "%s =/= %s" repr(x.lhs) repr(x.rhs) | |
let world = Constraint[] | |
is_free1(TracedLeaf(world, "x", -1.5), TracedLeaf(world, "y", 1.5)) | |
@show world | |
end | |
let world = Constraint[] | |
is_free2(TracedLeaf(world, "x", -1.5), TracedLeaf(world, "y", 1.5)) | |
@show world | |
end | |
let world = Constraint[] | |
@assert is_free3(TracedLeaf(world, "x", 2.0), TracedLeaf(world, "y", 1.5)) | |
@show world | |
end | |
let world = Constraint[] | |
@show is_free3(TracedLeaf(world, "x", 2.0), TracedLeaf(world, "y", 1.5)) | |
@show world | |
end | |
let world = Constraint[] | |
@show 0.0 <= TracedConstant(world, 0.0) | |
@show world | |
end | |
# let x = TracedConstant(nothing, 3.14) | |
# @show x | |
# end | |
# @show is_free1(2.0, 1.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment