Last active
June 18, 2021 19:27
-
-
Save saolof/4c4cbf9920c154adfa4c0fd4b96c130f to your computer and use it in GitHub Desktop.
Quick minikanren implementation that I'm planning to grow into an rkanren with a Dataframe-based datalog-like subset when I will finally have the time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using FunctionalCollections | |
using MacroTools | |
using DataFrames | |
# Data structures designed for this problem | |
include("persistent_unionfind.jl") | |
include("pairingheap.jl") | |
# | |
# (For datalog-like subset) | |
# | |
struct DFrameCol | |
dataframeID::Int | |
col::Symbol | |
end | |
######### | |
# 1st order Logic and unification | |
####### | |
struct LVar | |
id::Int | |
end | |
Base.iterate(l::LVar) = (l,nothing) | |
Base.iterate(l::LVar, ::Nothing) = nothing | |
Base.eltype(::Type{LVar}) = LVar | |
Base.IteratorSize(::Type{LVar}) = HasShape{0}() | |
Base.length(::LVar) = 1 | |
Base.size(::LVar) = () | |
struct SMap{K,V} <: AbstractDict{K,V} # Making a dedicated substitution map wrapper to do union find in O(log(N)^2) per operation. | |
bindings::PersistentHashMap{K,V} # Original microkanren made horrible use of data structure and is O(N^2) per operation worst case. | |
variablesubs::PersistentDisjointSet{K} | |
dataframes::PersistentVector{DataFrame} | |
dataframes_index::PersistentDisjointSet{DFrameCol} | |
end | |
SMap{K,V}() where {K,V} = SMap{K,V}(PersistentHashMap{K,V}(),PersistentDisjointSet{K}(),PersistentVector{DataFrame}(),PersistentDisjointSet{DFrameCol}()) | |
SMap(d::AbstractDict{K,V}) where {K,V} = foldl((d,p)->assoc(d,first(p),last(p)),pairs(d);init=SMap{K,V}()) | |
macro SMap(d) :($(SMap(eval(d)))) end | |
FunctionalCollections.assoc(smap::SMap,key::LVar,value) = SMap(assoc(smap.bindings,smap.variablesubs[key],value),smap.variablesubs,smap.dataframes,smap.dataframes_index) | |
FunctionalCollections.assoc(smap::SMap,key::LVar,value::LVar) = SMap(smap.bindings,union(smap.variablesubs,key,value),smap.dataframes,smap.dataframes_index) | |
Base.get(smap::SMap,u,default) = get(smap.bindings,smap.variablesubs[u],default) | |
Base.getindex(smap::SMap,key) = get(smap,key,key) | |
Base.haskey(smap::SMap,key) = haskey(smap.bindings,smap.variablesubs[key]) | |
Base.keys(smap::SMap) = union(keys(smap.bindings),keys(smap.variablesubs)) | |
Base.pairs(smap::SMap) = (x -> x => walk(x,smap)).(keys(smap)) | |
Base.length(smap::SMap) = length(smap.bindings) + length(smap.variablesubs) | |
Base.iterate(smap::SMap,args...) = iterate( Iterators.flatten((smap.bindings, map(x -> first(last(x)) => walk(last(last(x)),smap), smap.variablesubs) ) ), args...) | |
add_substitution(smap,lvar,value) = isnothing(smap) ? nothing : assoc(smap,lvar,value) | |
function walk(u::LVar,smap) | |
u = smap.variablesubs[u] | |
get(smap.bindings,u,u) | |
end | |
walk(u,smap) = u | |
# function walk(u,smap) | |
let testmap = @SMap Dict(LVar(1) => LVar(2), LVar(2) => "Banana") # S0me tests. | |
@assert walk(LVar(1), testmap) == "Banana" | |
@assert walk(LVar(2), testmap) == "Banana" | |
@assert walk("mango", testmap) == "mango" | |
end | |
unify(u,v, ::Nothing) = nothing | |
function unify(u, v, smap::AbstractDict) | |
u = walk(u,smap) | |
v = walk(v,smap) | |
if u == v return smap end # Delete rule | |
if isa(v, LVar) | |
u , v = v, u # Swap rule, variables always to the left. | |
end | |
if isa(u, LVar) | |
return occurs_check(u , v , smap) ? nothing : add_substitution(smap, u, v) # Bind variable u to v . (Eliminate rule for u) | |
end | |
unify_terms(u,v,smap) # Due to if checks, u and v cannot be variables, and also are not trivially equal. | |
end | |
occurs_check(u , v, smap) = false | |
occurs_check(u, v::LVar , smap) = (u == smap.variablesubs[v]) | |
occurs_check(u,v::Tuple, smap) = occurs_destructure(u,v,smap) | |
occurs_check(u,v::AbstractArray, smap) = occurs_destructure(u,v,smap) | |
function occurs_destructure(u,v, smap) | |
for i in v | |
if occurs_check(u,i,smap) return true end | |
end | |
false | |
end | |
unify_terms(u,v,smap) = nothing # Conflict rule (not equal, and other rules did not apply so this is the default). | |
unify_terms(u::Tuple , v :: Tuple, smap) = unify_containers(u , v , smap) # Decompose rule | |
unify_terms(u::AbstractArray , v::AbstractArray , smap) = unify_containers(u , v , smap) # Decompose rule | |
function unify_containers(u , v, smap) # Decompose rule. | |
if !equalshape(u,v) return end # Conflict rule. | |
for i in eachindex(u) | |
smap = unify(u[i], v[i], smap) | |
end | |
smap | |
end | |
equalshape(u::AbstractArray ,v::AbstractArray) = (size(u) == size(v)) | |
equalshape(u::Tuple , v::Tuple) = (length(u) == length(v)) | |
let testmap = @SMap Dict() # Some tests. | |
testmap = unify(LVar(1), LVar(2),testmap) | |
testmap = unify((LVar(0),"mango"),("banana",LVar(1)),testmap) | |
@assert !isnothing(testmap) | |
@assert unify(LVar(0),"mango",testmap) === nothing | |
@assert unify(LVar(0),"banana",testmap) == testmap | |
@assert unify(LVar(1),"mango",testmap) == testmap | |
@assert unify(LVar(2),"mango",testmap) == testmap | |
@assert unify(LVar(9),"squirrels",testmap) == @SMap Dict(LVar(0)=> "banana",LVar(1) => "mango",LVar(2) => "mango",LVar(9)=> "squirrels" ) | |
end | |
# Datakanren: | |
#include("datakanren.jl") | |
# | |
# Monadic streams | |
# | |
abstract type TE end | |
struct Evaluated{T} <: TE | |
val::T | |
end | |
struct Thunk{F<:Function} <: TE | |
priority::Int | |
f::F # Return type: must always return a heap of Evaluated objects and new thunks. | |
end | |
Thunk(f::Function) = Thunk(0, f) | |
isevaluated(::Evaluated) = true | |
isevaluated(::Thunk) = false | |
Base.isless(x::Evaluated,y::Evaluated) = false | |
Base.isless(x::Evaluated , y::Thunk) = true | |
Base.isless(x::Thunk, y::Evaluated) = false | |
Base.isless(x::Thunk, y::Thunk) = isless(x.priority, y.priority) | |
merge_streams(this::EmptyHeap{TE},other) = other | |
mapcat_stream(this::EmptyHeap{TE},g) = this | |
realize_stream_head(this::EmptyHeap{TE}) = this | |
stream_to_seq(this::EmptyHeap{<:TE}) = list() | |
# Aka mzero. | |
empty_stream() = EmptyHeap{TE}() | |
# Aka monadplus. | |
merge_streams(this::PairingTree{TE}, other::PairingHeap{TE}) = merge_heaps(this,other, TE) # Swap/interleaving is performed by the merge_heaps implementation. | |
# Aka mpure | |
make_stream(s) = singleton_heap(Evaluated(s), TE) | |
make_stream(f::Function) = singleton_heap(Thunk(f),TE) | |
# Aka monad bind | |
function mapcat_stream(this::PairingTree, g::Function ) | |
mapfoldl(merge_streams,this;init=EmptyHeap{TE}()) do val | |
isevaluated(this.top) ? g(val.val) : singleton_heap(Thunk(val.priority, () -> mapcat_stream(val.f(), g)), TE) | |
end | |
end | |
function realize_stream_head(this::PairingTree{TE}) | |
while !(isempty(this)) && !isevaluated(this.top) | |
top, rest = pop_min(this) | |
newval = top.f() | |
this = merge_heaps(newval , rest, TE) | |
end | |
this | |
end | |
struct StreamIterator | |
heap::PairingHeap{TE} | |
end | |
Base.IteratorSize(::Type{StreamIterator}) = Base.SizeUnknown() | |
Base.iterate(x::StreamIterator) = iterate(x,x.heap) | |
function Base.iterate(x::StreamIterator, h) | |
h = realize_stream_head(h) | |
if isempty(h) return nothing end | |
x, h = pop_min(h) | |
x.val, h | |
end | |
# | |
# Search | |
# | |
struct State | |
smap::SMap | |
nextid::Int | |
end | |
State(smap) = State(smap,0) | |
State() = State(@Persistent Dict()) | |
const empty_state = State() | |
with_smap(state,smap) = State(smap,state.nextid) | |
bump_nextid(state,n) = State(state.smap,state.nextid + n) | |
function (u ≅ v) | |
function unify_goal(state) | |
unified = unify(u,v,state.smap) | |
isnothing(unified) ? empty_stream() : make_stream(with_smap(state, unified)) | |
end | |
unify_goal | |
end | |
get_lambda_arity(l) = mapreduce(m->length(m.sig.parameters),max,methods(l)) - 1 | |
call_fresh(goalconstructor) = call_fresh(goalconstructor,get_lambda_arity(goalconstructor)) | |
function call_fresh(goalconstructor,n::Integer) | |
function innerfn(state) | |
goal = goalconstructor(LVar.(state.nextid:(state.nextid+n-1))...) | |
goal(bump_nextid(state,n)) | |
end | |
innerfn | |
end | |
thunk_goal(goal::PairingHeap) = goal | |
thunk_goal(goal::Function) = singleton_heap(Thunk(goal),TE) | |
thunk_goal(goal) = singleton_heap(Evaluated(goal), TE) | |
ldisj(goalA,goalB) = (state) -> merge_streams(thunk_goal(goalA(state)),thunk_goal(goalB(state))) | |
lconj(goalA,goalB) = (state) -> mapcat_stream(thunk_goal(goalA(state)),goalB) | |
_delaygoal_expr(goalexpr) = :((state) ->thunk_goal(() -> $(goalexpr)(state))) | |
macro delaygoal(goalexpr) esc(_delaygoal_expr(goalexpr)) end | |
_ldisj_expr(goalexpr) = _delaygoal_expr(goalexpr) | |
_ldisj_expr(goalexpr,goals...) = :(ldisj($(_delaygoal_expr(goalexpr)), $(_ldisj_expr(goals...)) )) | |
macro ldisj(goalexpr...) esc(_ldisj_expr(goalexpr...)) end | |
_lconj_expr(goalexpr) = _delaygoal_expr(goalexpr) | |
_lconj_expr(goalexpr,goals...) = :(lconj($(_delaygoal_expr(goalexpr)), $(_lconj_expr(goals...)) )) | |
macro lconj(goalexpr...) esc(_lconj_expr(goalexpr...)) end | |
reify_name(n) = Symbol("_$n") | |
reify_s(v,state::State) = reify_s(v,state.smap) | |
reify_s(v,smap) = _reify_s(walk(v,smap),smap) | |
_reify_s(v,smap) = smap | |
function _reify_s(v::LVar,smap) | |
n = reify_name(length(smap)) | |
add_substitution(smap,v,n) | |
end | |
_reify_s(v::Tuple, smap) = _reify_s_iterator(v,smap) | |
_reify_s(v::AbstractArray, smap) = _reify_s_iterator(v,smap) | |
_reify_s(v::Base.Generator, smap) = _reify_s_iterator(v,smap) | |
_reify_s(v::Iterators.Flatten, smap) =_reify_s_iterator(v,smap) | |
function _reify_s_iterator(v,smap) | |
for x in v | |
smap = reify_s(x, smap) | |
end | |
smap | |
end | |
deepwalk(v,smap) = _deepwalk(walk(v,smap),smap) | |
_deepwalk(v,smap) = v # Extend this. | |
_deepwalk(v::Tuple, smap) = Tuple(_deepwalk_iterator(v, smap)) | |
_deepwalk(v::AbstractArray, smap) = (x->deepwalk(x,smap)).(v) | |
_deepwalk(v::Base.Generator , smap) = _deepwalk_iterator(v,smap) | |
_deepwalk(v::Iterators.Flatten , smap) = _deepwalk_iterator(v,smap) | |
_deepwalk_iterator(v, smap) = (deepwalk(x,smap) for x in v) | |
function reify_state_firstvar(state) | |
v = deepwalk(LVar(0), state.smap) | |
deepwalk(v,reify_s(v, empty_state)) | |
end | |
macro conde(blocks) | |
@capture(blocks,begin clauses__ end) | |
newclauses = map(clauses) do ex | |
if @capture(ex,(subclauses__,)) | |
_lconj_expr(subclauses...) | |
else | |
ex | |
end | |
end | |
esc(_ldisj_expr(newclauses...)) | |
end | |
macro fresh(lambda) | |
esc(if @capture(lambda,args-> (body__,) ) | |
:(call_fresh(args -> $(_lconj_expr(body...)))) | |
else | |
:(call_fresh(lambda)) | |
end) | |
end | |
call_empty_state(goal) = goal(empty_state) | |
runkanren(f) = Iterators.map(reify_state_firstvar, StreamIterator(call_empty_state(call_fresh(f)))) | |
runcollect(f) = collect(runkanren(f)) | |
# | |
# Try it out: | |
# | |
runcollect() do q | |
q ≅ 1 | |
end | |
runcollect() do q, p, r | |
(p,q) ≅ (1 , p) | |
end | |
runcollect() do q | |
@conde begin | |
q ≅ 1 | |
q ≅ 3 | |
q ≅ 7 | |
end | |
end | |
runcollect() do q, p | |
@conde begin | |
q ≅ 1 | |
p ≅ q, p ≅ 1, q ≅ 1 | |
q ≅ 7 | |
end | |
end | |
runcollect() do q,p | |
@conde begin | |
(q,3,1) ≅ (p,p,1) | |
p ≅ 7, q ≅ p, q ≅ 1 | |
p ≅ q, p ≅ 5 | |
q ≅ p | |
end | |
end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment