Skip to content

Instantly share code, notes, and snippets.

@maartenvd
Last active June 16, 2020 10:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maartenvd/7fd7cab87be53b1f41d908bb0aba9588 to your computer and use it in GitHub Desktop.
Save maartenvd/7fd7cab87be53b1f41d908bb0aba9588 to your computer and use it in GitHub Desktop.
function betterind(tree,indices,usedind=0)
if isa(tree,Int)
return tree,indices,usedind,Dict()
else
(lt,indices,usedind,ltm)=betterind(tree[1],indices,usedind)
(rt,indices,usedind,rtm)=betterind(tree[2],indices,usedind)
tocont=intersect(indices[lt],indices[rt])
nind=Dict(zip(tocont,usedind+1:usedind+length(tocont)))
usedind+=length(tocont)
nind=merge(ltm,rtm,nind)
curcont=length(indices)+1
push!(indices,symdiff(indices[lt],indices[rt]))
return curcont,indices,usedind,nind
end
end
function calc_curcost(tree,indices,costs)
if isa(tree,Int)
return 0,indices[tree]
else
c1,i1=calc_curcost(tree[1],indices,costs)
c2,i2=calc_curcost(tree[2],indices,costs)
cc=c1+c2
open=symdiff(i1,i2)
tocontract=intersect(i1,i2)
oc=prod([costs[i] for i in open])
tc=prod([costs[i] for i in tocontract])
cc+=oc*tc
return cc,open
end
end
macro otensor(ex::Expr)
contraction_worker = esc(defaultparser(ex));
if isassignment(ex) || isdefinition(ex)
ex = getrhs(ex)
end
if !(ex.head == :call && ex.args[1] == :*)
error("cannot compute optimal contraction tree for this expression")
end
tensors = gettensorobjects(ex)
network = [getindices(ex.args[k]) for k = 2:length(ex.args)]
curent_tree = ncontree(network);
Costexpr=Expr(:call, :Dict)
for (symbol,indices) in zip(tensors,network)
for (index,indice) in enumerate(indices)
push!(Costexpr.args,:($indice => size($(esc(symbol)),$index)))
end
end
return quote
cost_map = $(Costexpr)
(curent_cost,_) = calc_curcost($(curent_tree),$(network),cost_map);
(optimal_tree, optimal_cost) = optimaltree($(network), cost_map)
if optimal_cost<curent_cost
sourcefile = $(string(__source__.file));
sourceline = $(__source__.line);
println("suboptimal contraction $(optimal_cost) $(curent_cost/optimal_cost) $(sourcefile):$(sourceline)")
legmap = betterind(optimal_tree,$(deepcopy(network)))[4]
for pair in legmap
println("$(pair[1]) -> $(pair[2])")
end
end
$(contraction_worker)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment