Skip to content

Instantly share code, notes, and snippets.

@richinex
Created October 31, 2022 10:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save richinex/f50f3b8099ca8720677cff674c0ed866 to your computer and use it in GitHub Desktop.
Save richinex/f50f3b8099ca8720677cff674c0ed866 to your computer and use it in GitHub Desktop.
jax_optim_adam.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/richinex/f50f3b8099ca8720677cff674c0ed866/jax_optim_adam.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hcDXZtQyCEB4"
},
"outputs": [],
"source": [
"import numpy as onp\n",
"import scipy\n",
"import scipy.sparse as sps\n",
"import jax\n",
"import jax.numpy as jnp \n",
"from jax.example_libraries import optimizers as jax_opt\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"from typing import Callable, Optional, Dict, Union, Sequence, Tuple\n",
"from datetime import datetime\n",
"import logging\n",
"logger = logging.getLogger(__name__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tzccT-w6AiJN"
},
"outputs": [],
"source": [
"\n",
"class Multieis:\n",
" \"\"\"\n",
" An immittance batch processing class\n",
"\n",
" :param p0: A 1D or 2D array of initial guess\n",
"\n",
" :param freq: An (m, ) 1D array containing the frequencies. \\\n",
" Where m is the number of frequencies\n",
"\n",
" :param Z: An (m, n) 2D array of complex immittances. \\\n",
" Where m is the number of frequencies and \\\n",
" n is the number of spectra\n",
"\n",
" :param bounds: A sequence of (min, max) pairs for \\\n",
" each element in p0. The values must be real\n",
"\n",
" :param smf: A array of real elements same size as p0. \\\n",
" when set to inf, the corresponding parameter is kept constant\n",
"\n",
" :param func: A model e.g an equivalent circuit model (ECM) or \\\n",
" an arbitrary immittance expression composed as python function\n",
"\n",
" :param weight: A string representing the weighting scheme or \\\n",
" an (m,n) 2-D array of real values containing \\\n",
" the measurement standard deviation. \\\n",
" Defaults to unit weighting if left unspecified.\n",
"\n",
" :param immittance: A string corresponding to the immittance type\n",
"\n",
" :returns: A Multieis instance\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" p0: jnp.ndarray,\n",
" freq: jnp.ndarray,\n",
" Z: jnp.ndarray,\n",
" bounds: Sequence[Union[int, float]],\n",
" smf: jnp.ndarray,\n",
" func: Callable[[float, float], float],\n",
" immittance: str = \"impedance\",\n",
" weight: Optional[Union[str, jnp.ndarray]] = None,\n",
" ) -> None:\n",
"\n",
" assert (\n",
" p0.ndim > 0 and p0.ndim <= 2\n",
" ), (\"Initial guess must be a 1-D array or 2-D \"\n",
" \"array with same number of cols as `F`\")\n",
" assert (\n",
" Z.ndim == 2 and Z.shape[1] >= 5\n",
" ), \"The algorithm requires that the number of spectra be >= 5\"\n",
" assert freq.ndim == 1, \"The frequencies supplied should be 1-D\"\n",
" assert (\n",
" len(freq) == Z.shape[0]\n",
" ), (\"Length mismatch: The len of F is {} while the rows of Z are {}\"\n",
" .format(len(freq), Z.shape[0]))\n",
"\n",
" # Create the lower and upper bounds\n",
" try:\n",
" self.lb = self.check_zero_and_negative_values(\n",
" jnp.asarray([i[0] for i in bounds])\n",
" )\n",
" self.ub = self.check_zero_and_negative_values(\n",
" jnp.asarray([i[1] for i in bounds])\n",
" )\n",
" except IndexError:\n",
" print(\"Bounds must be a sequence of min-max pairs\")\n",
"\n",
" if p0.ndim == 1:\n",
" self.p0 = self.check_zero_and_negative_values(self.check_nan_values(p0))\n",
" self.num_params = len(self.p0)\n",
" assert (\n",
" len(self.lb) == self.num_params\n",
" ), \"Shape mismatch between initial guess and bounds\"\n",
" if __debug__:\n",
" if not jnp.all(\n",
" jnp.logical_and(\n",
" jnp.greater(self.p0, self.lb),\n",
" jnp.greater(self.ub, self.p0)\n",
" )\n",
" ):\n",
" raise AssertionError(\"\"\"Initial guess can not be\n",
" greater than the upper bound\n",
" or less than lower bound\"\"\")\n",
" elif (p0.ndim == 2) and (1 in p0.shape):\n",
" self.p0 = self.check_zero_and_negative_values(self.check_nan_values(p0.flatten()))\n",
" self.num_params = len(self.p0)\n",
" assert (\n",
" len(self.lb) == self.num_params\n",
" ), \"Shape mismatch between initial guess and bounds\"\n",
" if __debug__:\n",
" if not jnp.all(\n",
" jnp.logical_and(\n",
" jnp.greater(self.p0, self.lb),\n",
" jnp.greater(self.ub, self.p0)\n",
" )\n",
" ):\n",
" raise AssertionError(\"\"\"Initial guess can not be\n",
" greater than the upper bound\n",
" or less than lower bound\"\"\")\n",
" else:\n",
" assert p0.shape[1] == Z.shape[1], (\"Columns of p0 \"\n",
" \"do not match that of Z\")\n",
" assert (\n",
" len(self.lb) == p0.shape[0]\n",
" ), (\"The len of p0 is {} while that of the bounds is {}\"\n",
" .format(p0.shape[0], len(self.lb)))\n",
" self.p0 = self.check_zero_and_negative_values(self.check_nan_values(p0))\n",
" self.num_params = p0.shape[0]\n",
"\n",
" self.immittance_list = [\"admittance\", \"impedance\"]\n",
" assert (\n",
" immittance.lower() in self.immittance_list\n",
" ), \"Either use 'admittance' or 'impedance'\"\n",
"\n",
" self.num_freq = len(freq)\n",
" self.num_eis = Z.shape[1]\n",
" self.F = jnp.asarray(freq, dtype=jnp.float64)\n",
" self.Z = self.check_is_complex(Z)\n",
" self.Z_exp = self.Z.copy()\n",
" self.Y_exp = 1 / self.Z_exp.copy()\n",
" self.indices = None\n",
" self.n_fits = None\n",
"\n",
" self.func = func\n",
" self.immittance = immittance\n",
"\n",
" self.smf = smf\n",
"\n",
" self.kvals = list(jnp.cumsum(jnp.insert(\n",
" jnp.where(jnp.isinf(self.smf), 1, self.num_eis), 0, 0)))\n",
"\n",
" self.d2m = self.get_fd()\n",
" self.dof = (2 * self.num_freq * self.num_eis) - \\\n",
" (self.num_params * self.num_eis)\n",
" self.plot_title1 = \" \".join(\n",
" [x.title() for x in self.immittance_list if (x == self.immittance)]\n",
" )\n",
" self.plot_title2 = \" \".join(\n",
" [x.title() for x in self.immittance_list if x != self.immittance]\n",
" )\n",
"\n",
" self.lb_vec, self.ub_vec = self.get_bounds_vector(self.lb, self.ub)\n",
"\n",
" # Define weighting strategies\n",
" if isinstance(weight, jnp.ndarray):\n",
" self.weight_name = \"sigma\"\n",
" assert (\n",
" Z.shape == weight.shape\n",
" ), \"Shape mismatch between Z and the weight array\"\n",
" self.Zerr_Re = weight\n",
" self.Zerr_Im = weight\n",
" elif isinstance(weight, str):\n",
" assert weight.lower() in [\n",
" \"proportional\",\n",
" \"modulus\",\n",
" ], (\"weight must be one of None, \"\n",
" \"proportional', 'modulus' or an 2-D array of weights\")\n",
" if weight.lower() == \"proportional\":\n",
" self.weight_name = \"proportional\"\n",
" self.Zerr_Re = self.Z.real\n",
" self.Zerr_Im = self.Z.imag\n",
" else:\n",
" self.weight_name = \"modulus\"\n",
" self.Zerr_Re = jnp.abs(self.Z)\n",
" self.Zerr_Im = jnp.abs(self.Z)\n",
" elif weight is None:\n",
" # if set to None we use \"unit\" weighting\n",
" self.weight_name = \"unit\"\n",
" self.Zerr_Re = jnp.ones(shape=(self.num_freq, self.num_eis))\n",
" self.Zerr_Im = jnp.ones(shape=(self.num_freq, self.num_eis))\n",
" else:\n",
" raise AttributeError(\n",
" (\"weight must be one of 'None', \"\n",
" \"proportional', 'modulus' or an 2-D array of weights\")\n",
" )\n",
"\n",
" def __str__(self):\n",
" return f\"\"\"Multieis({self.p0},{self.F},{self.Z},{self.Zerr_Re},\\\n",
" {self.Zerr_Im}, {list(zip(self.lb, self.ub))},\\\n",
" {self.func},{self.immittance},{self.weight_name})\"\"\"\n",
"\n",
" __repr__ = __str__\n",
"\n",
" @staticmethod\n",
" def check_nan_values(arr):\n",
" if jnp.isnan(jnp.sum(arr)):\n",
" raise Exception(\"Values must not contain nan\")\n",
" else:\n",
" return arr\n",
"\n",
" @staticmethod\n",
" def check_zero_and_negative_values(arr):\n",
" if jnp.all(arr > 0):\n",
" return arr\n",
" raise Exception(\"Values must be greater than zero\")\n",
"\n",
" @staticmethod\n",
" def try_convert(x):\n",
" try:\n",
" return str(x)\n",
" except Exception as e:\n",
" print(e.__doc__)\n",
" print(e.message)\n",
" return x\n",
"\n",
" @staticmethod\n",
" def check_is_complex(arr):\n",
" if onp.iscomplexobj(arr):\n",
" return jnp.asarray(arr, dtype=jnp.complex64)\n",
" else:\n",
" return jnp.asarray(arr, dtype=jnp.complex64)\n",
"\n",
" def get_bounds_vector(self,\n",
" lb: jnp.ndarray,\n",
" ub: jnp.ndarray\n",
" ) -> Tuple[jnp.ndarray, jnp.ndarray]:\n",
" \"\"\"\n",
" Creates vectors for the upper and lower \\\n",
" bounds which are the same length \\\n",
" as the number of parameters to be fitted.\n",
"\n",
" :param lb: A 1D array of lower bounds\n",
" :param ub: A 1D array of upper bounds\n",
"\n",
" :returns: A tuple of bounds vectors\n",
"\n",
" \"\"\"\n",
" lb_vec = jnp.zeros(\n",
" self.num_params * self.num_eis\n",
" - (self.num_eis - 1)\n",
" * jnp.sum(jnp.isinf(self.smf))\n",
" )\n",
" ub_vec = jnp.zeros(\n",
" self.num_params * self.num_eis\n",
" - (self.num_eis - 1) * jnp.sum(jnp.isinf(self.smf))\n",
" )\n",
" for i in range(self.num_params):\n",
" lb_vec = lb_vec.at[self.kvals[i]:self.kvals[i + 1]].set(lb[i])\n",
" ub_vec = ub_vec.at[self.kvals[i]:self.kvals[i + 1]].set(ub[i])\n",
" return lb_vec, ub_vec\n",
"\n",
" def get_fd(self):\n",
" \"\"\"\n",
" Computes the finite difference stencil \\\n",
" for a second order derivative. \\\n",
" The derivatives at the boundaries is calculated \\\n",
" using special finite difference equations\n",
" derived specifically for just these points \\\n",
" (aka higher order boundary conditions).\n",
" They are used to handle numerical problems \\\n",
" that occur at the edge of grids.\n",
"\n",
" :returns: Finite difference stencil for a second order derivative\n",
" \"\"\"\n",
" self.d2m = (\n",
" sps.diags([1, -2, 1], [-1, 0, 1],\n",
" shape=(self.num_eis, self.num_eis))\n",
" .tolil()\n",
" .toarray()\n",
" )\n",
" self.d2m[0, :4] = [2, -5, 4, -1]\n",
" self.d2m[-1, -4:] = [-1, 4, -5, 2]\n",
" return jnp.asarray(self.d2m)\n",
"\n",
" def convert_to_internal(self,\n",
" p: jnp.ndarray\n",
" ) -> jnp.ndarray:\n",
" \"\"\"\n",
" Converts A array of parameters from an external \\\n",
" to an internal coordinates (log10 scale)\n",
"\n",
" :param p: A 1D or 2D array of parameter values\n",
"\n",
" :returns: Parameters in log10 scale\n",
" \"\"\"\n",
" assert p.ndim > 0 and p.ndim <= 2\n",
" if p.ndim == 1:\n",
" par = jnp.broadcast_to(\n",
" p[:, None],\n",
" (self.num_params, self.num_eis)\n",
" )\n",
" else:\n",
" par = p\n",
" self.p0_mat = jnp.zeros(\n",
" self.num_params * self.num_eis\n",
" - (self.num_eis - 1) * jnp.sum(jnp.isinf(self.smf))\n",
" )\n",
" for i in range(self.num_params):\n",
" self.p0_mat = self.p0_mat.at[self.kvals[i]:self.kvals[i + 1]].set(par[\n",
" i, : self.kvals[i + 1] - self.kvals[i]\n",
" ])\n",
" p_log = jnp.log10(\n",
" (self.p0_mat - self.lb_vec) / (1 - self.p0_mat / self.ub_vec)\n",
" )\n",
" return p_log\n",
"\n",
" def convert_to_external(self,\n",
" P: jnp.ndarray\n",
" ) -> jnp.ndarray:\n",
"\n",
" \"\"\"\n",
" Converts A array of parameters from an internal \\\n",
" to an external coordinates\n",
"\n",
" :param p: A 1D array of parameter values\n",
"\n",
" :returns: Parameters in normal scale\n",
" \"\"\"\n",
" par_ext = jnp.zeros(shape=(self.num_params, self.num_eis))\n",
" for i in range(self.num_params):\n",
" par_ext = par_ext.at[i, :].set((\n",
" self.lb_vec[self.kvals[i]:self.kvals[i + 1]]\n",
" + 10 ** P[self.kvals[i]:self.kvals[i + 1]]\n",
" ) / (\n",
" 1\n",
" + (10 ** P[self.kvals[i]:self.kvals[i + 1]])\n",
" / self.ub_vec[self.kvals[i]:self.kvals[i + 1]]\n",
" ))\n",
" return par_ext\n",
"\n",
" def compute_wrss(self,\n",
" p: jnp.ndarray,\n",
" f: jnp.ndarray,\n",
" z: jnp.ndarray,\n",
" zerr_re: jnp.ndarray,\n",
" zerr_im: jnp.ndarray\n",
" ) -> jnp.ndarray:\n",
"\n",
" \"\"\"\n",
" Computes the scalar weighted residual sum of squares \\\n",
" (aka scaled version of the chisquare or the chisquare itself)\n",
"\n",
" :param p: A 1D array of parameter values\n",
"\n",
" :param f: A 1D array of frequency\n",
"\n",
" :param z: A 1D array of complex immittance\n",
"\n",
" :param zerr_re: A 1D array of weights for \\\n",
" the real part of the immittance\n",
"\n",
" :param zerr_im: A 1D array of weights for \\\n",
" the imaginary part of the immittance\n",
"\n",
" :returns: A scalar value of the \\\n",
" weighted residual sum of squares\n",
"\n",
" \"\"\"\n",
" z_concat = jnp.concatenate([z.real, z.imag], axis=0)\n",
" sigma = jnp.concatenate([zerr_re, zerr_im], axis=0)\n",
" z_model = self.func(p, f)\n",
" wrss = jnp.linalg.norm(((z_concat - z_model) / sigma)) ** 2\n",
" return wrss\n",
"\n",
"\n",
" def compute_wrms(self,\n",
" p: jnp.ndarray,\n",
" f: jnp.ndarray,\n",
" z: jnp.ndarray,\n",
" zerr_re: jnp.ndarray,\n",
" zerr_im: jnp.ndarray\n",
" ) -> jnp.ndarray:\n",
" \"\"\"\n",
" Computes the weighted residual mean square\n",
"\n",
" :param p: A 1D array of parameter values\n",
"\n",
" :param f: A 1D array of frequency\n",
"\n",
" :param z: A 1D array of complex immittance\n",
"\n",
" :param zerr_re: A 1D array of weights for \\\n",
" the real part of the immittance\n",
"\n",
" :param zerr_im: A 1D array of weights for \\\n",
" the imaginary part of the immittance\n",
"\n",
" :returns: A scalar value of the weighted residual mean square\n",
" \"\"\"\n",
" z_concat = jnp.concatenate([z.real, z.imag], axis=0)\n",
" sigma = jnp.concatenate([zerr_re, zerr_im], axis=0)\n",
" z_model = self.func(p, f)\n",
" wrss = jnp.linalg.norm(((z_concat - z_model) / sigma)) ** 2\n",
" wrms = wrss / (2 * len(f) - len(p))\n",
" return wrms\n",
"\n",
" def compute_total_obj(self,\n",
" P: jnp.ndarray,\n",
" F: jnp.ndarray,\n",
" Z: jnp.ndarray,\n",
" Zerr_Re: jnp.ndarray,\n",
" Zerr_Im: jnp.ndarray,\n",
" LB: jnp.ndarray,\n",
" UB: jnp.ndarray,\n",
" smf: jnp.ndarray\n",
" ) -> jnp.ndarray:\n",
" \"\"\"\n",
" This function computes the total scalar objective function to minimize\n",
" which is a combination of the weighted residual sum of squares\n",
" and the smoothing factor divided by the degrees of freedom\n",
"\n",
" :param P: A 1D array of parameter values\n",
"\n",
" :param F: A 1D array of frequency\n",
"\n",
" :param Z: A 2D array of complex immittance\n",
"\n",
" :param Zerr_Re: A 2D array of weights for \\\n",
" the real part of the immittance\n",
"\n",
" :param Zerr_Im: A 2D array of weights for \\\n",
" the imaginary part of the immittance\n",
"\n",
" :param LB: A 1D array of values for \\\n",
" the lower bounds (for the total parameters)\n",
"\n",
" :param LB: A 1D array of values for \\\n",
" the upper bounds (for the total parameters)\n",
"\n",
" :param smf: An array of real elements same size as p0. \\\n",
" when set to inf, the corresponding parameter is kept constant\n",
"\n",
" :returns: A scalar value of the total objective function\n",
"\n",
" \"\"\"\n",
" P_log = jnp.zeros(shape=(self.num_params, self.num_eis))\n",
" P_norm = jnp.zeros(shape=(self.num_params, self.num_eis))\n",
" for i in range(self.num_params):\n",
" P_log = P_log.at[i, :].set(P[self.kvals[i]:self.kvals[i + 1]])\n",
"\n",
" P_norm = P_norm.at[i, :].set((\n",
" LB[self.kvals[i]:self.kvals[i + 1]]\n",
" + jnp.power(10, P[self.kvals[i]:self.kvals[i + 1]])\n",
" ) / (\n",
" 1\n",
" + (jnp.power(10, P[self.kvals[i]:self.kvals[i + 1]]))\n",
" / UB[self.kvals[i]:self.kvals[i + 1]]\n",
" ))\n",
"\n",
" smf_1 = jnp.where(jnp.isinf(smf), 0.0, smf)\n",
" chi_smf = ((((self.d2m @ P_log.T) * (self.d2m @ P_log.T)))\n",
" .sum(0) * smf_1).sum()\n",
" wrss_tot = jax.vmap(self.compute_wrss, in_axes=(1, None, 1, 1, 1))(\n",
" P_norm, F, Z, Zerr_Re, Zerr_Im\n",
" )\n",
" return (jnp.sum(wrss_tot) + chi_smf)\n",
"\n",
" def compute_perr(self,\n",
" P: jnp.ndarray,\n",
" F: jnp.ndarray,\n",
" Z: jnp.ndarray,\n",
" Zerr_Re: jnp.ndarray,\n",
" Zerr_Im: jnp.ndarray,\n",
" LB: jnp.ndarray,\n",
" UB: jnp.ndarray,\n",
" smf: jnp.ndarray\n",
" ) -> jnp.ndarray:\n",
"\n",
" \"\"\"\n",
" Computes the error on the parameters resulting from the batch fit\n",
" using the hessian inverse of the parameters at the minimum computed\n",
" via automatic differentiation\n",
"\n",
" :param P: A 2D array of parameter values\n",
"\n",
" :param F: A 1D array of frequency\n",
"\n",
" :param Z: A 2D array of complex immittance\n",
"\n",
" :param Zerr_Re: A 2D array of weights for \\\n",
" the real part of the immittance\n",
"\n",
" :param Zerr_Im: A 2D array of weights for \\\n",
" the imaginary part of the immittance\n",
"\n",
" :param LB: A 1D array of values for \\\n",
" the lower bounds (for the total parameters)\n",
"\n",
" :param LB: A 1D array of values for \\\n",
" the upper bounds (for the total parameters)\n",
"\n",
" :param smf: An array of real elements same size as p0. \\\n",
" when set to inf, the corresponding parameter is kept constant\n",
"\n",
" :returns: A 2D array of the standard error on the parameters\n",
"\n",
" \"\"\"\n",
" P_log = self.convert_to_internal(P)\n",
"\n",
" chitot = self.compute_total_obj(P_log, F, Z, Zerr_Re, Zerr_Im, LB, UB, smf)/self.dof\n",
" hess_mat = jax.hessian(self.compute_total_obj)(P_log, F, Z, Zerr_Re, Zerr_Im, LB, UB, smf)\n",
" try:\n",
" # Here we check to see if the Hessian matrix is singular \\\n",
" # or ill-conditioned since this makes accurate computation of the\n",
" # confidence intervals close to impossible.\n",
" hess_inv = jnp.linalg.inv(hess_mat)\n",
" except Exception as e:\n",
" print(e.__doc__)\n",
" print(e.message)\n",
" hess_inv = jnp.linalg.pinv(hess_mat)\n",
"\n",
" # The covariance matrix of the parameter estimates\n",
" # is (asymptotically) the inverse of the hessian matrix\n",
" cov_mat = hess_inv * chitot\n",
" perr = jnp.zeros(shape=(self.num_params, self.num_eis))\n",
" for i in range(self.num_params):\n",
" perr = perr.at[i, :].set((jnp.sqrt(jnp.diag(cov_mat)))[\n",
" self.kvals[i]:self.kvals[i + 1]\n",
" ])\n",
" perr = perr.copy() * P\n",
" # if the error is nan, a value of 1 is assigned.\n",
" return jnp.nan_to_num(perr, nan=1.0e15)\n",
"\n",
"\n",
" def compute_aic(self,\n",
" p: jnp.ndarray,\n",
" f: jnp.ndarray,\n",
" z: jnp.ndarray,\n",
" zerr_re: jnp.ndarray,\n",
" zerr_im: jnp.ndarray,\n",
" ) -> jnp.ndarray:\n",
" \"\"\"\n",
" Computes the Akaike Information Criterion according to\n",
" `M. Ingdal et al <https://www.sciencedirect.com/science/article/abs/pii/S0013468619311739>`_\n",
"\n",
" :param p: A 1D array of parameter values\n",
"\n",
" :param f: A 1D array of frequency\n",
"\n",
" :param z: A 1D array of complex immittance\n",
"\n",
" :param zerr_re: A 1D array of weights for \\\n",
" the real part of the immittance\n",
"\n",
" :param zerr_im: A 1D array of weights for \\\n",
" the imaginary part of the immittance\n",
"\n",
"\n",
" :returns: A value for the AIC\n",
" \"\"\"\n",
"\n",
" wrss = self.compute_wrss(p, f, z, zerr_re, zerr_im)\n",
" if self.weight_name == \"sigma\":\n",
" m2lnL = (\n",
" (2 * self.num_freq) * jnp.log(2 * jnp.pi)\n",
" + jnp.sum(jnp.log(zerr_re**2))\n",
" + jnp.sum(jnp.log(zerr_im**2))\n",
" + (wrss)\n",
" )\n",
" aic = m2lnL + 2 * self.num_params\n",
"\n",
" elif self.weight_name == \"unit\":\n",
" m2lnL = (\n",
" 2 * self.num_freq * jnp.log(2 * jnp.pi)\n",
" - 2 * self.num_freq\n",
" * jnp.log(2 * self.num_freq)\n",
" + 2 * self.num_freq\n",
" + 2 * self.num_freq * jnp.log(wrss)\n",
" )\n",
" aic = m2lnL + 2 * self.num_params\n",
"\n",
" else:\n",
" wt_re = 1 / zerr_re**2\n",
" wt_im = wt_re\n",
" m2lnL = (\n",
" 2 * self.num_freq * jnp.log(2 * jnp.pi)\n",
" - 2 * self.num_freq\n",
" * jnp.log(2 * self.num_freq)\n",
" + 2 * self.num_freq\n",
" - jnp.sum(jnp.log(wt_re))\n",
" - jnp.sum(jnp.log(wt_im))\n",
" + 2 * self.num_freq * jnp.log(wrss)\n",
" ) # log-likelihood calculation\n",
" aic = m2lnL + 2 * (self.num_params + 1)\n",
" return aic\n",
"\n",
" def train_step(self,\n",
" step_i,\n",
" opt_state,\n",
" F,\n",
" Z,\n",
" Zerr_Re,\n",
" Zerr_Im,\n",
" LB,\n",
" UB,\n",
" smf\n",
" ):\n",
" net_params = self.get_params(opt_state)\n",
" self.loss, self.grads = jax.value_and_grad(\n",
" self.compute_total_obj, argnums=0\n",
" )(\n",
" net_params,\n",
" F,\n",
" Z,\n",
" Zerr_Re,\n",
" Zerr_Im,\n",
" LB,\n",
" UB,\n",
" smf\n",
" )\n",
" return self.loss, self.opt_update(step_i, self.grads, opt_state)\n",
"\n",
" def fit_stochastic(self,\n",
" lr: float = 1e-3,\n",
" num_epochs: int = 1e5,\n",
" ) -> Tuple[\n",
" jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray\n",
" ]: # Optimal parameters, parameter error,\n",
" # weighted residual mean square, and the AIC\n",
"\n",
" \"\"\"\n",
" Fitting routine which uses the Adam optimizer.\n",
" It is important to note here that even stocahstic search procedures,\n",
" although applicable to large scale problems do not \\\n",
" find the global optimum with certainty (Aster, Richard pg 249)\n",
"\n",
" :param lr: Learning rate\n",
"\n",
" :param num_epochs: Number of epochs\n",
"\n",
" :returns: A tuple containing the optimal parameters (popt), \\\n",
" the standard error of the parameters (perr), \\\n",
" the objective function at the minimum (chisqr), \\\n",
" the total cost function (chitot) and the AIC\n",
" \"\"\"\n",
" self.lr = lr\n",
" self.num_epochs = int(num_epochs)\n",
"\n",
" if hasattr(self, \"popt\") and self.popt.shape[1] == self.Z.shape[1]:\n",
" print(\"\\nUsing prefit\")\n",
"\n",
" self.par_log = (\n",
" self.convert_to_internal(self.popt)\n",
" )\n",
" else:\n",
" print(\"\\nUsing initial\")\n",
" self.par_log = (\n",
" self.convert_to_internal(self.p0)\n",
" )\n",
"\n",
" start = datetime.now()\n",
" self.opt_init, self.opt_update, self.get_params = jax_opt.adam(self.lr)\n",
" self.opt_state = self.opt_init(self.par_log)\n",
" self.losses = []\n",
" for epoch in range(self.num_epochs):\n",
"\n",
" self.loss, self.opt_state = jax.jit(self.train_step)(\n",
" epoch,\n",
" self.opt_state,\n",
" self.F,\n",
" self.Z,\n",
" self.Zerr_Re,\n",
" self.Zerr_Im,\n",
" self.lb_vec,\n",
" self.ub_vec,\n",
" self.smf\n",
" )\n",
" self.losses.append(float(self.loss))\n",
" if epoch % int(self.num_epochs/10) == 0:\n",
" print(\n",
" \"\" + str(epoch) + \": \"\n",
" + \"loss=\" + \"{:5.3e}\".format(self.loss/self.dof)\n",
" )\n",
"\n",
" self.popt = self.convert_to_external(self.get_params(self.opt_state))\n",
" self.chitot = self.losses[-1]\n",
"\n",
" # Computer perr using the fractional covariance matrix\n",
"\n",
" self.perr = self.compute_perr(\n",
" self.popt,\n",
" self.F,\n",
" self.Z,\n",
" self.Zerr_Re,\n",
" self.Zerr_Im,\n",
" self.lb_vec,\n",
" self.ub_vec,\n",
" self.smf,\n",
" )\n",
" self.chisqr = jnp.mean(\n",
" jax.vmap(self.compute_wrms, in_axes=(1, None, 1, 1, 1))(\n",
" self.popt, self.F, self.Z, self.Zerr_Re, self.Zerr_Im\n",
" )\n",
" )\n",
" self.AIC = jnp.mean(\n",
" jax.vmap(self.compute_aic, in_axes=(1, None, 1, 1, 1))(\n",
" self.popt, self.F, self.Z, self.Zerr_Re, self.Zerr_Im\n",
" )\n",
" )\n",
" print(\"Optimization complete\")\n",
" end = datetime.now()\n",
" print(f\"total time is {end-start}\", end=\" \")\n",
" self.Z_exp = self.Z.clone()\n",
" self.Y_exp = 1 / self.Z_exp.clone()\n",
" self.Z_pred, self.Y_pred = self.model_prediction(self.popt, self.F)\n",
" self.indices = [i for i in range(self.Z_exp.shape[1])]\n",
"\n",
" return self.popt, self.perr, self.chisqr, self.chitot, self.AIC\n",
"\n",
" def real_to_complex(self,\n",
" z: jnp.ndarray,\n",
" ) -> jnp.ndarray:\n",
" \"\"\"\n",
" :param z: Takes a real vector of length 2n \\\n",
" where n is the number of frequencies\n",
"\n",
" :returns: Returns a complex vector of length n.\n",
" \"\"\"\n",
" return z[: len(z) // 2] + 1j * z[len(z) // 2:]\n",
"\n",
" def complex_to_real(self,\n",
" z: jnp.ndarray,\n",
" ) -> jnp.ndarray:\n",
"\n",
" \"\"\"\n",
" :param z: Takes a complex vector of length n \\\n",
" where n is the number of frequencies\n",
"\n",
" :returns: Returns a real vector of length 2n\n",
" \"\"\"\n",
"\n",
" return jnp.concatenate((z.real, z.imag), axis=0)\n",
"\n",
" def model_prediction(self,\n",
" P: jnp.ndarray,\n",
" F: jnp.ndarray\n",
" ) -> Tuple[jnp.ndarray, jnp.ndarray]:\n",
" \"\"\"\n",
" Computes the predicted immittance and its inverse\n",
"\n",
" :param P:\n",
"\n",
" :param Z:\n",
"\n",
" :returns: The predicted immittance (Z_pred) \\\n",
" and its inverse(Y_pred)\n",
" \"\"\"\n",
" Z_pred = jax.vmap(self.real_to_complex, in_axes=0)(\n",
" jax.vmap(self.func, in_axes=(1, None))(P, F)\n",
" ).T\n",
" Y_pred = 1 / Z_pred.copy()\n",
" return Z_pred, Y_pred\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gSkCFxuzCktO",
"outputId": "a0b5b568-4b78-45ae-ef9b-045d6d43e357"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"def par(a, b):\n",
" \"\"\"\n",
" Defines the total impedance of two circuit elements in parallel\n",
" \"\"\"\n",
" return 1/(1/a + 1/b)\n",
"\n",
"def model(p, f):\n",
" w = 2*jnp.pi*f # Angular frequency\n",
" s = 1j*w # Complex variable\n",
" Rs = p[0]\n",
" Qh = p[1]\n",
" nh = p[2]\n",
" Rct = p[3]\n",
" Wct = p[4]\n",
" Rw = p[5]\n",
" Zw = Wct/jnp.sqrt(w) * (1-1j) # Planar infinite length Warburg impedance\n",
" Zdl = 1/(s**nh*Qh) # admittance of a CPE\n",
" Z = Rs + par(Zdl, Rct + par(Zw, Rw))\n",
" Y = 1/Z\n",
" return jnp.concatenate((Y.real, Y.imag), axis = 0)\n",
"\n",
"\n",
"p0 = jnp.asarray([1.6295e+02, 3.0678e-08, 9.3104e-01, 1.1865e+04, 4.7125e+05, 1.3296e+06])\n",
"\n",
"bounds = [[1e-15,1e15], [1e-9, 1e2], [1e-1,1e0], [1e-15,1e15], [1e-15,1e15], [1e-15,1e15]]\n",
"\n",
"smf_modulus = jnp.asarray([1., 1., 1., 1., 1., 1.]) # Smoothing factor used with the modulus\n",
"\n",
"smf_inf = jnp.asarray([jnp.inf, 1., 1., 1., 1., 1.]) # Smoothing factor with one parameter set to Inf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M69sCba9HR5f"
},
"outputs": [],
"source": [
"F = jnp.asarray([8.00000e+00, 2.50000e+01, 4.20000e+01, 5.90000e+01,\n",
" 7.60000e+01, 9.30000e+01, 1.27000e+02, 1.61000e+02,\n",
" 2.46000e+02, 3.14000e+02, 3.99000e+02, 5.01000e+02,\n",
" 6.37000e+02, 7.88000e+02, 9.93000e+02, 1.25200e+03,\n",
" 1.58500e+03, 1.99500e+03, 2.51200e+03, 3.16200e+03,\n",
" 3.98100e+03, 5.01200e+03, 6.31000e+03, 7.94300e+03,\n",
" 1.00000e+04, 1.25890e+04, 1.58490e+04, 1.99530e+04,\n",
" 2.51190e+04, 3.16230e+04, 3.98110e+04, 5.01190e+04,\n",
" 6.30960e+04, 7.94330e+04, 1.00000e+05, 1.25893e+05,\n",
" 1.58490e+05, 1.99527e+05, 2.51189e+05, 3.16228e+05,\n",
" 3.98108e+05, 5.01188e+05, 6.30958e+05, 7.94329e+05,\n",
" 1.00000e+06])\n",
"\n",
"Y = jnp.asarray([[3.98043994e-07+1.62846334e-06j,\n",
" 3.83471274e-07+1.60171055e-06j,\n",
" 3.83865881e-07+1.58502473e-06j,\n",
" 4.04896667e-07+1.58470630e-06j,\n",
" 4.51064921e-07+1.60540776e-06j],\n",
" [9.22443178e-07+4.39969926e-06j,\n",
" 9.06191531e-07+4.38099732e-06j,\n",
" 9.16687611e-07+4.37470180e-06j,\n",
" 9.62976515e-07+4.38477446e-06j,\n",
" 1.05098138e-06+4.41408429e-06j],\n",
" [1.40957809e-06+6.86928752e-06j,\n",
" 1.38162397e-06+6.83126518e-06j,\n",
" 1.38059693e-06+6.81408665e-06j,\n",
" 1.41883277e-06+6.82275413e-06j,\n",
" 1.50544156e-06+6.86007979e-06j],\n",
" [1.68711949e-06+9.25850964e-06j,\n",
" 1.66052735e-06+9.20596722e-06j,\n",
" 1.67216979e-06+9.17365924e-06j,\n",
" 1.73414162e-06+9.16654426e-06j,\n",
" 1.85462545e-06+9.18734440e-06j],\n",
" [2.15576119e-06+1.16992223e-05j,\n",
" 2.10827579e-06+1.16345200e-05j,\n",
" 2.10738881e-06+1.15910179e-05j,\n",
" 2.16757689e-06+1.15751991e-05j,\n",
" 2.29701277e-06+1.15904422e-05j],\n",
" [2.58288446e-06+1.40525399e-05j,\n",
" 2.51553729e-06+1.39863741e-05j,\n",
" 2.48908032e-06+1.39332806e-05j,\n",
" 2.52246764e-06+1.38992500e-05j,\n",
" 2.62926756e-06+1.38897467e-05j],\n",
" [3.31642832e-06+1.85012723e-05j,\n",
" 3.29454474e-06+1.83918928e-05j,\n",
" 3.31309980e-06+1.82958465e-05j,\n",
" 3.38095288e-06+1.82216481e-05j,\n",
" 3.50368350e-06+1.81772357e-05j],\n",
" [4.12509917e-06+2.25237090e-05j,\n",
" 4.10579196e-06+2.24286177e-05j,\n",
" 4.12116287e-06+2.23589195e-05j,\n",
" 4.18102400e-06+2.23195184e-05j,\n",
" 4.29224065e-06+2.23151201e-05j],\n",
" [5.77924402e-06+3.31866577e-05j,\n",
" 5.69006625e-06+3.30411312e-05j,\n",
" 5.66323070e-06+3.29221439e-05j,\n",
" 5.71128066e-06+3.28385904e-05j,\n",
" 5.84011150e-06+3.27982052e-05j],\n",
" [7.04483091e-06+4.16688999e-05j,\n",
" 6.98463055e-06+4.14659116e-05j,\n",
" 6.97550695e-06+4.12703739e-05j,\n",
" 7.03040587e-06+4.10956090e-05j,\n",
" 7.15912302e-06+4.09553868e-05j],\n",
" [8.63866353e-06+5.13732193e-05j,\n",
" 8.50426386e-06+5.11485632e-05j,\n",
" 8.42278678e-06+5.09697711e-05j,\n",
" 8.41987094e-06+5.08404373e-05j,\n",
" 8.51576897e-06+5.07598197e-05j],\n",
" [1.03729171e-05+6.31615694e-05j,\n",
" 1.02327485e-05+6.28848429e-05j,\n",
" 1.01667838e-05+6.26310721e-05j,\n",
" 1.01907526e-05+6.24130189e-05j,\n",
" 1.03141883e-05+6.22438456e-05j],\n",
" [1.28909396e-05+7.83684227e-05j,\n",
" 1.27298499e-05+7.80434129e-05j,\n",
" 1.26394616e-05+7.77378664e-05j,\n",
" 1.26342220e-05+7.74690998e-05j,\n",
" 1.27239173e-05+7.72554340e-05j],\n",
" [1.55678263e-05+9.51368056e-05j,\n",
" 1.53954279e-05+9.47842200e-05j,\n",
" 1.52970006e-05+9.44365966e-05j,\n",
" 1.52850498e-05+9.41093895e-05j,\n",
" 1.53678557e-05+9.38235098e-05j],\n",
" [1.90616338e-05+1.17447395e-04j,\n",
" 1.88478443e-05+1.17017706e-04j,\n",
" 1.87262340e-05+1.16581010e-04j,\n",
" 1.87125606e-05+1.16152383e-04j,\n",
" 1.88150934e-05+1.15755727e-04j],\n",
" [2.36724445e-05+1.44870515e-04j,\n",
" 2.34057225e-05+1.44312377e-04j,\n",
" 2.32291331e-05+1.43749552e-04j,\n",
" 2.31617378e-05+1.43208046e-04j,\n",
" 2.32165712e-05+1.42720542e-04j],\n",
" [2.98377254e-05+1.79296054e-04j,\n",
" 2.95218451e-05+1.78680115e-04j,\n",
" 2.92895420e-05+1.78052374e-04j,\n",
" 2.91592060e-05+1.77445327e-04j,\n",
" 2.91458418e-05+1.76899135e-04j],\n",
" [3.71368005e-05+2.21134513e-04j,\n",
" 3.67530884e-05+2.20456888e-04j,\n",
" 3.64819971e-05+2.19732538e-04j,\n",
" 3.63438885e-05+2.19003981e-04j,\n",
" 3.63522122e-05+2.18327899e-04j],\n",
" [4.71590574e-05+2.72656704e-04j,\n",
" 4.66598431e-05+2.71883619e-04j,\n",
" 4.62840071e-05+2.71045399e-04j,\n",
" 4.60535412e-05+2.70186109e-04j,\n",
" 4.59836847e-05+2.69369048e-04j],\n",
" [6.05509194e-05+3.36161669e-04j,\n",
" 5.98841943e-05+3.35299905e-04j,\n",
" 5.93654004e-05+3.34340759e-04j,\n",
" 5.90322197e-05+3.33344913e-04j,\n",
" 5.89099800e-05+3.32394353e-04j],\n",
" [7.88562465e-05+4.14546928e-04j,\n",
" 7.81773488e-05+4.13466740e-04j,\n",
" 7.76331726e-05+4.12246969e-04j,\n",
" 7.72511048e-05+4.10973036e-04j,\n",
" 7.70551051e-05+4.09756583e-04j],\n",
" [1.04056926e-04+5.10462851e-04j,\n",
" 1.03213046e-04+5.09286532e-04j,\n",
" 1.02485879e-04+5.07898745e-04j,\n",
" 1.01920967e-04+5.06399898e-04j,\n",
" 1.01564037e-04+5.04926487e-04j],\n",
" [1.39663694e-04+6.28142734e-04j,\n",
" 1.38702802e-04+6.26766239e-04j,\n",
" 1.37810479e-04+6.25117798e-04j,\n",
" 1.37032752e-04+6.23326225e-04j,\n",
" 1.36427101e-04+6.21561194e-04j],\n",
" [1.88594946e-04+7.69385544e-04j,\n",
" 1.87430560e-04+7.67729071e-04j,\n",
" 1.86385252e-04+7.65808218e-04j,\n",
" 1.85518133e-04+7.63777934e-04j,\n",
" 1.84892968e-04+7.61830714e-04j],\n",
" [2.59671040e-04+9.40661295e-04j,\n",
" 2.58338725e-04+9.38724843e-04j,\n",
" 2.56982952e-04+9.36332683e-04j,\n",
" 2.55680061e-04+9.33679519e-04j,\n",
" 2.54537823e-04+9.31019720e-04j],\n",
" [3.59062746e-04+1.14189147e-03j,\n",
" 3.57378914e-04+1.14002009e-03j,\n",
" 3.55668395e-04+1.13755930e-03j,\n",
" 3.54003307e-04+1.13472599e-03j,\n",
" 3.52498580e-04+1.13181409e-03j],\n",
" [4.98924463e-04+1.37415191e-03j,\n",
" 4.96647903e-04+1.37173687e-03j,\n",
" 4.94265929e-04+1.36869797e-03j,\n",
" 4.91948973e-04+1.36530725e-03j,\n",
" 4.89903730e-04+1.36190688e-03j],\n",
" [6.93496200e-04+1.63521001e-03j,\n",
" 6.90673129e-04+1.63290917e-03j,\n",
" 6.87621825e-04+1.62983825e-03j,\n",
" 6.84574014e-04+1.62626372e-03j,\n",
" 6.81804435e-04+1.62254740e-03j],\n",
" [9.57296055e-04+1.91309035e-03j,\n",
" 9.53748415e-04+1.91089732e-03j,\n",
" 9.49627138e-04+1.90796948e-03j,\n",
" 9.45268781e-04+1.90456549e-03j,\n",
" 9.41093254e-04+1.90101075e-03j],\n",
" [1.30123761e-03+2.19320343e-03j,\n",
" 1.29747635e-03+2.19098944e-03j,\n",
" 1.29290798e-03+2.18768464e-03j,\n",
" 1.28791679e-03+2.18362780e-03j,\n",
" 1.28301373e-03+2.17926595e-03j],\n",
" [1.73074182e-03+2.43863533e-03j,\n",
" 1.72621978e-03+2.43695825e-03j,\n",
" 1.72084465e-03+2.43458990e-03j,\n",
" 1.71503122e-03+2.43169256e-03j,\n",
" 1.70930568e-03+2.42848461e-03j],\n",
" [2.23910459e-03+2.62897694e-03j,\n",
" 2.23373249e-03+2.62814597e-03j,\n",
" 2.22713780e-03+2.62692641e-03j,\n",
" 2.21994217e-03+2.62538018e-03j,\n",
" 2.21290882e-03+2.62357481e-03j],\n",
" [2.80288863e-03+2.72976886e-03j,\n",
" 2.79733306e-03+2.72965804e-03j,\n",
" 2.79038353e-03+2.72933533e-03j,\n",
" 2.78272317e-03+2.72875628e-03j,\n",
" 2.77517387e-03+2.72789295e-03j],\n",
" [3.38658039e-03+2.72387289e-03j,\n",
" 3.38097266e-03+2.72540422e-03j,\n",
" 3.37401521e-03+2.72714160e-03j,\n",
" 3.36621539e-03+2.72886851e-03j,\n",
" 3.35824443e-03+2.73031136e-03j],\n",
" [3.94340698e-03+2.61988142e-03j,\n",
" 3.93845234e-03+2.62111914e-03j,\n",
" 3.93158384e-03+2.62272917e-03j,\n",
" 3.92330065e-03+2.62457621e-03j,\n",
" 3.91434226e-03+2.62642140e-03j],\n",
" [4.43841098e-03+2.43232748e-03j,\n",
" 4.43451665e-03+2.43382109e-03j,\n",
" 4.42952756e-03+2.43561435e-03j,\n",
" 4.42374917e-03+2.43758317e-03j,\n",
" 4.41762246e-03+2.43955571e-03j],\n",
" [4.85721370e-03+2.19639274e-03j,\n",
" 4.85349493e-03+2.19924166e-03j,\n",
" 4.84847510e-03+2.20281631e-03j,\n",
" 4.84257983e-03+2.20671482e-03j,\n",
" 4.83638886e-03+2.21042964e-03j],\n",
" [5.19651314e-03+1.93761603e-03j,\n",
" 5.19210519e-03+1.94056542e-03j,\n",
" 5.18698012e-03+1.94421073e-03j,\n",
" 5.18155703e-03+1.94818689e-03j,\n",
" 5.17630065e-03+1.95205770e-03j],\n",
" [5.47554484e-03+1.68772519e-03j,\n",
" 5.47296507e-03+1.68850122e-03j,\n",
" 5.46988007e-03+1.68990286e-03j,\n",
" 5.46627119e-03+1.69175409e-03j,\n",
" 5.46219060e-03+1.69376354e-03j],\n",
" [5.67758968e-03+1.42167136e-03j,\n",
" 5.67611866e-03+1.42599957e-03j,\n",
" 5.67469001e-03+1.43104268e-03j,\n",
" 5.67328278e-03+1.43618439e-03j,\n",
" 5.67186205e-03+1.44071598e-03j],\n",
" [5.84336650e-03+1.20199029e-03j,\n",
" 5.84407616e-03+1.20482082e-03j,\n",
" 5.84446732e-03+1.20796752e-03j,\n",
" 5.84423216e-03+1.21105090e-03j,\n",
" 5.84315788e-03+1.21376407e-03j],\n",
" [5.96777257e-03+1.00386678e-03j,\n",
" 5.96832717e-03+1.00598310e-03j,\n",
" 5.96709363e-03+1.00783817e-03j,\n",
" 5.96424518e-03+1.00927078e-03j,\n",
" 5.96018741e-03+1.01026450e-03j],\n",
" [6.03622245e-03+8.31301615e-04j,\n",
" 6.03435514e-03+8.33750877e-04j,\n",
" 6.03168830e-03+8.37852131e-04j,\n",
" 6.02866989e-03+8.43248505e-04j,\n",
" 6.02591783e-03+8.49210075e-04j],\n",
" [6.09621918e-03+6.91661902e-04j,\n",
" 6.09610649e-03+6.95430732e-04j,\n",
" 6.09701313e-03+6.99803990e-04j,\n",
" 6.09867461e-03+7.04136852e-04j,\n",
" 6.10066159e-03+7.07614759e-04j],\n",
" [6.13996899e-03+5.59578126e-04j,\n",
" 6.13802765e-03+5.60164801e-04j,\n",
" 6.13610540e-03+5.62009984e-04j,\n",
" 6.13400387e-03+5.64746442e-04j,\n",
" 6.13150280e-03+5.67872135e-04j]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Se4hyZu3CoPd",
"outputId": "7c0010fe-aceb-4769-ee72-8a0055d95462"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Using initial\n",
"0: loss=8.660e-01\n",
"10000: loss=7.544e-05\n",
"20000: loss=5.847e-05\n",
"30000: loss=5.804e-05\n",
"40000: loss=5.790e-05\n",
"50000: loss=5.784e-05\n",
"60000: loss=5.779e-05\n",
"70000: loss=5.777e-05\n",
"80000: loss=5.775e-05\n",
"90000: loss=5.773e-05\n",
"Optimization complete\n",
"total time is 0:00:47.746234 "
]
}
],
"source": [
"eis = Multieis(p0, F, Y, bounds, smf_modulus, model, weight='modulus', immittance='admittance')\n",
"\n",
"popt, perr, chisqr, chitot, aic = eis.fit_stochastic()\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3.9.13 ('jax_env')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "6e71d03fb7a7d280bb66b5f7d14841fb884e15b93df48a7b9ca18ae0c1960e2f"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment