Skip to content

Instantly share code, notes, and snippets.

@trappmartin
Created February 15, 2023 09:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trappmartin/d3dacd08e060df9052556624afd8f11a to your computer and use it in GitHub Desktop.
Save trappmartin/d3dacd08e060df9052556624afd8f11a to your computer and use it in GitHub Desktop.
# Collect all sum nodes.
snodes = filter(n -> isa(n, SumNode), values(spn))
# option 1
# Create an initial values.
q1 = mapreduce(n -> n.logweights[:], vcat, snodes)
# Helper function used by ForwardDiff.
function f1(θ)
N, D = size(x)
c = 1
for i in 1:length(snodes)
K = length(snodes[i])
# Extraction of parameters.
ϕ = θ[c:(c+K-1)]
# Copy parameters to SPN node.
snodes[i].logweights[:] = ϕ
c += length(snodes[i])
end
# Return llh.
return mean(logpdf(spn.root, x) .- logpdf(spn.root, ones(1,D) * NaN))
end
# option 2
# Create an initial values.
q2 = mapreduce(n -> exp.(n.logweights[:]), vcat, snodes)
# Helper function used by ForwardDiff.
function f1(θ)
N, D = size(x)
c = 1
for i in 1:length(snodes)
K = length(snodes[i])
# Extraction of parameters.
ϕ = θ[c:(c+K-1)]
# Copy parameters to SPN node.
snodes[i].logweights[:] = log.(projectToPositiveSimplex!(ϕ))
c += length(snodes[i])
end
# Return llh.
return mean(logpdf(spn.root, x))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment