Skip to content

Instantly share code, notes, and snippets.

@XerxesZorgon
Last active June 6, 2022 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save XerxesZorgon/4bc86e874f27fd8fb6aced87735510a0 to your computer and use it in GitHub Desktop.
Save XerxesZorgon/4bc86e874f27fd8fb6aced87735510a0 to your computer and use it in GitHub Desktop.
Estimates parameters required to fit sum of Gaussians to a function
#=
Parameters required to fit sum of Gaussians to a function,
argmin(α,μ,σ) || f(x) - ∑ αᵢ exp(-(x-μᵢ)²/(2 σᵢ²)) ||²
Load with: include("fitGaussParams.jl")
or includet("fitGaussParams.jl") if using Revise.jl
Written by: John Peach 08-Nov-2021
Wild Peaches
=#
using LsqFit, Plots
# Function to be fit
f(x) = 0 < x <= 1 ? x : 0
#-----------------------------------------------------------------------------
"""
fitParams
Fit n Gaussians to f(x) returning parameter vectors α, μ, σ.
# Parameters
n: Number of Gaussians to use in the sum
f: Fit function
x: Range of x-values
# Returns
fit: LsqFit.LsqFitResult
fit.param contains:
α: Scaling parameter (n x 1)
μ: Means (n x 1)
σ: Standard deviations (n x 1)
# Example
Fit 4 Gaussians to a ramp function, f, over the range [-0.5,1.5]
x = Vector(-0.5:0.01:1.5)
fit = fitParams(4,f,x)
"""
function fitParams(n::Int,f,x::Vector)
# Initial parameter values
α₀ = 1 .- (1/2) .^ (1:n)
μ₀ = 1 .- (1/2) .^ (1:n)
σ₀ = (1/4) .^ (1:n)
p₀ = [α₀; μ₀; σ₀]
n = length(p₀)
# Function values
y = [f(x_i) for x_i in x]
# Bounds
lb = zeros(n)
ub = ones(n)
# Fit to the function values using Gaussians
fit = curve_fit(gaussSum, x, y, p₀, lower = lb, upper = ub)
end
#-----------------------------------------------------------------------------
"""
gaussSum
Sum of Gaussians using parameters α, μ, σ
# Parameters
p: Parameter vector = [α, μ, σ]
# Returns
y_fit: g(x,α,μ,σ) = ∑ αᵢ exp(-(x-μᵢ)²/(2 σᵢ²))
# Example
x = Vector(-0.5:0.01:1.5)
fit = fitParams(4,f,x)
p = coef(fit)
y_fit = gaussSum(x,p)
"""
function gaussSum(x::Vector,p::Vector)
# Extract parameters
p = reshape(p,:,3)
α = p[:,1]
μ = p[:,2]
σ = p[:,3]
n = length(α)
# Gaussian sum and fitted y-values
g(x,α,μ,σ) = sum( α[i] * exp( - (x-μ[i])^2/(2*σ[i]^2) ) for i = 1:n)
y_fit = [g(x_i,α,μ,σ) for x_i in x]
end
#-----------------------------------------------------------------------------
"""
plotFit
Plots function f(x) and best fit g(x)
# Parameters
x: Range of x-values
# Returns
Plot of f(x) and sum of Gaussians fitting f using parameters found in gaussSum
# Example
x = Vector(-0.5:0.01:1.5)
fit = fitParams(4,f,x)
plotFit(x,fit)
"""
function plotFit(x,fit)
# Function values
y = [f(x_i) for x_i in x]
# Parameters
p = reshape(coef(fit),:,3)
α = p[:,1]
μ = p[:,2]
σ = p[:,3]
n = length(α)
# Gaussian sum and fitted y-values
g(x,α,μ,σ) = sum( α[i] * exp( - (x-μ[i])^2/(2*σ[i]^2) ) for i = 1:n)
y_fit = [g(x_i,α,μ,σ) for x_i in x]
# Plot
plot(x,[y, y_fit], title = "Gaussian Fit", label = ["f(x)" "g(x)"], lw = 3)
end
@stevetyson
Copy link

I'm a bit new to Julia and I'm struggling with curve_fit

If I use this code

include("fitGaussParams.jl")
x = Vector(-0.5:0.01:1.5)
f(x) = 0 < x <= 1 ? x : 0
α,μ,σ = fitParams(4,f,x)

I get this...

MethodError: no method matching iterate(::LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Real}})
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at /Applications/Julia-1.7.3.app/Contents/Resources/julia/share/julia/base/range.jl:826
iterate(::Union{LinRange, StepRangeLen}, ::Integer) at /Applications/Julia-1.7.3.app/Contents/Resources/julia/share/julia/base/range.jl:826
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at /Applications/Julia-1.7.3.app/Contents/Resources/julia/share/julia/base/dict.jl:695
...

Stacktrace:
[1] indexed_iterate(I::LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Real}}, i::Int64)
@ Base ./tuple.jl:92
[2] top-level scope
@ In[13]:4
[3] eval
@ ./boot.jl:373 [inlined]
[4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1196

I'm using Julia 1.7.3 and I really don't have any idea how to solve this...

Any ideas?

Thanks, Steve Tyson

@XerxesZorgon
Copy link
Author

Steve,
Thanks for pointing this out. I'm getting the same error, also with 1.7.3. I'm not sure what's going on yet, but I'll post something when I find it. I'm pretty sure it was working when I first posted it: https://wildpeaches.xyz/blog/curve-fitting-with-julia/
John
https://www.facebook.com/groups/253238143289738

@XerxesZorgon
Copy link
Author

I see the problem now. The function fitParams actually returns a single parameter, fit, which is type LsqFit.LsqFitResult. To extract the individual variables α,μ,σ use either p = coef(fit) or p = fit.param and which is 3n long. Split p into equal length sections to get each variable. The function plotFit does this to create the plot. I've updated the comments in the functions, so hopefully they're clearer.

@stevetyson
Copy link

Thanks for the incredibly swift response John. If I had more experience with Julia I should have spotted this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment