Skip to content

Instantly share code, notes, and snippets.

@samuela
Created June 8, 2021 20:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samuela/7a639712ae32cbd764d3ce1be5fabec9 to your computer and use it in GitHub Desktop.
Save samuela/7a639712ae32cbd764d3ce1be5fabec9 to your computer and use it in GitHub Desktop.
import Base: +, *, ==, isless, show
import InteractiveUtils: @code_typed
import Printf: @printf
abstract type TracedThing end
# These are constants that we meet along the way. Debatable whether we need this at all...
struct TracedConstant <: TracedThing
world
# TODO: make the value type polymorphic
value::Float64
end
# These are the "leaves" of the computation graph and have names.
struct TracedLeaf <: TracedThing
world
name
value::Float64
end
struct TracedMult <: TracedThing
world
lhs::TracedThing
rhs::TracedThing
value::Float64
TracedMult(lhs::TracedThing, rhs::TracedThing) = begin
@assert lhs.world === rhs.world
new(lhs.world, lhs, rhs, lhs.value * rhs.value)
end
end
struct TracedAdd <: TracedThing
world
lhs::TracedThing
rhs::TracedThing
value::Float64
TracedAdd(lhs::TracedThing, rhs::TracedThing) = begin
@assert lhs.world === rhs.world
new(lhs.world, lhs, rhs, lhs.value + rhs.value)
end
end
*(lhs::L, rhs::R) where {L<:TracedThing,R<:TracedThing} = TracedMult(lhs, rhs)
+(lhs::L, rhs::R) where {L<:TracedThing,R<:TracedThing} = TracedAdd(lhs, rhs)
# Little test case here...
let x = TracedLeaf(nothing, "x", -1.5)
x * x
end
is_free1(x, y) = (-2.0 <= x) && (x <= 2.0) && (-2.0 <= y) && (y <= 2.0) && (x*x + y*y >= 1.0)
is_free2(x, y) = (-2.0 <= x <= 2.0) && (-2.0 <= y <= 2.0) && (x*x + y*y >= 1.0)
function is_free3(x, y)
if !(-2.0 <= x <= 2.0)
return false
end
if !(-2.0 <= y <= 2.0)
return false
end
x*x + y*y >= 1.0
end
abstract type Constraint end
struct LessThan <: Constraint
lhs
rhs
end
struct GreaterThan <: Constraint
lhs
rhs
end
struct LessThanOrEq <: Constraint
lhs
rhs
end
struct GreaterThanOrEq <: Constraint
lhs
rhs
end
struct Equal <: Constraint
lhs
rhs
end
struct NotEqual <: Constraint
lhs
rhs
end
function isless(lhs::Float64, rhs::TracedThing)
res = isless(lhs, rhs.value)
push!(rhs.world, (res ? LessThan : GreaterThanOrEq)(TracedConstant(rhs.world, lhs), rhs))
res
end
function isless(lhs::TracedThing, rhs::Float64)
res = isless(lhs.value, rhs)
push!(lhs.world, (res ? LessThan : GreaterThanOrEq)(lhs, TracedConstant(lhs.world, rhs)))
res
end
function ==(lhs::Float64, rhs::TracedThing)
res = lhs == rhs.value
push!(rhs.world, (res ? Equal : NotEqual)(TracedConstant(rhs.world, lhs), rhs))
res
end
function ==(lhs::TracedThing, rhs::Float64)
res = lhs.value == rhs
push!(lhs.world, (res ? Equal : NotEqual)(lhs, TracedConstant(lhs.world, rhs)))
res
end
show(io::IO, x::TracedConstant) = show(io, x.value)
show(io::IO, x::TracedLeaf) = print(io, x.name)
show(io::IO, x::TracedMult) = @printf io "%s * %s" repr(x.lhs) repr(x.rhs)
show(io::IO, x::TracedAdd) = @printf io "%s + %s" repr(x.lhs) repr(x.rhs)
show(io::IO, x::LessThan) = @printf io "%s < %s" repr(x.lhs) repr(x.rhs)
show(io::IO, x::GreaterThanOrEq) = @printf io "%s >= %s" repr(x.lhs) repr(x.rhs)
show(io::IO, x::Equal) = @printf io "%s == %s" repr(x.lhs) repr(x.rhs)
show(io::IO, x::NotEqual) = @printf io "%s =/= %s" repr(x.lhs) repr(x.rhs)
let world = Constraint[]
is_free1(TracedLeaf(world, "x", -1.5), TracedLeaf(world, "y", 1.5))
@show world
end
let world = Constraint[]
is_free2(TracedLeaf(world, "x", -1.5), TracedLeaf(world, "y", 1.5))
@show world
end
let world = Constraint[]
@assert is_free3(TracedLeaf(world, "x", 2.0), TracedLeaf(world, "y", 1.5))
@show world
end
let world = Constraint[]
@show is_free3(TracedLeaf(world, "x", 2.0), TracedLeaf(world, "y", 1.5))
@show world
end
let world = Constraint[]
@show 0.0 <= TracedConstant(world, 0.0)
@show world
end
# let x = TracedConstant(nothing, 3.14)
# @show x
# end
# @show is_free1(2.0, 1.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment