Skip to content

Instantly share code, notes, and snippets.

@richinex
Last active November 5, 2022 10:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save richinex/dc7f6677ddb600287ca8cda199b6b872 to your computer and use it in GitHub Desktop.
Save richinex/dc7f6677ddb600287ca8cda199b6b872 to your computer and use it in GitHub Desktop.
julia_optim_adam.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/richinex/dc7f6677ddb600287ca8cda199b6b872/julia_optim_adam.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LQRBEXQ4jRmG"
},
"outputs": [],
"source": [
"using LinearAlgebra, Flux, Zygote, Plots, Parameters, Printf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "s34kNRD0jRmJ",
"outputId": "30a0ff92-9f98-40fa-f799-025f0909377c"
},
"outputs": [
{
"data": {
"text/plain": [
"get_fd (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Tracks the parameters to be kept constant\n",
"function get_kvals(smf::AbstractVector, num_eis::Integer)\n",
" kvals = cumsum(insert!(ifelse.(isinf.(smf), 1, num_eis), 1, 1),)\n",
" return kvals\n",
"end\n",
"\n",
"# Extends a bounds vector of length n to length m by repeating across the number of spectra supplied\n",
"function get_bounds_vector(lb::AbstractVector,ub::AbstractVector, num_eis::Integer, smf::AbstractVector, kvals::AbstractVector)\n",
" num_params = length(lb)\n",
" lb_vec = zeros(num_params * num_eis - (num_eis - 1) * sum(isinf.(smf)))\n",
" ub_vec = zeros(num_params * num_eis- (num_eis - 1) * sum(isinf.(smf)))\n",
" for i = 1:num_params \n",
" lb_vec[kvals[i]:kvals[i + 1]-1] .= lb[i]\n",
" ub_vec[kvals[i]:kvals[i + 1]-1] .= ub[i]\n",
" end\n",
" return lb_vec, ub_vec\n",
"end\n",
"\n",
"# Creates the second orderfinite difference stencil\n",
"function get_fd(m::Integer)\n",
" d2m = zeros(m , m)\n",
" d2m[1, 1:4] .= [2, -5, 4, -1]\n",
" for k = 2:m - 1\n",
" d2m[k, k - 1:k + 1] .= [1, -2, 1]\n",
" end\n",
" d2m[end, end-3:end] .= [-1, 4, -5, 2]\n",
" return d2m\n",
"end\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ER4CNWF7jRmK",
"outputId": "7a167939-9e7d-4681-bd73-c936ac7e442a"
},
"outputs": [
{
"data": {
"text/plain": [
"compute_total_obj (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Using multiple dispatch for the convert_to_internal function\n",
"function convert_to_internal(p::AbstractVector{<:Real}, lb_vec::AbstractVector{<:Real}, ub_vec::AbstractVector{<:Real}, num_eis::Integer, smf::AbstractVector, kvals::AbstractVector)\n",
" num_params = length(p)\n",
" par = repeat(p, outer = (1, num_eis))\n",
" begin\n",
" p_internal = zeros(num_params * num_eis - (num_eis - 1) * sum(isinf.(smf)))\n",
" for i = 1:num_params\n",
" p_internal[kvals[i]:kvals[i + 1]-1] .= par[i, 1:kvals[i + 1] - kvals[i]]\n",
" end\n",
" p_internal = log10.((p_internal .- lb_vec) ./ (1 .- p_internal ./ ub_vec))\n",
" end\n",
" return p_internal\n",
"end\n",
"\n",
"\"\"\"\n",
"Converts A tensor of parameters from an external \\\n",
"to an internal coordinates (log10 scale)\n",
"\n",
":param p: A 1D or 2D tensor of parameter values\n",
"\n",
":returns: Parameters in log10 scale\n",
"\"\"\"\n",
"function convert_to_internal(p::AbstractMatrix{<:Real}, lb_vec::AbstractVector{<:Real}, ub_vec::AbstractVector{<:Real}, num_eis::Integer, smf::AbstractVector, kvals::AbstractVector)\n",
" num_params = size(p, 1)\n",
" begin\n",
" p_internal = zeros(num_params * num_eis - (num_eis - 1) * sum(isinf.(smf)))\n",
" for i = 1:num_params\n",
" p_internal[kvals[i]:kvals[i + 1]-1] .= p[i, 1:kvals[i + 1] - kvals[i]]\n",
" end\n",
" p_internal = log10.((p_internal .- lb_vec) ./ (1 .- p_internal ./ ub_vec))\n",
" end\n",
" return p_internal\n",
"end\n",
"\n",
"\n",
"# Convert to external\n",
"function convert_to_external(P::AbstractVector, lb_vec::AbstractVector, ub_vec::AbstractVector, num_eis::Integer, num_params::Integer, kvals::AbstractVector)\n",
" P_external = zeros(num_params, num_eis)\n",
" for i = 1:num_params\n",
" P_external[i, :] .= (lb_vec[kvals[i]:kvals[i + 1]-1] .+ 10 .^ P[kvals[i]:kvals[i + 1]-1]) ./ (1 .+ (10 .^ P[kvals[i]:kvals[i + 1]-1]) ./ ub_vec[kvals[i]:kvals[i + 1]-1])\n",
" end\n",
" return P_external\n",
"end\n",
"\n",
"\n",
"# Objective function\n",
"function compute_wrss(p::AbstractVector, f::AbstractVector, z::AbstractVector, zerr_re::AbstractVector,zerr_im::AbstractVector, func::Function)\n",
" z_concat = vcat(real.(z), imag.(z))\n",
" sigma = vcat(zerr_re, zerr_im)\n",
" z_model = func(p, f)\n",
" wrss = norm(((z_concat .- z_model) ./ sigma)) .^ 2\n",
" return wrss\n",
"end\n",
"\n",
"\n",
"# Total objective function (includes the smoothing matrix)\n",
"function compute_total_obj(P::AbstractVector, \n",
" F::AbstractArray, \n",
" Z::AbstractArray, \n",
" Zerr_Re::AbstractArray,\n",
" Zerr_Im::AbstractArray, \n",
" LB::AbstractVector, \n",
" UB::AbstractVector, \n",
" smf::AbstractVector, \n",
" func::Function, \n",
" num_params::Integer, \n",
" num_eis::Integer, \n",
" kvals::AbstractVector, \n",
" d2m::AbstractMatrix)\n",
"\n",
" P_log = reduce(hcat, [\n",
" let istart=kvals[i], istop=kvals[i + 1] - 1, ps=@view P[istart:istop]\n",
" istop - istart == 0 ? repeat(ps, num_eis) : ps\n",
" end\n",
" for i = 1:num_params])\n",
"\n",
" P_norm = reduce(hcat, [\n",
" let istart=kvals[i], istop=kvals[i + 1] - 1, ps=@view P[istart:istop]\n",
" ps_norm = @views (LB[istart:istop] .+ (10 .^ ps)) ./ (1 .+ (10 .^ ps) ./ UB[istart:istop])\n",
" istop - istart == 0 ? repeat(ps_norm, num_eis) : ps_norm\n",
" end\n",
" for i = 1:num_params])\n",
" \n",
" smf_1 = ifelse.(isinf.(smf), 0.0, smf)\n",
" chi_smf = sum(sum((d2m * P_log).^2, dims = 1) .* smf_1)\n",
" wrss_tot = compute_wrss.(eachcol(transpose(P_norm)), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)\n",
" return (sum(wrss_tot) + chi_smf)\n",
"end\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GLvrEq4XjRmM",
"outputId": "cf10bf45-ef2b-4274-d002-1f16d7a7a1d4"
},
"outputs": [
{
"data": {
"text/plain": [
"Multieis"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Create a struct to hold the data for optimization\n",
"abstract type Fitting end\n",
"@with_kw struct Multieis <: Fitting\n",
" p0::AbstractArray\n",
" freq::AbstractVector\n",
" Z::AbstractArray\n",
" bounds::Union{Array{Tuple{Number, Number}}, Array{Array{Float64, 1}}}\n",
" smf::AbstractVector\n",
" func::Function\n",
" weight = Nothing\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "D4g-HLiqjRmM",
"outputId": "55beff20-9fb0-4829-8d21-2a94c00895f7"
},
"outputs": [
{
"data": {
"text/plain": [
"fit_stochastic (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# The fitting function\n",
"function fit_stochastic(multieis::Multieis; lr::Number=1e3, num_epochs::Integer=5000)\n",
" size(multieis.p0, 1) == size(bounds, 1) || throw(DimensionMismatch(\"initial guess has a size of $(size(multieis.p0, 1)) while bounds has size $(size(multieis.bounds, 1))\"))\n",
" size(multieis.freq, 1) == size(multieis.Z, 1) || throw(DimensionMismatch(\"The frequency vector has $(size(multieis.freq, 1)) while Z has length $(size(multieis.Z, 1))\"))\n",
" p0 = multieis.p0\n",
" Z = multieis.Z\n",
" num_params = size(p0, 1)\n",
" num_eis = size(Z, 2)\n",
" num_freq = size(Z, 1)\n",
" F = repeat(multieis.freq, outer=(1, num_eis))\n",
" lb = [i[1] for i in multieis.bounds]\n",
" ub = [i[2] for i in multieis.bounds]\n",
"\n",
" all(x->x>0 && isfinite(sum(x)), lb) || throw(ArgumentError(\"Lower bounds must be strictly positive\"))\n",
" all(x->x>0 && isfinite(sum(x)), ub) || throw(ArgumentError(\"Upper bounds must be strictly positive\"))\n",
" all(x->x>0 && isfinite(sum(x)), p0) || throw(ArgumentError(\"Initial guess must be strictly positive\"))\n",
"\n",
" map(1:size(p0, 2)) do i\n",
" if lb <= p0[:,i] <= ub\n",
" return true\n",
" else\n",
" throw(ArgumentError(\"Initial guess must be within the bounds\"))\n",
" end\n",
" end\n",
"\n",
" smf = multieis.smf \n",
" func = multieis.func\n",
" weight = multieis.weight \n",
"\n",
" dof = (2 * num_freq * num_eis) - (num_params * num_eis)\n",
" weight_type = \"\"\n",
" if weight == Nothing || weight == \"modulus\"\n",
" weight_type = \"modulus\"\n",
" println(\"Using $weight_type\")\n",
" Zerr_Re = abs.(Z)\n",
" Zerr_Im = abs.(Z)\n",
" elseif weight == 1 || weight == \"unit\"\n",
" weight_type = \"unit\"\n",
" println(\"Using $weight_type\")\n",
" Zerr_Re = ones(num_freq, num_eis)\n",
" Zerr_Im = ones(num_freq, num_eis)\n",
" elseif weight == 2 || weight == \"proportional\"\n",
" weight_type = \"proportional\"\n",
" println(\"Using $weight_type\")\n",
" Zerr_Re = real.(Z)\n",
" Zerr_Im = imag.(Z)\n",
" elseif weight isa Array{<:Number,2}\n",
" size(weight) == size(Z) || throw(DimensionMismatch(\"Weight matrix must be of size $num_freq x $num_eis\"))\n",
" weight_type = \"sigma\"\n",
" println(\"Using $weight_type\")\n",
" Zerr_Re = weight\n",
" Zerr_Im = weight\n",
" else\n",
" throw(ArgumentError(\"Weight must be either Nothing, 1, 2, or a matrix of size $num_freq x $num_eis\"))\n",
" end\n",
" kvals = get_kvals(smf, num_eis)\n",
" d2m = get_fd(num_eis)\n",
" \n",
" lb_vec, ub_vec = get_bounds_vector(lb, ub, num_eis, smf, kvals)\n",
" par_log = convert_to_internal(p0, lb_vec, ub_vec, num_eis, smf, kvals)\n",
"\n",
" opt = Adam(1e-3)\n",
" \n",
" function train_step!(opt, p)\n",
" gs = gradient(p -> compute_total_obj(p, F, Z, Zerr_Re, Zerr_Im, lb_vec, ub_vec, smf, func, num_params, num_eis, kvals, d2m), p)[1]\n",
" Flux.Optimise.update!(opt, p, gs)\n",
" training_loss = compute_total_obj(p, F, Z, Zerr_Re, Zerr_Im, lb_vec, ub_vec, smf, func, num_params, num_eis, kvals, d2m)\n",
" return training_loss\n",
" end\n",
"\n",
" losses = Array{Float64,1}(undef, num_epochs)\n",
" for epoch in 1:num_epochs\n",
" loss = train_step!(opt, par_log)\n",
"\n",
" if epoch % (num_epochs / 10) == 0\n",
"\n",
" @printf \" \\nEpoch: %i %.2e\" epoch loss\n",
"\n",
" end\n",
" push!(losses, loss)\n",
" \n",
" end\n",
" \n",
" popt = convert_to_external(par_log, lb_vec, ub_vec, num_eis, num_params, kvals)\n",
" chisqr = losses[end]/dof\n",
" return popt, chisqr\n",
" \n",
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nQ29uS7bjRmN",
"outputId": "ece888cc-17e9-41eb-8342-8a8a3c75c57f"
},
"outputs": [
{
"data": {
"text/plain": [
"45×5 Matrix{ComplexF64}:\n",
" 3.98044e-7+1.62846e-6im … 4.51065e-7+1.60541e-6im\n",
" 9.22443e-7+4.3997e-6im 1.05098e-6+4.41408e-6im\n",
" 1.40958e-6+6.86929e-6im 1.50544e-6+6.86008e-6im\n",
" 1.68712e-6+9.25851e-6im 1.85463e-6+9.18734e-6im\n",
" 2.15576e-6+1.16992e-5im 2.29701e-6+1.15904e-5im\n",
" 2.58288e-6+1.40525e-5im … 2.62927e-6+1.38897e-5im\n",
" 3.31643e-6+1.85013e-5im 3.50368e-6+1.81772e-5im\n",
" 4.1251e-6+2.25237e-5im 4.29224e-6+2.23151e-5im\n",
" 5.77924e-6+3.31867e-5im 5.84011e-6+3.27982e-5im\n",
" 7.04483e-6+4.16689e-5im 7.15912e-6+4.09554e-5im\n",
" ⋮ ⋱ \n",
" 0.00485721+0.00219639im 0.00483639+0.00221043im\n",
" 0.00519651+0.00193762im 0.0051763+0.00195206im\n",
" 0.00547554+0.00168773im 0.00546219+0.00169376im\n",
" 0.00567759+0.00142167im 0.00567186+0.00144072im\n",
" 0.00584337+0.00120199im … 0.00584316+0.00121376im\n",
" 0.00596777+0.00100387im 0.00596019+0.00101026im\n",
" 0.00603622+0.000831302im 0.00602592+0.00084921im\n",
" 0.00609622+0.000691662im 0.00610066+0.000707615im\n",
" 0.00613997+0.000559578im 0.0061315+0.000567872im"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Generate the data (frequency(F) and admittance (ydata))\n",
"F = ([8.00000e+00, 2.50000e+01, 4.20000e+01, 5.90000e+01,\n",
"7.60000e+01, 9.30000e+01, 1.27000e+02, 1.61000e+02,\n",
"2.46000e+02, 3.14000e+02, 3.99000e+02, 5.01000e+02,\n",
"6.37000e+02, 7.88000e+02, 9.93000e+02, 1.25200e+03,\n",
"1.58500e+03, 1.99500e+03, 2.51200e+03, 3.16200e+03,\n",
"3.98100e+03, 5.01200e+03, 6.31000e+03, 7.94300e+03,\n",
"1.00000e+04, 1.25890e+04, 1.58490e+04, 1.99530e+04,\n",
"2.51190e+04, 3.16230e+04, 3.98110e+04, 5.01190e+04,\n",
"6.30960e+04, 7.94330e+04, 1.00000e+05, 1.25893e+05,\n",
"1.58490e+05, 1.99527e+05, 2.51189e+05, 3.16228e+05,\n",
"3.98108e+05, 5.01188e+05, 6.30958e+05, 7.94329e+05,\n",
"1.00000e+06])\n",
"\n",
"Y = vcat(transpose.([[3.98043994e-07+1.62846334e-06im,\n",
" 3.83471274e-07+1.60171055e-06im,\n",
" 3.83865881e-07+1.58502473e-06im,\n",
" 4.04896667e-07+1.58470630e-06im,\n",
" 4.51064921e-07+1.60540776e-06im],\n",
"[9.22443178e-07+4.39969926e-06im,\n",
" 9.06191531e-07+4.38099732e-06im,\n",
" 9.16687611e-07+4.37470180e-06im,\n",
" 9.62976515e-07+4.38477446e-06im,\n",
" 1.05098138e-06+4.41408429e-06im],\n",
"[1.40957809e-06+6.86928752e-06im,\n",
" 1.38162397e-06+6.83126518e-06im,\n",
" 1.38059693e-06+6.81408665e-06im,\n",
" 1.41883277e-06+6.82275413e-06im,\n",
" 1.50544156e-06+6.86007979e-06im],\n",
"[1.68711949e-06+9.25850964e-06im,\n",
" 1.66052735e-06+9.20596722e-06im,\n",
" 1.67216979e-06+9.17365924e-06im,\n",
" 1.73414162e-06+9.16654426e-06im,\n",
" 1.85462545e-06+9.18734440e-06im],\n",
"[2.15576119e-06+1.16992223e-05im,\n",
" 2.10827579e-06+1.16345200e-05im,\n",
" 2.10738881e-06+1.15910179e-05im,\n",
" 2.16757689e-06+1.15751991e-05im,\n",
" 2.29701277e-06+1.15904422e-05im],\n",
"[2.58288446e-06+1.40525399e-05im,\n",
" 2.51553729e-06+1.39863741e-05im,\n",
" 2.48908032e-06+1.39332806e-05im,\n",
" 2.52246764e-06+1.38992500e-05im,\n",
" 2.62926756e-06+1.38897467e-05im],\n",
"[3.31642832e-06+1.85012723e-05im,\n",
" 3.29454474e-06+1.83918928e-05im,\n",
" 3.31309980e-06+1.82958465e-05im,\n",
" 3.38095288e-06+1.82216481e-05im,\n",
" 3.50368350e-06+1.81772357e-05im],\n",
"[4.12509917e-06+2.25237090e-05im,\n",
" 4.10579196e-06+2.24286177e-05im,\n",
" 4.12116287e-06+2.23589195e-05im,\n",
" 4.18102400e-06+2.23195184e-05im,\n",
" 4.29224065e-06+2.23151201e-05im],\n",
"[5.77924402e-06+3.31866577e-05im,\n",
" 5.69006625e-06+3.30411312e-05im,\n",
" 5.66323070e-06+3.29221439e-05im,\n",
" 5.71128066e-06+3.28385904e-05im,\n",
" 5.84011150e-06+3.27982052e-05im],\n",
"[7.04483091e-06+4.16688999e-05im,\n",
" 6.98463055e-06+4.14659116e-05im,\n",
" 6.97550695e-06+4.12703739e-05im,\n",
" 7.03040587e-06+4.10956090e-05im,\n",
" 7.15912302e-06+4.09553868e-05im],\n",
"[8.63866353e-06+5.13732193e-05im,\n",
" 8.50426386e-06+5.11485632e-05im,\n",
" 8.42278678e-06+5.09697711e-05im,\n",
" 8.41987094e-06+5.08404373e-05im,\n",
" 8.51576897e-06+5.07598197e-05im],\n",
"[1.03729171e-05+6.31615694e-05im,\n",
" 1.02327485e-05+6.28848429e-05im,\n",
" 1.01667838e-05+6.26310721e-05im,\n",
" 1.01907526e-05+6.24130189e-05im,\n",
" 1.03141883e-05+6.22438456e-05im],\n",
"[1.28909396e-05+7.83684227e-05im,\n",
" 1.27298499e-05+7.80434129e-05im,\n",
" 1.26394616e-05+7.77378664e-05im,\n",
" 1.26342220e-05+7.74690998e-05im,\n",
" 1.27239173e-05+7.72554340e-05im],\n",
"[1.55678263e-05+9.51368056e-05im,\n",
" 1.53954279e-05+9.47842200e-05im,\n",
" 1.52970006e-05+9.44365966e-05im,\n",
" 1.52850498e-05+9.41093895e-05im,\n",
" 1.53678557e-05+9.38235098e-05im],\n",
"[1.90616338e-05+1.17447395e-04im,\n",
" 1.88478443e-05+1.17017706e-04im,\n",
" 1.87262340e-05+1.16581010e-04im,\n",
" 1.87125606e-05+1.16152383e-04im,\n",
" 1.88150934e-05+1.15755727e-04im],\n",
"[2.36724445e-05+1.44870515e-04im,\n",
" 2.34057225e-05+1.44312377e-04im,\n",
" 2.32291331e-05+1.43749552e-04im,\n",
" 2.31617378e-05+1.43208046e-04im,\n",
" 2.32165712e-05+1.42720542e-04im],\n",
"[2.98377254e-05+1.79296054e-04im,\n",
" 2.95218451e-05+1.78680115e-04im,\n",
" 2.92895420e-05+1.78052374e-04im,\n",
" 2.91592060e-05+1.77445327e-04im,\n",
" 2.91458418e-05+1.76899135e-04im],\n",
"[3.71368005e-05+2.21134513e-04im,\n",
" 3.67530884e-05+2.20456888e-04im,\n",
" 3.64819971e-05+2.19732538e-04im,\n",
" 3.63438885e-05+2.19003981e-04im,\n",
" 3.63522122e-05+2.18327899e-04im],\n",
"[4.71590574e-05+2.72656704e-04im,\n",
" 4.66598431e-05+2.71883619e-04im,\n",
" 4.62840071e-05+2.71045399e-04im,\n",
" 4.60535412e-05+2.70186109e-04im,\n",
" 4.59836847e-05+2.69369048e-04im],\n",
"[6.05509194e-05+3.36161669e-04im,\n",
" 5.98841943e-05+3.35299905e-04im,\n",
" 5.93654004e-05+3.34340759e-04im,\n",
" 5.90322197e-05+3.33344913e-04im,\n",
" 5.89099800e-05+3.32394353e-04im],\n",
"[7.88562465e-05+4.14546928e-04im,\n",
" 7.81773488e-05+4.13466740e-04im,\n",
" 7.76331726e-05+4.12246969e-04im,\n",
" 7.72511048e-05+4.10973036e-04im,\n",
" 7.70551051e-05+4.09756583e-04im],\n",
"[1.04056926e-04+5.10462851e-04im,\n",
" 1.03213046e-04+5.09286532e-04im,\n",
" 1.02485879e-04+5.07898745e-04im,\n",
" 1.01920967e-04+5.06399898e-04im,\n",
" 1.01564037e-04+5.04926487e-04im],\n",
"[1.39663694e-04+6.28142734e-04im,\n",
" 1.38702802e-04+6.26766239e-04im,\n",
" 1.37810479e-04+6.25117798e-04im,\n",
" 1.37032752e-04+6.23326225e-04im,\n",
" 1.36427101e-04+6.21561194e-04im],\n",
"[1.88594946e-04+7.69385544e-04im,\n",
" 1.87430560e-04+7.67729071e-04im,\n",
" 1.86385252e-04+7.65808218e-04im,\n",
" 1.85518133e-04+7.63777934e-04im,\n",
" 1.84892968e-04+7.61830714e-04im],\n",
"[2.59671040e-04+9.40661295e-04im,\n",
" 2.58338725e-04+9.38724843e-04im,\n",
" 2.56982952e-04+9.36332683e-04im,\n",
" 2.55680061e-04+9.33679519e-04im,\n",
" 2.54537823e-04+9.31019720e-04im],\n",
"[3.59062746e-04+1.14189147e-03im,\n",
" 3.57378914e-04+1.14002009e-03im,\n",
" 3.55668395e-04+1.13755930e-03im,\n",
" 3.54003307e-04+1.13472599e-03im,\n",
" 3.52498580e-04+1.13181409e-03im],\n",
"[4.98924463e-04+1.37415191e-03im,\n",
" 4.96647903e-04+1.37173687e-03im,\n",
" 4.94265929e-04+1.36869797e-03im,\n",
" 4.91948973e-04+1.36530725e-03im,\n",
" 4.89903730e-04+1.36190688e-03im],\n",
"[6.93496200e-04+1.63521001e-03im,\n",
" 6.90673129e-04+1.63290917e-03im,\n",
" 6.87621825e-04+1.62983825e-03im,\n",
" 6.84574014e-04+1.62626372e-03im,\n",
" 6.81804435e-04+1.62254740e-03im],\n",
"[9.57296055e-04+1.91309035e-03im,\n",
" 9.53748415e-04+1.91089732e-03im,\n",
" 9.49627138e-04+1.90796948e-03im,\n",
" 9.45268781e-04+1.90456549e-03im,\n",
" 9.41093254e-04+1.90101075e-03im],\n",
"[1.30123761e-03+2.19320343e-03im,\n",
" 1.29747635e-03+2.19098944e-03im,\n",
" 1.29290798e-03+2.18768464e-03im,\n",
" 1.28791679e-03+2.18362780e-03im,\n",
" 1.28301373e-03+2.17926595e-03im],\n",
"[1.73074182e-03+2.43863533e-03im,\n",
" 1.72621978e-03+2.43695825e-03im,\n",
" 1.72084465e-03+2.43458990e-03im,\n",
" 1.71503122e-03+2.43169256e-03im,\n",
" 1.70930568e-03+2.42848461e-03im],\n",
"[2.23910459e-03+2.62897694e-03im,\n",
" 2.23373249e-03+2.62814597e-03im,\n",
" 2.22713780e-03+2.62692641e-03im,\n",
" 2.21994217e-03+2.62538018e-03im,\n",
" 2.21290882e-03+2.62357481e-03im],\n",
"[2.80288863e-03+2.72976886e-03im,\n",
" 2.79733306e-03+2.72965804e-03im,\n",
" 2.79038353e-03+2.72933533e-03im,\n",
" 2.78272317e-03+2.72875628e-03im,\n",
" 2.77517387e-03+2.72789295e-03im],\n",
"[3.38658039e-03+2.72387289e-03im,\n",
" 3.38097266e-03+2.72540422e-03im,\n",
" 3.37401521e-03+2.72714160e-03im,\n",
" 3.36621539e-03+2.72886851e-03im,\n",
" 3.35824443e-03+2.73031136e-03im],\n",
"[3.94340698e-03+2.61988142e-03im,\n",
" 3.93845234e-03+2.62111914e-03im,\n",
" 3.93158384e-03+2.62272917e-03im,\n",
" 3.92330065e-03+2.62457621e-03im,\n",
" 3.91434226e-03+2.62642140e-03im],\n",
"[4.43841098e-03+2.43232748e-03im,\n",
" 4.43451665e-03+2.43382109e-03im,\n",
" 4.42952756e-03+2.43561435e-03im,\n",
" 4.42374917e-03+2.43758317e-03im,\n",
" 4.41762246e-03+2.43955571e-03im],\n",
"[4.85721370e-03+2.19639274e-03im,\n",
" 4.85349493e-03+2.19924166e-03im,\n",
" 4.84847510e-03+2.20281631e-03im,\n",
" 4.84257983e-03+2.20671482e-03im,\n",
" 4.83638886e-03+2.21042964e-03im],\n",
"[5.19651314e-03+1.93761603e-03im,\n",
" 5.19210519e-03+1.94056542e-03im,\n",
" 5.18698012e-03+1.94421073e-03im,\n",
" 5.18155703e-03+1.94818689e-03im,\n",
" 5.17630065e-03+1.95205770e-03im],\n",
"[5.47554484e-03+1.68772519e-03im,\n",
" 5.47296507e-03+1.68850122e-03im,\n",
" 5.46988007e-03+1.68990286e-03im,\n",
" 5.46627119e-03+1.69175409e-03im,\n",
" 5.46219060e-03+1.69376354e-03im],\n",
"[5.67758968e-03+1.42167136e-03im,\n",
" 5.67611866e-03+1.42599957e-03im,\n",
" 5.67469001e-03+1.43104268e-03im,\n",
" 5.67328278e-03+1.43618439e-03im,\n",
" 5.67186205e-03+1.44071598e-03im],\n",
"[5.84336650e-03+1.20199029e-03im,\n",
" 5.84407616e-03+1.20482082e-03im,\n",
" 5.84446732e-03+1.20796752e-03im,\n",
" 5.84423216e-03+1.21105090e-03im,\n",
" 5.84315788e-03+1.21376407e-03im],\n",
"[5.96777257e-03+1.00386678e-03im,\n",
" 5.96832717e-03+1.00598310e-03im,\n",
" 5.96709363e-03+1.00783817e-03im,\n",
" 5.96424518e-03+1.00927078e-03im,\n",
" 5.96018741e-03+1.01026450e-03im],\n",
"[6.03622245e-03+8.31301615e-04im,\n",
" 6.03435514e-03+8.33750877e-04im,\n",
" 6.03168830e-03+8.37852131e-04im,\n",
" 6.02866989e-03+8.43248505e-04im,\n",
" 6.02591783e-03+8.49210075e-04im],\n",
"[6.09621918e-03+6.91661902e-04im,\n",
" 6.09610649e-03+6.95430732e-04im,\n",
" 6.09701313e-03+6.99803990e-04im,\n",
" 6.09867461e-03+7.04136852e-04im,\n",
" 6.10066159e-03+7.07614759e-04im],\n",
"[6.13996899e-03+5.59578126e-04im,\n",
" 6.13802765e-03+5.60164801e-04im,\n",
" 6.13610540e-03+5.62009984e-04im,\n",
" 6.13400387e-03+5.64746442e-04im,\n",
" 6.13150280e-03+5.67872135e-04im]])...)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k1QSp_21jRmP",
"outputId": "74235665-c9eb-45f3-e0a9-c419ed3b3044"
},
"outputs": [
{
"data": {
"text/plain": [
"6-element Vector{Float64}:\n",
" 1.0\n",
" 1.0\n",
" 1.0\n",
" 1.0\n",
" 1.0\n",
" 1.0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Create the model to be fitted\n",
"function model(p, f)\n",
" w = 2 .* pi .* f # Angular frequency\n",
" s = w .* im # Complex variable\n",
" Rs = p[1]\n",
" Qh = p[2]\n",
" nh = p[3]\n",
" Rct = p[4]\n",
" Wct = p[5]\n",
" Rw = p[6]\n",
" Zw = Wct ./ sqrt.(w) .* (1-im) # Planar infinite length Warburg impedance\n",
" Ydl = s .^nh .* Qh # admittance of a CPE\n",
" Z1 = (1 ./ Zw .+ 1 ./ Rw) .^ -1\n",
" Z2 = (Rct .+ Z1)\n",
" Y2 = Z2 .^ -1\n",
" Y3 = (Ydl .+ Y2)\n",
" Z3 = 1 ./ Y3\n",
" Z = Rs .+ Z3\n",
" Y = 1 ./ Z\n",
" return vcat(real.(Y), imag.(Y))\n",
"end\n",
"\n",
"p0 = [1.6295e+02, 3.0678e-08, 9.3104e-01, 1.1865e+04, 4.7125e+05, 1.3296e+06]\n",
"\n",
"bounds = [[1e-15,1e15], [1e-9, 1e2], [1e-1,1e0], [1e-15,1e15], [1e-15,1e15], [1e-15,1e15]]\n",
"\n",
"lb = [i[1] for i in bounds]\n",
"ub = [i[2] for i in bounds]\n",
"\n",
"smf_sigma = ([100000., 100000., 100000., 100000., 100000., 100000.]) # Smoothing factor used with the standard deviation\n",
"\n",
"smf_modulus = ([1., 1., 1., 1., 1., 1.]) # Smoothing factor used with the modulus\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J37kqOqGjRmP",
"outputId": "3307aba0-5d14-40ff-c90b-bd227d34b239"
},
"outputs": [
{
"data": {
"text/plain": [
"Multieis\n",
" p0: Array{Float64}((6,)) [162.95, 3.0678e-8, 0.93104, 11865.0, 471250.0, 1.3296e6]\n",
" freq: Array{Float64}((45,)) [8.0, 25.0, 42.0, 59.0, 76.0, 93.0, 127.0, 161.0, 246.0, 314.0 … 125893.0, 158490.0, 199527.0, 251189.0, 316228.0, 398108.0, 501188.0, 630958.0, 794329.0, 1.0e6]\n",
" Z: Array{ComplexF64}((45, 5)) ComplexF64[3.98043994e-7 + 1.62846334e-6im 3.83471274e-7 + 1.60171055e-6im … 4.04896667e-7 + 1.5847063e-6im 4.51064921e-7 + 1.60540776e-6im; 9.22443178e-7 + 4.39969926e-6im 9.06191531e-7 + 4.38099732e-6im … 9.62976515e-7 + 4.38477446e-6im 1.05098138e-6 + 4.41408429e-6im; … ; 0.00609621918 + 0.000691661902im 0.00609610649 + 0.000695430732im … 0.00609867461 + 0.000704136852im 0.00610066159 + 0.000707614759im; 0.00613996899 + 0.000559578126im 0.00613802765 + 0.000560164801im … 0.00613400387 + 0.000564746442im 0.0061315028 + 0.000567872135im]\n",
" bounds: Array{Vector{Float64}}((6,))\n",
" smf: Array{Float64}((6,)) [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n",
" func: model (function of type typeof(model))\n",
" weight: Nothing <: Any\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"eis = Multieis(p0=p0, freq=F, Z=Y, bounds=bounds, smf=smf_modulus, func=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "58Yzq44ojRmQ",
"outputId": "427332eb-275c-4b7c-fc01-6f14940963da"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using modulus\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 10000 3.21e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 20000 2.64e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 30000 2.53e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 40000 2.51e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 50000 2.51e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 60000 2.50e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 70000 2.50e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 80000 2.49e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 90000 2.48e-02"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n",
"Epoch: 100000 2.48e-025.9101168689673486e-5"
]
}
],
"source": [
"popt, chisqr = fit_stochastic(eis, num_epochs=Int(1e5))\n",
"print(chisqr)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.8.2",
"language": "julia",
"name": "julia-1.8"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.8.2"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "37101ec27ccf92a314b44853e9764439c3325994725dd510a611df984a5abe68"
}
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment