Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using DynamicPPL | |
# utilities for working with Turing model parameter names using only the DynamicPPL API | |
""" | |
flattened_varnames_list(model::DynamicPPL.Model) -> Vector{Symbol} | |
Get a vector of varnames as `Symbol`s with one-to-one correspondence to the | |
flattened parameter vector. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distances, Distributions, OptimalTransport, LinearAlgebra, Random | |
struct PowEuclidean{T} <: Distances.PreMetric | |
p::T | |
end | |
(m::PowEuclidean)(a, b) = Distances.Euclidean()(a, b)^m.p | |
# for measures μ and ν with support ℝᵈ, approximate empirical p-wasserstein distance | |
# between matrices x and y of random points whose columns are respectively drawn from |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Using Manifolds.jl, follow-up to https://github.com/JuliaManifolds/Manifolds.jl/pull/249 | |
using LinearAlgebra | |
using Manifolds | |
using Test | |
# parallel transport of a point (p,X) ∈ T GL(n) along the | |
# geodesic curve γ on GL(n) with the left-GL(n)-invariant metric | |
# from section A.2 of https://arxiv.org/abs/1603.05868v1 | |
function geodesic_flow!(M::GeneralLinear, q, Y, p, X, t::Real = 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rrule for expv(t, A, v) = exp(t * A) * v | |
# Use OrdinaryDiffEq to solve adjoint system for pullback | |
# Since OrdinaryDiffEq depends on ExponentialUtilities, it doesn't make sense | |
# to include this code there | |
using Pkg | |
Pkg.add(["ChainRulesCore", "ChainRulesTestUtils", "ExponentialUtilities", "FiniteDifferences", "OrdinaryDiffEq", "Test", "LinearAlgebra"]) | |
using ChainRulesCore, ChainRulesTestUtils, FiniteDifferences, OrdinaryDiffEq, Test, LinearAlgebra, Random | |
using FiniteDifferences: rand_tangent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using FiniteDifferences, LinearAlgebra, Zygote, Random, Test | |
# adapted from ChainRulesTestUtils.rrule_test | |
function pullback_test( | |
f, | |
ȳ, | |
xx̄s::Tuple{Any,Any}...; | |
rtol = 1e-9, | |
atol = 1e-9, | |
fkwargs = NamedTuple(), |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder