Skip to content

Instantly share code, notes, and snippets.

@saolof
Last active June 18, 2021 19:27
Show Gist options
  • Save saolof/4c4cbf9920c154adfa4c0fd4b96c130f to your computer and use it in GitHub Desktop.
Save saolof/4c4cbf9920c154adfa4c0fd4b96c130f to your computer and use it in GitHub Desktop.
Quick minikanren implementation that I'm planning to grow into an rkanren with a Dataframe-based datalog-like subset when I will finally have the time
using FunctionalCollections
using MacroTools
using DataFrames
# Data structures designed for this problem
include("persistent_unionfind.jl")
include("pairingheap.jl")
#
# (For datalog-like subset)
#
struct DFrameCol
dataframeID::Int
col::Symbol
end
#########
# 1st order Logic and unification
#######
struct LVar
id::Int
end
Base.iterate(l::LVar) = (l,nothing)
Base.iterate(l::LVar, ::Nothing) = nothing
Base.eltype(::Type{LVar}) = LVar
Base.IteratorSize(::Type{LVar}) = HasShape{0}()
Base.length(::LVar) = 1
Base.size(::LVar) = ()
struct SMap{K,V} <: AbstractDict{K,V} # Making a dedicated substitution map wrapper to do union find in O(log(N)^2) per operation.
bindings::PersistentHashMap{K,V} # Original microkanren made horrible use of data structure and is O(N^2) per operation worst case.
variablesubs::PersistentDisjointSet{K}
dataframes::PersistentVector{DataFrame}
dataframes_index::PersistentDisjointSet{DFrameCol}
end
SMap{K,V}() where {K,V} = SMap{K,V}(PersistentHashMap{K,V}(),PersistentDisjointSet{K}(),PersistentVector{DataFrame}(),PersistentDisjointSet{DFrameCol}())
SMap(d::AbstractDict{K,V}) where {K,V} = foldl((d,p)->assoc(d,first(p),last(p)),pairs(d);init=SMap{K,V}())
macro SMap(d) :($(SMap(eval(d)))) end
FunctionalCollections.assoc(smap::SMap,key::LVar,value) = SMap(assoc(smap.bindings,smap.variablesubs[key],value),smap.variablesubs,smap.dataframes,smap.dataframes_index)
FunctionalCollections.assoc(smap::SMap,key::LVar,value::LVar) = SMap(smap.bindings,union(smap.variablesubs,key,value),smap.dataframes,smap.dataframes_index)
Base.get(smap::SMap,u,default) = get(smap.bindings,smap.variablesubs[u],default)
Base.getindex(smap::SMap,key) = get(smap,key,key)
Base.haskey(smap::SMap,key) = haskey(smap.bindings,smap.variablesubs[key])
Base.keys(smap::SMap) = union(keys(smap.bindings),keys(smap.variablesubs))
Base.pairs(smap::SMap) = (x -> x => walk(x,smap)).(keys(smap))
Base.length(smap::SMap) = length(smap.bindings) + length(smap.variablesubs)
Base.iterate(smap::SMap,args...) = iterate( Iterators.flatten((smap.bindings, map(x -> first(last(x)) => walk(last(last(x)),smap), smap.variablesubs) ) ), args...)
add_substitution(smap,lvar,value) = isnothing(smap) ? nothing : assoc(smap,lvar,value)
function walk(u::LVar,smap)
u = smap.variablesubs[u]
get(smap.bindings,u,u)
end
walk(u,smap) = u
# function walk(u,smap)
let testmap = @SMap Dict(LVar(1) => LVar(2), LVar(2) => "Banana") # S0me tests.
@assert walk(LVar(1), testmap) == "Banana"
@assert walk(LVar(2), testmap) == "Banana"
@assert walk("mango", testmap) == "mango"
end
unify(u,v, ::Nothing) = nothing
function unify(u, v, smap::AbstractDict)
u = walk(u,smap)
v = walk(v,smap)
if u == v return smap end # Delete rule
if isa(v, LVar)
u , v = v, u # Swap rule, variables always to the left.
end
if isa(u, LVar)
return occurs_check(u , v , smap) ? nothing : add_substitution(smap, u, v) # Bind variable u to v . (Eliminate rule for u)
end
unify_terms(u,v,smap) # Due to if checks, u and v cannot be variables, and also are not trivially equal.
end
occurs_check(u , v, smap) = false
occurs_check(u, v::LVar , smap) = (u == smap.variablesubs[v])
occurs_check(u,v::Tuple, smap) = occurs_destructure(u,v,smap)
occurs_check(u,v::AbstractArray, smap) = occurs_destructure(u,v,smap)
function occurs_destructure(u,v, smap)
for i in v
if occurs_check(u,i,smap) return true end
end
false
end
unify_terms(u,v,smap) = nothing # Conflict rule (not equal, and other rules did not apply so this is the default).
unify_terms(u::Tuple , v :: Tuple, smap) = unify_containers(u , v , smap) # Decompose rule
unify_terms(u::AbstractArray , v::AbstractArray , smap) = unify_containers(u , v , smap) # Decompose rule
function unify_containers(u , v, smap) # Decompose rule.
if !equalshape(u,v) return end # Conflict rule.
for i in eachindex(u)
smap = unify(u[i], v[i], smap)
end
smap
end
equalshape(u::AbstractArray ,v::AbstractArray) = (size(u) == size(v))
equalshape(u::Tuple , v::Tuple) = (length(u) == length(v))
let testmap = @SMap Dict() # Some tests.
testmap = unify(LVar(1), LVar(2),testmap)
testmap = unify((LVar(0),"mango"),("banana",LVar(1)),testmap)
@assert !isnothing(testmap)
@assert unify(LVar(0),"mango",testmap) === nothing
@assert unify(LVar(0),"banana",testmap) == testmap
@assert unify(LVar(1),"mango",testmap) == testmap
@assert unify(LVar(2),"mango",testmap) == testmap
@assert unify(LVar(9),"squirrels",testmap) == @SMap Dict(LVar(0)=> "banana",LVar(1) => "mango",LVar(2) => "mango",LVar(9)=> "squirrels" )
end
# Datakanren:
#include("datakanren.jl")
#
# Monadic streams
#
abstract type TE end
struct Evaluated{T} <: TE
val::T
end
struct Thunk{F<:Function} <: TE
priority::Int
f::F # Return type: must always return a heap of Evaluated objects and new thunks.
end
Thunk(f::Function) = Thunk(0, f)
isevaluated(::Evaluated) = true
isevaluated(::Thunk) = false
Base.isless(x::Evaluated,y::Evaluated) = false
Base.isless(x::Evaluated , y::Thunk) = true
Base.isless(x::Thunk, y::Evaluated) = false
Base.isless(x::Thunk, y::Thunk) = isless(x.priority, y.priority)
merge_streams(this::EmptyHeap{TE},other) = other
mapcat_stream(this::EmptyHeap{TE},g) = this
realize_stream_head(this::EmptyHeap{TE}) = this
stream_to_seq(this::EmptyHeap{<:TE}) = list()
# Aka mzero.
empty_stream() = EmptyHeap{TE}()
# Aka monadplus.
merge_streams(this::PairingTree{TE}, other::PairingHeap{TE}) = merge_heaps(this,other, TE) # Swap/interleaving is performed by the merge_heaps implementation.
# Aka mpure
make_stream(s) = singleton_heap(Evaluated(s), TE)
make_stream(f::Function) = singleton_heap(Thunk(f),TE)
# Aka monad bind
function mapcat_stream(this::PairingTree, g::Function )
mapfoldl(merge_streams,this;init=EmptyHeap{TE}()) do val
isevaluated(this.top) ? g(val.val) : singleton_heap(Thunk(val.priority, () -> mapcat_stream(val.f(), g)), TE)
end
end
function realize_stream_head(this::PairingTree{TE})
while !(isempty(this)) && !isevaluated(this.top)
top, rest = pop_min(this)
newval = top.f()
this = merge_heaps(newval , rest, TE)
end
this
end
struct StreamIterator
heap::PairingHeap{TE}
end
Base.IteratorSize(::Type{StreamIterator}) = Base.SizeUnknown()
Base.iterate(x::StreamIterator) = iterate(x,x.heap)
function Base.iterate(x::StreamIterator, h)
h = realize_stream_head(h)
if isempty(h) return nothing end
x, h = pop_min(h)
x.val, h
end
#
# Search
#
struct State
smap::SMap
nextid::Int
end
State(smap) = State(smap,0)
State() = State(@Persistent Dict())
const empty_state = State()
with_smap(state,smap) = State(smap,state.nextid)
bump_nextid(state,n) = State(state.smap,state.nextid + n)
function (u ≅ v)
function unify_goal(state)
unified = unify(u,v,state.smap)
isnothing(unified) ? empty_stream() : make_stream(with_smap(state, unified))
end
unify_goal
end
get_lambda_arity(l) = mapreduce(m->length(m.sig.parameters),max,methods(l)) - 1
call_fresh(goalconstructor) = call_fresh(goalconstructor,get_lambda_arity(goalconstructor))
function call_fresh(goalconstructor,n::Integer)
function innerfn(state)
goal = goalconstructor(LVar.(state.nextid:(state.nextid+n-1))...)
goal(bump_nextid(state,n))
end
innerfn
end
thunk_goal(goal::PairingHeap) = goal
thunk_goal(goal::Function) = singleton_heap(Thunk(goal),TE)
thunk_goal(goal) = singleton_heap(Evaluated(goal), TE)
ldisj(goalA,goalB) = (state) -> merge_streams(thunk_goal(goalA(state)),thunk_goal(goalB(state)))
lconj(goalA,goalB) = (state) -> mapcat_stream(thunk_goal(goalA(state)),goalB)
_delaygoal_expr(goalexpr) = :((state) ->thunk_goal(() -> $(goalexpr)(state)))
macro delaygoal(goalexpr) esc(_delaygoal_expr(goalexpr)) end
_ldisj_expr(goalexpr) = _delaygoal_expr(goalexpr)
_ldisj_expr(goalexpr,goals...) = :(ldisj($(_delaygoal_expr(goalexpr)), $(_ldisj_expr(goals...)) ))
macro ldisj(goalexpr...) esc(_ldisj_expr(goalexpr...)) end
_lconj_expr(goalexpr) = _delaygoal_expr(goalexpr)
_lconj_expr(goalexpr,goals...) = :(lconj($(_delaygoal_expr(goalexpr)), $(_lconj_expr(goals...)) ))
macro lconj(goalexpr...) esc(_lconj_expr(goalexpr...)) end
reify_name(n) = Symbol("_$n")
reify_s(v,state::State) = reify_s(v,state.smap)
reify_s(v,smap) = _reify_s(walk(v,smap),smap)
_reify_s(v,smap) = smap
function _reify_s(v::LVar,smap)
n = reify_name(length(smap))
add_substitution(smap,v,n)
end
_reify_s(v::Tuple, smap) = _reify_s_iterator(v,smap)
_reify_s(v::AbstractArray, smap) = _reify_s_iterator(v,smap)
_reify_s(v::Base.Generator, smap) = _reify_s_iterator(v,smap)
_reify_s(v::Iterators.Flatten, smap) =_reify_s_iterator(v,smap)
function _reify_s_iterator(v,smap)
for x in v
smap = reify_s(x, smap)
end
smap
end
deepwalk(v,smap) = _deepwalk(walk(v,smap),smap)
_deepwalk(v,smap) = v # Extend this.
_deepwalk(v::Tuple, smap) = Tuple(_deepwalk_iterator(v, smap))
_deepwalk(v::AbstractArray, smap) = (x->deepwalk(x,smap)).(v)
_deepwalk(v::Base.Generator , smap) = _deepwalk_iterator(v,smap)
_deepwalk(v::Iterators.Flatten , smap) = _deepwalk_iterator(v,smap)
_deepwalk_iterator(v, smap) = (deepwalk(x,smap) for x in v)
function reify_state_firstvar(state)
v = deepwalk(LVar(0), state.smap)
deepwalk(v,reify_s(v, empty_state))
end
macro conde(blocks)
@capture(blocks,begin clauses__ end)
newclauses = map(clauses) do ex
if @capture(ex,(subclauses__,))
_lconj_expr(subclauses...)
else
ex
end
end
esc(_ldisj_expr(newclauses...))
end
macro fresh(lambda)
esc(if @capture(lambda,args-> (body__,) )
:(call_fresh(args -> $(_lconj_expr(body...))))
else
:(call_fresh(lambda))
end)
end
call_empty_state(goal) = goal(empty_state)
runkanren(f) = Iterators.map(reify_state_firstvar, StreamIterator(call_empty_state(call_fresh(f))))
runcollect(f) = collect(runkanren(f))
#
# Try it out:
#
runcollect() do q
q ≅ 1
end
runcollect() do q, p, r
(p,q) ≅ (1 , p)
end
runcollect() do q
@conde begin
q ≅ 1
q ≅ 3
q ≅ 7
end
end
runcollect() do q, p
@conde begin
q ≅ 1
p ≅ q, p ≅ 1, q ≅ 1
q ≅ 7
end
end
runcollect() do q,p
@conde begin
(q,3,1) ≅ (p,p,1)
p ≅ 7, q ≅ p, q ≅ 1
p ≅ q, p ≅ 5
q ≅ p
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment