Last active
July 21, 2024 20:43
-
-
Save SouthEndMusic/06f9642095546d5e9c117667957ff47a to your computer and use it in GitHub Desktop.
Proof of concept of deriving Gauss-Turán Quadratures
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Integrals | |
using Optim | |
using PreallocationTools | |
using ForwardDiff | |
using Base.Threads | |
################################################# | |
## Automatic higher order derivatives | |
""" | |
Generate a function `derivs!(out, f, x)` which computes the 0th up to the max_orderth derivative | |
of the scalar-to-scalar function f at x and stores them in ascending derivative order in `out`. | |
Hence `out` must be at least of length `max_order + 1`. | |
""" | |
macro generate_derivs(max_order) | |
# Create nested dual number of required depth (arg_0, …, arg_{max_order}) | |
arg_assignments = [:(arg_0 = x)] | |
for i = 1:max_order | |
arg_name = Symbol("arg_$i") | |
prev_arg_name = Symbol("arg_$(i-1)") | |
push!( | |
arg_assignments, | |
:($arg_name = ForwardDiff.Dual{Val{$i}}($prev_arg_name, one($prev_arg_name))), | |
) | |
end | |
# Unpack the results | |
arg_max = Symbol("arg_$max_order") | |
res_unpacks = [:(res_0 = f($arg_max))] | |
for i = 1:max_order | |
res_name = Symbol("res_$i") | |
res_prev_name = Symbol("res_$(i-1)") | |
push!(res_unpacks, :($res_name = only($res_prev_name.partials))) | |
end | |
# Assign the results | |
out_assignments = Expr[] | |
for i = 0:max_order | |
res = Symbol("res_$i") | |
res_temp = Symbol("$(res)_temp_0") | |
push!(out_assignments, :($res_temp = $res)) | |
# Create temporary variables to get to | |
# res_{i}.value.value.value… | |
for j = 1:(max_order-i) | |
res_temp = Symbol("$(res)_temp_$j") | |
res_temp_prev = Symbol("$(res)_temp_$(j-1)") | |
push!(out_assignments, :($res_temp = $res_temp_prev.value)) | |
end | |
res_temp = Symbol("$(res)_temp_$(max_order - i)") | |
push!(out_assignments, :(out[$(i + 1)] = $res_temp)) | |
end | |
# Construct the complete function definition | |
func_def = quote | |
function derivs!(out, f, x::T)::Nothing where {T<:Number} | |
$(arg_assignments...) | |
$(res_unpacks...) | |
$(out_assignments...) | |
return nothing | |
end | |
end | |
return func_def | |
end | |
################################################# | |
## Gauss-Turán quadrature rule computation | |
# Overload LinearAlgebra.generic_trimatdiv! with patched versions | |
# from https://github.com/JuliaLang/julia/pull/54201 otherwise | |
# Optim.jl crashes on BigFloat/Double | |
include("fix.jl") | |
function DEFAULT_w(x::T)::T where {T} | |
one(T) | |
end | |
# 1 ≤ i ≤ n : node index | |
# 1 ≤ m ≤ 2s + 1 : derivative order + 1 | |
flat_index(i, m, n) = (m - 1) * n + i | |
""" | |
Cached data for the `GaussTuranLoss!` call. | |
""" | |
struct GaussTuranCache{T} | |
n::Int | |
s::Int | |
N::Int | |
a::T | |
b::T | |
ε::T | |
rhs_upper::Vector{T} | |
rhs_lower::Vector{T} | |
M_upper_buffer::LazyBufferCache{typeof(identity)} | |
M_lower_buffer::LazyBufferCache{typeof(identity)} | |
X_buffer::LazyBufferCache{typeof(identity)} | |
A_buffer::LazyBufferCache{typeof(identity)} | |
function GaussTuranCache( | |
n, | |
s, | |
N, | |
a::T, | |
b::T, | |
ε::T, | |
rhs_upper::Vector{T}, | |
rhs_lower::Vector{T}, | |
) where {T} | |
new{T}( | |
n, | |
s, | |
N, | |
a, | |
b, | |
ε, | |
rhs_upper, | |
rhs_lower, | |
LazyBufferCache(), | |
LazyBufferCache(), | |
LazyBufferCache(), | |
LazyBufferCache(), | |
) | |
end | |
end | |
""" | |
Function whose root defines the quadrature rule. | |
""" | |
function GaussTuranLoss!(f!, ΔX, cache) | |
(; | |
n, | |
s, | |
N, | |
a, | |
rhs_upper, | |
rhs_lower, | |
M_upper_buffer, | |
M_lower_buffer, | |
A_buffer, | |
X_buffer, | |
) = cache | |
M_upper = M_upper_buffer[ΔX, (N, N)] | |
M_lower = M_lower_buffer[ΔX, (n, N)] | |
A = A_buffer[ΔX, N] | |
X = X_buffer[ΔX, n] | |
# Compute X from ΔX | |
cumsum!(X, ΔX) | |
X .+= a | |
# Evaluating f! | |
for (i, x) in enumerate(X) | |
Threads.@threads for j = 1:N | |
f!(view(M_upper, j, i:n:N), x, j) | |
end | |
Threads.@threads for j = (N+1):(N+n) | |
f!(view(M_lower, j - N, i:n:N), x, j) | |
end | |
end | |
# Solving for A | |
A .= M_upper \ rhs_upper | |
# Computing output | |
out = zero(eltype(ΔX)) | |
for i in eachindex(X) | |
out_term = -rhs_lower[i] | |
for j in eachindex(A) | |
out_term += A[j] * M_lower[i, j] | |
end | |
out += out_term^2 | |
end | |
sqrt(out) | |
end | |
""" | |
Callable result object of the Gauss-Turán quadrature rule | |
computation algorithm. | |
""" | |
struct GaussTuranResult{T,RType,dfType,deriv!Type} | |
X::Vector{T} | |
A::Matrix{T} | |
res::RType | |
cache::GaussTuranCache | |
df::dfType | |
deriv!::deriv!Type | |
f_tmp::Vector{T} | |
function GaussTuranResult(res, cache::GaussTuranCache{T}, df, deriv!) where {T} | |
(; A_buffer, s, n, N) = cache | |
X = cumsum(res.minimizer) .+ a | |
df.f(res.minimizer) | |
A = reshape(A_buffer[T[], N], (n, 2s + 1)) | |
f_tmp = zeros(T, 2s + 1) | |
new{T,typeof(res),typeof(df),typeof(deriv!)}(X, A, res, cache, df, deriv!, f_tmp) | |
end | |
end | |
""" | |
Input: function f(x, d) which gives the dth derivative of f | |
""" | |
function (I::GaussTuranResult{T} where {T})(integrand) | |
(; X, A, cache, deriv!, f_tmp) = I | |
out = zero(eltype(X)) | |
for (i, x) in enumerate(X) | |
deriv!(f_tmp, integrand, x) | |
for m = 1:(2*cache.s+1) | |
out += A[i, m] * f_tmp[m] | |
end | |
end | |
out | |
end | |
""" | |
GaussTuran(f!, a, b, n, s; w = DEFAULT_w, ε = nothing, X₀ = nothing) | |
Determine a quadrature rule | |
I(g) = ∑ₘ∑ᵢ Aₘᵢ * ∂ᵐ⁻¹g(xᵢ) (m = 1, … 2s + 1, i = 1, …, n) | |
that gives the precise integral ₐ∫ᵇf(x)dx for given linearly independent functions f₁, f₂, …, f₂₍ₛ₊₁₎ₙ. | |
Method: | |
The equations | |
∑ₘ∑ᵢ Aₘᵢ * ∂ᵐ⁻¹fⱼ(xᵢ) = ₐ∫ᵇw(x)fⱼ(x)dx j = 1, …, 2(s+1)n | |
define an overdetermined linear system M(X)A = b in the weights Aₘᵢ for a given X = (x₁, x₂, …, xₙ). | |
We split the matrix M into a square upper part M_upper of size (2s+1)n x (2s+1)n and a lower part M_lower of size n x (2s+1)n, | |
and the right hand size b analogously. From this we obtain A = M_upper⁻¹ * b_upper. Then we can asses the correctness of X by comparing | |
M_lower * A to b_lower, i.e. how well the last n equations holds. This yields the loss function | |
loss(X) = ||M_lower * A - b_lower||₂ = ||M_lower * M_upper⁻¹ * b_upper - b_lower||₂. | |
We have the constraints that we want X to be ordered and in the interval (a, b). To achieve this, we formulate the loss | |
in terms of ΔX = (Δx₁, Δx₂, …, Δxₙ) = (x₁ - a, x₂ - x₁, …, xₙ - xₙ₋₁) on which we set the constraints | |
ε ≤ Δxᵢ ≤ b - a - 2 * ε i = 1, …, n | |
nε ≤ a + ∑ΔX ≤ b - a - ε | |
where ε is an enforced minimum distance between the nodes. This prevents that consecutive nodes converge towards eachother making | |
M_upper singular. | |
## Inputs | |
- `f`: Function with signature `f(x::T, j)::T` that returns the d-th derivative of fⱼ at x | |
- `deriv!`: Function generated with @generate_derivs(2s) for automatically computing the derivatives of the fⱼ | |
- `a`: Integration lower bound | |
- `b`: Integration upper bound | |
- `n`: The number of nodes in the quadrature rule | |
- `s`: Determines the highest order derivative required from the functions fⱼ, currently 2(s + 1) | |
## Keyword Arguments | |
- `w`: the integrand weighting function, must have signature w(x::Number)::Number. Defaults to `w(x) = 1`. | |
- `ε`: the minimum distance between nodes. Defaults to 1e-3 * (b - a) / (n + 1). | |
- `X₀`: The initial guess for the nodes. Defaults to uniformly distributed over (a, b). | |
- `integration_kwargs`: The key word arguments passed to `solve` for integrating w * fⱼ | |
- `optimization_kwargs`: The key word arguments passed to `Optim.Options` for the minization problem | |
for finding X. | |
""" | |
function GaussTuran( | |
f, | |
deriv!, | |
a::T, | |
b::T, | |
n, | |
s; | |
w = DEFAULT_w, | |
ε = nothing, | |
X₀ = nothing, | |
integration_kwargs::NamedTuple = (; reltol = 1e-120), | |
optimization_options::Optim.Options = Optim.Options(), | |
) where {T<:AbstractFloat} | |
# Initial guess | |
if isnothing(X₀) | |
X₀ = collect(range(a, b, length = n + 2)[2:(end-1)]) | |
else | |
@assert length(X₀) == n | |
end | |
ΔX₀ = diff(X₀) | |
pushfirst!(ΔX₀, X₀[1] - a) | |
# Minimum distance between nodes | |
if isnothing(ε) | |
ε = 1e-3 * (b - a) / (n + 1) | |
else | |
@assert 0 < ε ≤ (b - a) / (n + 1) | |
end | |
ε = T(ε) | |
# Integrate w * f | |
integrand = (out, x, j) -> out[] = w(x) * f(x, j) | |
function integrate(j) | |
prob = IntegralProblem{true}(integrand, (a, b), j) | |
res = solve(prob, QuadGKJL(); integration_kwargs...) | |
res.u[] | |
end | |
N = (2s + 1) * n | |
rhs_upper = [integrate(j) for j = 1:N] | |
rhs_lower = [integrate(j) for j = (N+1):(N+n)] | |
# Solving constrained non linear problem for ΔX, see | |
# https://julianlsolvers.github.io/Optim.jl/stable/examples/generated/ipnewton_basics/ | |
# The cache for evaluating GaussTuranLoss | |
cache = GaussTuranCache(n, s, N, a, b, ε, rhs_upper, rhs_lower) | |
# In-place derivative computation of f | |
function f!(out, x, j::Int)::Nothing | |
deriv!(out, x -> f(x, j), x) | |
end | |
# The function whose root defines the quadrature rule | |
# Note: the optimization method requires a Hessian, | |
# which brings the highest order derivative required to 2s + 2 | |
func(ΔX) = GaussTuranLoss!(f!, ΔX, cache) | |
df = TwiceDifferentiable(func, ΔX₀; autodiff = :forward) | |
# The constraints on ΔX | |
ΔX_lb = fill(ε, length(ΔX₀)) | |
ΔX_ub = fill(b - a - 2 * ε, length(ΔX₀)) | |
# Defining the variable and constraints nε ≤ a + ∑ΔX ≤ b - a - ε | |
sum_variable!(c, ΔX) = (c[1] = a + sum(ΔX); c) | |
sum_jacobian!(J, ΔX) = (J[1, :] .= one(eltype(ΔX)); J) | |
sum_hessian!(H, ΔX, λ) = nothing | |
sum_lb = [n * ε] | |
sum_ub = [b - a - ε] | |
constraints = TwiceDifferentiableConstraints( | |
sum_variable!, | |
sum_jacobian!, | |
sum_hessian!, | |
ΔX_lb, | |
ΔX_ub, | |
sum_lb, | |
sum_ub, | |
) | |
# Solve for the quadrature rule by minimizing the loss function | |
res = Optim.optimize(df, constraints, T.(ΔX₀), IPNewton(), optimization_options) | |
GaussTuranResult(res, cache, df, deriv!) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment