-
-
Save XerxesZorgon/4bc86e874f27fd8fb6aced87735510a0 to your computer and use it in GitHub Desktop.
#= | |
Parameters required to fit sum of Gaussians to a function, | |
argmin(α,μ,σ) || f(x) - ∑ αᵢ exp(-(x-μᵢ)²/(2 σᵢ²)) ||² | |
Load with: include("fitGaussParams.jl") | |
or includet("fitGaussParams.jl") if using Revise.jl | |
Written by: John Peach 08-Nov-2021 | |
Wild Peaches | |
=# | |
using LsqFit, Plots | |
# Function to be fit | |
f(x) = 0 < x <= 1 ? x : 0 | |
#----------------------------------------------------------------------------- | |
""" | |
fitParams | |
Fit n Gaussians to f(x) returning parameter vectors α, μ, σ. | |
# Parameters | |
n: Number of Gaussians to use in the sum | |
f: Fit function | |
x: Range of x-values | |
# Returns | |
fit: LsqFit.LsqFitResult | |
fit.param contains: | |
α: Scaling parameter (n x 1) | |
μ: Means (n x 1) | |
σ: Standard deviations (n x 1) | |
# Example | |
Fit 4 Gaussians to a ramp function, f, over the range [-0.5,1.5] | |
x = Vector(-0.5:0.01:1.5) | |
fit = fitParams(4,f,x) | |
""" | |
function fitParams(n::Int,f,x::Vector) | |
# Initial parameter values | |
α₀ = 1 .- (1/2) .^ (1:n) | |
μ₀ = 1 .- (1/2) .^ (1:n) | |
σ₀ = (1/4) .^ (1:n) | |
p₀ = [α₀; μ₀; σ₀] | |
n = length(p₀) | |
# Function values | |
y = [f(x_i) for x_i in x] | |
# Bounds | |
lb = zeros(n) | |
ub = ones(n) | |
# Fit to the function values using Gaussians | |
fit = curve_fit(gaussSum, x, y, p₀, lower = lb, upper = ub) | |
end | |
#----------------------------------------------------------------------------- | |
""" | |
gaussSum | |
Sum of Gaussians using parameters α, μ, σ | |
# Parameters | |
p: Parameter vector = [α, μ, σ] | |
# Returns | |
y_fit: g(x,α,μ,σ) = ∑ αᵢ exp(-(x-μᵢ)²/(2 σᵢ²)) | |
# Example | |
x = Vector(-0.5:0.01:1.5) | |
fit = fitParams(4,f,x) | |
p = coef(fit) | |
y_fit = gaussSum(x,p) | |
""" | |
function gaussSum(x::Vector,p::Vector) | |
# Extract parameters | |
p = reshape(p,:,3) | |
α = p[:,1] | |
μ = p[:,2] | |
σ = p[:,3] | |
n = length(α) | |
# Gaussian sum and fitted y-values | |
g(x,α,μ,σ) = sum( α[i] * exp( - (x-μ[i])^2/(2*σ[i]^2) ) for i = 1:n) | |
y_fit = [g(x_i,α,μ,σ) for x_i in x] | |
end | |
#----------------------------------------------------------------------------- | |
""" | |
plotFit | |
Plots function f(x) and best fit g(x) | |
# Parameters | |
x: Range of x-values | |
# Returns | |
Plot of f(x) and sum of Gaussians fitting f using parameters found in gaussSum | |
# Example | |
x = Vector(-0.5:0.01:1.5) | |
fit = fitParams(4,f,x) | |
plotFit(x,fit) | |
""" | |
function plotFit(x,fit) | |
# Function values | |
y = [f(x_i) for x_i in x] | |
# Parameters | |
p = reshape(coef(fit),:,3) | |
α = p[:,1] | |
μ = p[:,2] | |
σ = p[:,3] | |
n = length(α) | |
# Gaussian sum and fitted y-values | |
g(x,α,μ,σ) = sum( α[i] * exp( - (x-μ[i])^2/(2*σ[i]^2) ) for i = 1:n) | |
y_fit = [g(x_i,α,μ,σ) for x_i in x] | |
# Plot | |
plot(x,[y, y_fit], title = "Gaussian Fit", label = ["f(x)" "g(x)"], lw = 3) | |
end |
Steve,
Thanks for pointing this out. I'm getting the same error, also with 1.7.3. I'm not sure what's going on yet, but I'll post something when I find it. I'm pretty sure it was working when I first posted it: https://wildpeaches.xyz/blog/curve-fitting-with-julia/
John
https://www.facebook.com/groups/253238143289738
I see the problem now. The function fitParams actually returns a single parameter, fit, which is type LsqFit.LsqFitResult. To extract the individual variables α,μ,σ use either p = coef(fit) or p = fit.param and which is 3n long. Split p into equal length sections to get each variable. The function plotFit does this to create the plot. I've updated the comments in the functions, so hopefully they're clearer.
Thanks for the incredibly swift response John. If I had more experience with Julia I should have spotted this.
I'm a bit new to Julia and I'm struggling with curve_fit
If I use this code
I get this...
MethodError: no method matching iterate(::LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Real}})
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at /Applications/Julia-1.7.3.app/Contents/Resources/julia/share/julia/base/range.jl:826
iterate(::Union{LinRange, StepRangeLen}, ::Integer) at /Applications/Julia-1.7.3.app/Contents/Resources/julia/share/julia/base/range.jl:826
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at /Applications/Julia-1.7.3.app/Contents/Resources/julia/share/julia/base/dict.jl:695
...
Stacktrace:
[1] indexed_iterate(I::LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Real}}, i::Int64)
@ Base ./tuple.jl:92
[2] top-level scope
@ In[13]:4
[3] eval
@ ./boot.jl:373 [inlined]
[4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1196
I'm using Julia 1.7.3 and I really don't have any idea how to solve this...
Any ideas?
Thanks, Steve Tyson