Skip to content

Instantly share code, notes, and snippets.

@XerxesZorgon
Last active June 6, 2022 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save XerxesZorgon/4bc86e874f27fd8fb6aced87735510a0 to your computer and use it in GitHub Desktop.
Save XerxesZorgon/4bc86e874f27fd8fb6aced87735510a0 to your computer and use it in GitHub Desktop.
Estimates parameters required to fit sum of Gaussians to a function
#=
Parameters required to fit sum of Gaussians to a function,
argmin(α,μ,σ) || f(x) - ∑ αᵢ exp(-(x-μᵢ)²/(2 σᵢ²)) ||²
Load with: include("fitGaussParams.jl")
or includet("fitGaussParams.jl") if using Revise.jl
Written by: John Peach 08-Nov-2021
Wild Peaches
=#
using LsqFit, Plots
# Function to be fit
f(x) = 0 < x <= 1 ? x : 0
#-----------------------------------------------------------------------------
"""
fitParams
Fit n Gaussians to f(x) returning parameter vectors α, μ, σ.
# Parameters
n: Number of Gaussians to use in the sum
f: Fit function
x: Range of x-values
# Returns
fit: LsqFit.LsqFitResult
fit.param contains:
α: Scaling parameter (n x 1)
μ: Means (n x 1)
σ: Standard deviations (n x 1)
# Example
Fit 4 Gaussians to a ramp function, f, over the range [-0.5,1.5]
x = Vector(-0.5:0.01:1.5)
fit = fitParams(4,f,x)
"""
function fitParams(n::Int,f,x::Vector)
# Initial parameter values
α₀ = 1 .- (1/2) .^ (1:n)
μ₀ = 1 .- (1/2) .^ (1:n)
σ₀ = (1/4) .^ (1:n)
p₀ = [α₀; μ₀; σ₀]
n = length(p₀)
# Function values
y = [f(x_i) for x_i in x]
# Bounds
lb = zeros(n)
ub = ones(n)
# Fit to the function values using Gaussians
fit = curve_fit(gaussSum, x, y, p₀, lower = lb, upper = ub)
end
#-----------------------------------------------------------------------------
"""
gaussSum
Sum of Gaussians using parameters α, μ, σ
# Parameters
p: Parameter vector = [α, μ, σ]
# Returns
y_fit: g(x,α,μ,σ) = ∑ αᵢ exp(-(x-μᵢ)²/(2 σᵢ²))
# Example
x = Vector(-0.5:0.01:1.5)
fit = fitParams(4,f,x)
p = coef(fit)
y_fit = gaussSum(x,p)
"""
function gaussSum(x::Vector,p::Vector)
# Extract parameters
p = reshape(p,:,3)
α = p[:,1]
μ = p[:,2]
σ = p[:,3]
n = length(α)
# Gaussian sum and fitted y-values
g(x,α,μ,σ) = sum( α[i] * exp( - (x-μ[i])^2/(2*σ[i]^2) ) for i = 1:n)
y_fit = [g(x_i,α,μ,σ) for x_i in x]
end
#-----------------------------------------------------------------------------
"""
plotFit
Plots function f(x) and best fit g(x)
# Parameters
x: Range of x-values
# Returns
Plot of f(x) and sum of Gaussians fitting f using parameters found in gaussSum
# Example
x = Vector(-0.5:0.01:1.5)
fit = fitParams(4,f,x)
plotFit(x,fit)
"""
function plotFit(x,fit)
# Function values
y = [f(x_i) for x_i in x]
# Parameters
p = reshape(coef(fit),:,3)
α = p[:,1]
μ = p[:,2]
σ = p[:,3]
n = length(α)
# Gaussian sum and fitted y-values
g(x,α,μ,σ) = sum( α[i] * exp( - (x-μ[i])^2/(2*σ[i]^2) ) for i = 1:n)
y_fit = [g(x_i,α,μ,σ) for x_i in x]
# Plot
plot(x,[y, y_fit], title = "Gaussian Fit", label = ["f(x)" "g(x)"], lw = 3)
end
@XerxesZorgon
Copy link
Author

I see the problem now. The function fitParams actually returns a single parameter, fit, which is type LsqFit.LsqFitResult. To extract the individual variables α,μ,σ use either p = coef(fit) or p = fit.param and which is 3n long. Split p into equal length sections to get each variable. The function plotFit does this to create the plot. I've updated the comments in the functions, so hopefully they're clearer.

@stevetyson
Copy link

Thanks for the incredibly swift response John. If I had more experience with Julia I should have spotted this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment