Skip to content

Instantly share code, notes, and snippets.

@AyeGill
Last active August 8, 2020 19:09
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AyeGill/0f5092c356e25fc188bb712bef56c0f7 to your computer and use it in GitHub Desktop.
Save AyeGill/0f5092c356e25fc188bb712bef56c0f7 to your computer and use it in GitHub Desktop.
Para(Euc) from Backprop as Functor(https://arxiv.org/abs/1711.10455) in Catlab.jl
using Catlab.GAT
using Catlab.Theories
using Flux
struct EucSpc
dim::Int
end
struct ParaEucFun
domdim::Int
coddim::Int
paramdim::Int
operation::Function
end
#This may crash if you operation fails with something other than DimensionMismatch.
#Use with care!
function test_operation_pef(p::ParaEucFun)
try
v = p.operation(zeros(p.domdim),zeros(p.paramdim))
return (size(v,1) == p.coddim)
catch e
if isa(e,DimensionMismatch)
return false
end
throw(e)
end
end
para_id(x,p) = x;
#pd1 = para dim for 1st function, etc
function para_compose(pd1,pd2,f1,f2)
return function (x,p)
param_1 = p[1:pd1]
param_2 = p[pd1+1:pd1+pd2]
f2(f1(x,param_1),param_2)
end
end
@instance SymmetricMonoidalCategory(EucSpc, ParaEucFun) begin
dom(P::ParaEucFun) = EucSpc(P.domdim)
codom(P::ParaEucFun) = EucSpc(P.coddim)
id(x::EucSpc) = ParaEucFun(x.dim,x.dim,0,para_id)
compose(f::ParaEucFun, g::ParaEucFun) =
ParaEucFun(f.domdim, g.coddim,
f.paramdim + g.paramdim,
para_compose(f.paramdim,g.paramdim, f.operation,g.operation))
otimes(a::EucSpc,b::EucSpc) = EucSpc(a.dim + b.dim)
function otimes(f::ParaEucFun, g::ParaEucFun)
dd = f.domdim + g.domdim
cd = f.coddim + g.coddim
pd = f.paramdim + g.paramdim
op = function(x,p)
par1 = p[1:f.paramdim]
par2 = p[f.paramdim+1:pd]
x1 = x[1:f.domdim]
x2 = x[f.domdim+1:dd]
y1 = f.operation(par1,x1)
y2 = g.operation(par2,x2)
return vcat(y1,y2)
end
ParaEucFun(dd,cd,pd,op)
end
munit(::Type{EucSpc}) = EucSpc(0)
function braid(a::EucSpc,b::EucSpc)
dim = a.dim + b.dim
op = function(x,p)
x1 = x[1:a.dim]
x2 = x[a.dim+1:dim]
return vcat(x2,x1)
end
ParaEucFun(dim,dim,0,op)
end
end
##WARNING
#In reality, to get a symmetric monoidal category, we consider equivalence
#classes of parameterized functions up to reparameterization, etc etc
#this is too much of a pain to deal with here
#In practice this means that when you construct eg functors from wiring diagrams
#into ParaEuc, the value of a wiring diagram may depend on the arbitrary choice
#of how to break it up into tensors and compositions and so on.
#(But all the possibilities will be equivalent up to reordering the parameters)
##Using these to run things
struct ParaEucModel
domdim::Int
coddim::Int
paramdim::Int
param::Vector{Float64}
operation::Function
end
#Calling models
(m::ParaEucModel)(x) = m.operation(x,m.param)
#Adds functionality from the flux library
#Only param is trainable (obviously)
Flux.@functor ParaEucModel (param,)
#Create a model with randomized initialization.
function make_model(p::ParaEucFun)
ParaEucModel(p.domdim, p.coddim, p.paramdim,
rand(Float64, p.paramdim), p.operation)
end
#Training
#Should be implemented more generally (other losses, optimizers, etc)
#This is just a basic version for testing and to illustrate
function train_model(m::ParaEucModel, data)
loss(x,y) = Flux.Losses.mse(m(x),y)
opt = Flux.Optimise.Descent()
Flux.train!(loss,params(m),data,opt)
end
#Functional programmers beware! training is stateful! This mutates the model!
#VERY SMALL BABY EXAMPLE
sum = ParaEucFun(2,1,0,(x,p) -> [x[1] + x[2]])
scale = ParaEucFun(1,1,1, (x,p) -> x * p[1])
fun = compose(otimes(scale,scale),sum)
m = make_model(fun)
data = [([0,1],[0]),([0,0],[0]),([1,0],0),([1,1],[1])]
train_model(m,data)

Obviously we'd like to implement a version of the category of learners from the paper. The problem I ran into was coherence issues when you take the product of many sets - you'd end up with some monstrous nested tuple type or something like that. We get around this in Para(Euc) by always passing in flat arrays.

Anyways the obvious next step is to figure out a clean way of handling the coherence, then build a functor Para(Euc) to Learn using a library like Flux.jl

Check out:

@jpfairbanks
Copy link

This looks awesome! I think the coherence of the parameters might be solved by https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/array_partition.jl, which lets you build recursive arrays and then access them with sequential indexing as if it were a single 1D array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment