Created
October 23, 2022 19:12
-
-
Save richinex/6a2ab066d040c0323aad1fe9dbb00f8e to your computer and use it in GitHub Desktop.
jax_optim.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOjv93L5SgdTOe4LHP+qb7c", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/richinex/6a2ab066d040c0323aad1fe9dbb00f8e/jax_optim.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install jaxopt" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "sZVjjgwzDlhK", | |
"outputId": "6d2ea006-43cf-4a1c-e775-d05360034c7e" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: jaxopt in /usr/local/lib/python3.7/dist-packages (0.5.5)\n", | |
"Requirement already satisfied: numpy>=1.18.4 in /usr/local/lib/python3.7/dist-packages (from jaxopt) (1.21.6)\n", | |
"Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from jaxopt) (1.7.3)\n", | |
"Requirement already satisfied: jax>=0.2.18 in /usr/local/lib/python3.7/dist-packages (from jaxopt) (0.3.23)\n", | |
"Requirement already satisfied: jaxlib>=0.1.69 in /usr/local/lib/python3.7/dist-packages (from jaxopt) (0.3.22+cuda11.cudnn805)\n", | |
"Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from jaxopt) (1.3.0)\n", | |
"Requirement already satisfied: matplotlib>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from jaxopt) (3.2.2)\n", | |
"Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.18->jaxopt) (0.8.0)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.18->jaxopt) (4.1.1)\n", | |
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.18->jaxopt) (3.3.0)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.0.1->jaxopt) (0.11.0)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.0.1->jaxopt) (3.0.9)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.0.1->jaxopt) (2.8.2)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.0.1->jaxopt) (1.4.4)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib>=2.0.1->jaxopt) (1.15.0)\n", | |
"Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.2.18->jaxopt) (3.9.0)\n", | |
"Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.2.18->jaxopt) (5.10.0)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"import numpy as onp\n", | |
"import scipy\n", | |
"import scipy.sparse as sps\n", | |
"import jax\n", | |
"import jax.numpy as jnp \n", | |
"import jaxopt\n", | |
"import jax.scipy.optimize as jsopt\n", | |
"jax.config.update(\"jax_enable_x64\", True)\n", | |
"from typing import Callable, Optional, Dict, Union, Sequence, Tuple\n", | |
"from datetime import datetime\n", | |
"import logging\n", | |
"logger = logging.getLogger(__name__)" | |
], | |
"metadata": { | |
"id": "hcDXZtQyCEB4" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"class Multieis:\n", | |
" \"\"\"\n", | |
" An immittance batch processing class\n", | |
"\n", | |
" :param p0: A 1D or 2D array of initial guess\n", | |
"\n", | |
" :param freq: An (m, ) 1D array containing the frequencies. \\\n", | |
" Where m is the number of frequencies\n", | |
"\n", | |
" :param Z: An (m, n) 2D array of complex immittances. \\\n", | |
" Where m is the number of frequencies and \\\n", | |
" n is the number of spectra\n", | |
"\n", | |
" :param bounds: A sequence of (min, max) pairs for \\\n", | |
" each element in p0. The values must be real\n", | |
"\n", | |
" :param smf: A array of real elements same size as p0. \\\n", | |
" when set to inf, the corresponding parameter is kept constant\n", | |
"\n", | |
" :param func: A model e.g an equivalent circuit model (ECM) or \\\n", | |
" an arbitrary immittance expression composed as python function\n", | |
"\n", | |
" :param weight: A string representing the weighting scheme or \\\n", | |
" an (m,n) 2-D array of real values containing \\\n", | |
" the measurement standard deviation. \\\n", | |
" Defaults to unit weighting if left unspecified.\n", | |
"\n", | |
" :param immittance: A string corresponding to the immittance type\n", | |
"\n", | |
" :returns: A Multieis instance\n", | |
"\n", | |
" \"\"\"\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" p0: jnp.ndarray,\n", | |
" freq: jnp.ndarray,\n", | |
" Z: jnp.ndarray,\n", | |
" bounds: Sequence[Union[int, float]],\n", | |
" smf: jnp.ndarray,\n", | |
" func: Callable[[float, float], float],\n", | |
" immittance: str = \"impedance\",\n", | |
" weight: Optional[Union[str, jnp.ndarray]] = None,\n", | |
" ) -> None:\n", | |
"\n", | |
" assert (\n", | |
" p0.ndim > 0 and p0.ndim <= 2\n", | |
" ), (\"Initial guess must be a 1-D array or 2-D \"\n", | |
" \"array with same number of cols as `F`\")\n", | |
" assert (\n", | |
" Z.ndim == 2 and Z.shape[1] >= 5\n", | |
" ), \"The algorithm requires that the number of spectra be >= 5\"\n", | |
" assert freq.ndim == 1, \"The frequencies supplied should be 1-D\"\n", | |
" assert (\n", | |
" len(freq) == Z.shape[0]\n", | |
" ), (\"Length mismatch: The len of F is {} while the rows of Z are {}\"\n", | |
" .format(len(freq), Z.shape[0]))\n", | |
"\n", | |
" # Create the lower and upper bounds\n", | |
" try:\n", | |
" self.lb = self.check_zero_and_negative_values(\n", | |
" jnp.asarray([i[0] for i in bounds])\n", | |
" )\n", | |
" self.ub = self.check_zero_and_negative_values(\n", | |
" jnp.asarray([i[1] for i in bounds])\n", | |
" )\n", | |
" except IndexError:\n", | |
" print(\"Bounds must be a sequence of min-max pairs\")\n", | |
"\n", | |
" if p0.ndim == 1:\n", | |
" self.p0 = self.check_zero_and_negative_values(self.check_nan_values(p0))\n", | |
" self.num_params = len(self.p0)\n", | |
" assert (\n", | |
" len(self.lb) == self.num_params\n", | |
" ), \"Shape mismatch between initial guess and bounds\"\n", | |
" if __debug__:\n", | |
" if not jnp.all(\n", | |
" jnp.logical_and(\n", | |
" jnp.greater(self.p0, self.lb),\n", | |
" jnp.greater(self.ub, self.p0)\n", | |
" )\n", | |
" ):\n", | |
" raise AssertionError(\"\"\"Initial guess can not be\n", | |
" greater than the upper bound\n", | |
" or less than lower bound\"\"\")\n", | |
" elif (p0.ndim == 2) and (1 in p0.shape):\n", | |
" self.p0 = self.check_zero_and_negative_values(self.check_nan_values(p0.flatten()))\n", | |
" self.num_params = len(self.p0)\n", | |
" assert (\n", | |
" len(self.lb) == self.num_params\n", | |
" ), \"Shape mismatch between initial guess and bounds\"\n", | |
" if __debug__:\n", | |
" if not jnp.all(\n", | |
" jnp.logical_and(\n", | |
" jnp.greater(self.p0, self.lb),\n", | |
" jnp.greater(self.ub, self.p0)\n", | |
" )\n", | |
" ):\n", | |
" raise AssertionError(\"\"\"Initial guess can not be\n", | |
" greater than the upper bound\n", | |
" or less than lower bound\"\"\")\n", | |
" else:\n", | |
" assert p0.shape[1] == Z.shape[1], (\"Columns of p0 \"\n", | |
" \"do not match that of Z\")\n", | |
" assert (\n", | |
" len(self.lb) == p0.shape[0]\n", | |
" ), (\"The len of p0 is {} while that of the bounds is {}\"\n", | |
" .format(p0.shape[0], len(self.lb)))\n", | |
" self.p0 = self.check_zero_and_negative_values(self.check_nan_values(p0))\n", | |
" self.num_params = p0.shape[0]\n", | |
"\n", | |
" self.immittance_list = [\"admittance\", \"impedance\"]\n", | |
" assert (\n", | |
" immittance.lower() in self.immittance_list\n", | |
" ), \"Either use 'admittance' or 'impedance'\"\n", | |
"\n", | |
" self.num_freq = len(freq)\n", | |
" self.num_eis = Z.shape[1]\n", | |
" self.F = jnp.asarray(freq, dtype=jnp.float64)\n", | |
" self.Z = self.check_is_complex(Z)\n", | |
" self.Z_exp = self.Z.copy()\n", | |
" self.Y_exp = 1 / self.Z_exp.copy()\n", | |
" self.indices = None\n", | |
" self.n_fits = None\n", | |
"\n", | |
" self.func = func\n", | |
" self.immittance = immittance\n", | |
"\n", | |
" self.smf = smf\n", | |
"\n", | |
" self.kvals = list(jnp.cumsum(jnp.insert(\n", | |
" jnp.where(jnp.isinf(self.smf), 1, self.num_eis), 0, 0)))\n", | |
"\n", | |
" self.d2m = self.get_fd()\n", | |
" self.dof = (2 * self.num_freq * self.num_eis) - \\\n", | |
" (self.num_params * self.num_eis)\n", | |
" self.plot_title1 = \" \".join(\n", | |
" [x.title() for x in self.immittance_list if (x == self.immittance)]\n", | |
" )\n", | |
" self.plot_title2 = \" \".join(\n", | |
" [x.title() for x in self.immittance_list if x != self.immittance]\n", | |
" )\n", | |
"\n", | |
" self.lb_vec, self.ub_vec = self.get_bounds_vector(self.lb, self.ub)\n", | |
"\n", | |
" # Define weighting strategies\n", | |
" if isinstance(weight, jnp.ndarray):\n", | |
" self.weight_name = \"sigma\"\n", | |
" assert (\n", | |
" Z.shape == weight.shape\n", | |
" ), \"Shape mismatch between Z and the weight array\"\n", | |
" self.Zerr_Re = weight\n", | |
" self.Zerr_Im = weight\n", | |
" elif isinstance(weight, str):\n", | |
" assert weight.lower() in [\n", | |
" \"proportional\",\n", | |
" \"modulus\",\n", | |
" ], (\"weight must be one of None, \"\n", | |
" \"proportional', 'modulus' or an 2-D array of weights\")\n", | |
" if weight.lower() == \"proportional\":\n", | |
" self.weight_name = \"proportional\"\n", | |
" self.Zerr_Re = self.Z.real\n", | |
" self.Zerr_Im = self.Z.imag\n", | |
" else:\n", | |
" self.weight_name = \"modulus\"\n", | |
" self.Zerr_Re = jnp.abs(self.Z)\n", | |
" self.Zerr_Im = jnp.abs(self.Z)\n", | |
" elif weight is None:\n", | |
" # if set to None we use \"unit\" weighting\n", | |
" self.weight_name = \"unit\"\n", | |
" self.Zerr_Re = jnp.ones(shape=(self.num_freq, self.num_eis))\n", | |
" self.Zerr_Im = jnp.ones(shape=(self.num_freq, self.num_eis))\n", | |
" else:\n", | |
" raise AttributeError(\n", | |
" (\"weight must be one of 'None', \"\n", | |
" \"proportional', 'modulus' or an 2-D array of weights\")\n", | |
" )\n", | |
"\n", | |
" def __str__(self):\n", | |
" return f\"\"\"Multieis({self.p0},{self.F},{self.Z},{self.Zerr_Re},\\\n", | |
" {self.Zerr_Im}, {list(zip(self.lb, self.ub))},\\\n", | |
" {self.func},{self.immittance},{self.weight_name})\"\"\"\n", | |
"\n", | |
" __repr__ = __str__\n", | |
"\n", | |
" @staticmethod\n", | |
" def check_nan_values(arr):\n", | |
" if jnp.isnan(jnp.sum(arr)):\n", | |
" raise Exception(\"Values must not contain nan\")\n", | |
" else:\n", | |
" return arr\n", | |
"\n", | |
" @staticmethod\n", | |
" def check_zero_and_negative_values(arr):\n", | |
" if jnp.all(arr > 0):\n", | |
" return arr\n", | |
" raise Exception(\"Values must be greater than zero\")\n", | |
"\n", | |
" @staticmethod\n", | |
" def try_convert(x):\n", | |
" try:\n", | |
" return str(x)\n", | |
" except Exception as e:\n", | |
" print(e.__doc__)\n", | |
" print(e.message)\n", | |
" return x\n", | |
"\n", | |
" @staticmethod\n", | |
" def check_is_complex(arr):\n", | |
" if onp.iscomplexobj(arr):\n", | |
" return jnp.asarray(arr, dtype=jnp.complex64)\n", | |
" else:\n", | |
" return jnp.asarray(arr, dtype=jnp.complex64)\n", | |
"\n", | |
" def get_bounds_vector(self,\n", | |
" lb: jnp.ndarray,\n", | |
" ub: jnp.ndarray\n", | |
" ) -> Tuple[jnp.ndarray, jnp.ndarray]:\n", | |
" \"\"\"\n", | |
" Creates vectors for the upper and lower \\\n", | |
" bounds which are the same length \\\n", | |
" as the number of parameters to be fitted.\n", | |
"\n", | |
" :param lb: A 1D array of lower bounds\n", | |
" :param ub: A 1D array of upper bounds\n", | |
"\n", | |
" :returns: A tuple of bounds vectors\n", | |
"\n", | |
" \"\"\"\n", | |
" lb_vec = jnp.zeros(\n", | |
" self.num_params * self.num_eis\n", | |
" - (self.num_eis - 1)\n", | |
" * jnp.sum(jnp.isinf(self.smf))\n", | |
" )\n", | |
" ub_vec = jnp.zeros(\n", | |
" self.num_params * self.num_eis\n", | |
" - (self.num_eis - 1) * jnp.sum(jnp.isinf(self.smf))\n", | |
" )\n", | |
" for i in range(self.num_params):\n", | |
" lb_vec = lb_vec.at[self.kvals[i]:self.kvals[i + 1]].set(lb[i])\n", | |
" ub_vec = ub_vec.at[self.kvals[i]:self.kvals[i + 1]].set(ub[i])\n", | |
" return lb_vec, ub_vec\n", | |
"\n", | |
" def get_fd(self):\n", | |
" \"\"\"\n", | |
" Computes the finite difference stencil \\\n", | |
" for a second order derivative. \\\n", | |
" The derivatives at the boundaries is calculated \\\n", | |
" using special finite difference equations\n", | |
" derived specifically for just these points \\\n", | |
" (aka higher order boundary conditions).\n", | |
" They are used to handle numerical problems \\\n", | |
" that occur at the edge of grids.\n", | |
"\n", | |
" :returns: Finite difference stencil for a second order derivative\n", | |
" \"\"\"\n", | |
" self.d2m = (\n", | |
" sps.diags([1, -2, 1], [-1, 0, 1],\n", | |
" shape=(self.num_eis, self.num_eis))\n", | |
" .tolil()\n", | |
" .toarray()\n", | |
" )\n", | |
" self.d2m[0, :4] = [2, -5, 4, -1]\n", | |
" self.d2m[-1, -4:] = [-1, 4, -5, 2]\n", | |
" return jnp.asarray(self.d2m)\n", | |
"\n", | |
" def convert_to_internal(self,\n", | |
" p: jnp.ndarray\n", | |
" ) -> jnp.ndarray:\n", | |
" \"\"\"\n", | |
" Converts A array of parameters from an external \\\n", | |
" to an internal coordinates (log10 scale)\n", | |
"\n", | |
" :param p: A 1D or 2D array of parameter values\n", | |
"\n", | |
" :returns: Parameters in log10 scale\n", | |
" \"\"\"\n", | |
" assert p.ndim > 0 and p.ndim <= 2\n", | |
" if p.ndim == 1:\n", | |
" par = jnp.broadcast_to(\n", | |
" p[:, None],\n", | |
" (self.num_params, self.num_eis)\n", | |
" )\n", | |
" else:\n", | |
" par = p\n", | |
" self.p0_mat = jnp.zeros(\n", | |
" self.num_params * self.num_eis\n", | |
" - (self.num_eis - 1) * jnp.sum(jnp.isinf(self.smf))\n", | |
" )\n", | |
" for i in range(self.num_params):\n", | |
" self.p0_mat = self.p0_mat.at[self.kvals[i]:self.kvals[i + 1]].set(par[\n", | |
" i, : self.kvals[i + 1] - self.kvals[i]\n", | |
" ])\n", | |
" p_log = jnp.log10(\n", | |
" (self.p0_mat - self.lb_vec) / (1 - self.p0_mat / self.ub_vec)\n", | |
" )\n", | |
" return p_log\n", | |
"\n", | |
" def convert_to_external(self,\n", | |
" P: jnp.ndarray\n", | |
" ) -> jnp.ndarray:\n", | |
"\n", | |
" \"\"\"\n", | |
" Converts A array of parameters from an internal \\\n", | |
" to an external coordinates\n", | |
"\n", | |
" :param p: A 1D array of parameter values\n", | |
"\n", | |
" :returns: Parameters in normal scale\n", | |
" \"\"\"\n", | |
" par_ext = jnp.zeros(shape=(self.num_params, self.num_eis))\n", | |
" for i in range(self.num_params):\n", | |
" par_ext = par_ext.at[i, :].set((\n", | |
" self.lb_vec[self.kvals[i]:self.kvals[i + 1]]\n", | |
" + 10 ** P[self.kvals[i]:self.kvals[i + 1]]\n", | |
" ) / (\n", | |
" 1\n", | |
" + (10 ** P[self.kvals[i]:self.kvals[i + 1]])\n", | |
" / self.ub_vec[self.kvals[i]:self.kvals[i + 1]]\n", | |
" ))\n", | |
" return par_ext\n", | |
"\n", | |
" def compute_wrss(self,\n", | |
" p: jnp.ndarray,\n", | |
" f: jnp.ndarray,\n", | |
" z: jnp.ndarray,\n", | |
" zerr_re: jnp.ndarray,\n", | |
" zerr_im: jnp.ndarray\n", | |
" ) -> jnp.ndarray:\n", | |
"\n", | |
" \"\"\"\n", | |
" Computes the scalar weighted residual sum of squares \\\n", | |
" (aka scaled version of the chisquare or the chisquare itself)\n", | |
"\n", | |
" :param p: A 1D array of parameter values\n", | |
"\n", | |
" :param f: A 1D array of frequency\n", | |
"\n", | |
" :param z: A 1D array of complex immittance\n", | |
"\n", | |
" :param zerr_re: A 1D array of weights for \\\n", | |
" the real part of the immittance\n", | |
"\n", | |
" :param zerr_im: A 1D array of weights for \\\n", | |
" the imaginary part of the immittance\n", | |
"\n", | |
" :returns: A scalar value of the \\\n", | |
" weighted residual sum of squares\n", | |
"\n", | |
" \"\"\"\n", | |
" z_concat = jnp.concatenate([z.real, z.imag], axis=0)\n", | |
" sigma = jnp.concatenate([zerr_re, zerr_im], axis=0)\n", | |
" z_model = self.func(p, f)\n", | |
" wrss = jnp.linalg.norm(((z_concat - z_model) / sigma)) ** 2\n", | |
" return wrss\n", | |
"\n", | |
"\n", | |
" def compute_wrms(self,\n", | |
" p: jnp.ndarray,\n", | |
" f: jnp.ndarray,\n", | |
" z: jnp.ndarray,\n", | |
" zerr_re: jnp.ndarray,\n", | |
" zerr_im: jnp.ndarray\n", | |
" ) -> jnp.ndarray:\n", | |
" \"\"\"\n", | |
" Computes the weighted residual mean square\n", | |
"\n", | |
" :param p: A 1D array of parameter values\n", | |
"\n", | |
" :param f: A 1D array of frequency\n", | |
"\n", | |
" :param z: A 1D array of complex immittance\n", | |
"\n", | |
" :param zerr_re: A 1D array of weights for \\\n", | |
" the real part of the immittance\n", | |
"\n", | |
" :param zerr_im: A 1D array of weights for \\\n", | |
" the imaginary part of the immittance\n", | |
"\n", | |
" :returns: A scalar value of the weighted residual mean square\n", | |
" \"\"\"\n", | |
" z_concat = jnp.concatenate([z.real, z.imag], axis=0)\n", | |
" sigma = jnp.concatenate([zerr_re, zerr_im], axis=0)\n", | |
" z_model = self.func(p, f)\n", | |
" wrss = jnp.linalg.norm(((z_concat - z_model) / sigma)) ** 2\n", | |
" wrms = wrss / (2 * len(f) - len(p))\n", | |
" return wrms\n", | |
"\n", | |
" def compute_total_obj(self,\n", | |
" P: jnp.ndarray,\n", | |
" F: jnp.ndarray,\n", | |
" Z: jnp.ndarray,\n", | |
" Zerr_Re: jnp.ndarray,\n", | |
" Zerr_Im: jnp.ndarray,\n", | |
" LB: jnp.ndarray,\n", | |
" UB: jnp.ndarray,\n", | |
" smf: jnp.ndarray\n", | |
" ) -> jnp.ndarray:\n", | |
" \"\"\"\n", | |
" This function computes the total scalar objective function to minimize\n", | |
" which is a combination of the weighted residual sum of squares\n", | |
" and the smoothing factor divided by the degrees of freedom\n", | |
"\n", | |
" :param P: A 1D array of parameter values\n", | |
"\n", | |
" :param F: A 1D array of frequency\n", | |
"\n", | |
" :param Z: A 2D array of complex immittance\n", | |
"\n", | |
" :param Zerr_Re: A 2D array of weights for \\\n", | |
" the real part of the immittance\n", | |
"\n", | |
" :param Zerr_Im: A 2D array of weights for \\\n", | |
" the imaginary part of the immittance\n", | |
"\n", | |
" :param LB: A 1D array of values for \\\n", | |
" the lower bounds (for the total parameters)\n", | |
"\n", | |
" :param LB: A 1D array of values for \\\n", | |
" the upper bounds (for the total parameters)\n", | |
"\n", | |
" :param smf: An array of real elements same size as p0. \\\n", | |
" when set to inf, the corresponding parameter is kept constant\n", | |
"\n", | |
" :returns: A scalar value of the total objective function\n", | |
"\n", | |
" \"\"\"\n", | |
" P_log = jnp.zeros(shape=(self.num_params, self.num_eis))\n", | |
" P_norm = jnp.zeros(shape=(self.num_params, self.num_eis))\n", | |
" for i in range(self.num_params):\n", | |
" P_log = P_log.at[i, :].set(P[self.kvals[i]:self.kvals[i + 1]])\n", | |
"\n", | |
" P_norm = P_norm.at[i, :].set((\n", | |
" LB[self.kvals[i]:self.kvals[i + 1]]\n", | |
" + jnp.power(10, P[self.kvals[i]:self.kvals[i + 1]])\n", | |
" ) / (\n", | |
" 1\n", | |
" + (jnp.power(10, P[self.kvals[i]:self.kvals[i + 1]]))\n", | |
" / UB[self.kvals[i]:self.kvals[i + 1]]\n", | |
" ))\n", | |
"\n", | |
" smf_1 = jnp.where(jnp.isinf(smf), 0.0, smf)\n", | |
" chi_smf = ((((self.d2m @ P_log.T) * (self.d2m @ P_log.T)))\n", | |
" .sum(0) * smf_1).sum()\n", | |
" wrss_tot = jax.vmap(self.compute_wrss, in_axes=(1, None, 1, 1, 1))(\n", | |
" P_norm, F, Z, Zerr_Re, Zerr_Im\n", | |
" )\n", | |
" return (jnp.sum(wrss_tot) + chi_smf)\n", | |
"\n", | |
" def compute_perr(self,\n", | |
" P: jnp.ndarray,\n", | |
" F: jnp.ndarray,\n", | |
" Z: jnp.ndarray,\n", | |
" Zerr_Re: jnp.ndarray,\n", | |
" Zerr_Im: jnp.ndarray,\n", | |
" LB: jnp.ndarray,\n", | |
" UB: jnp.ndarray,\n", | |
" smf: jnp.ndarray\n", | |
" ) -> jnp.ndarray:\n", | |
"\n", | |
" \"\"\"\n", | |
" Computes the error on the parameters resulting from the batch fit\n", | |
" using the hessian inverse of the parameters at the minimum computed\n", | |
" via automatic differentiation\n", | |
"\n", | |
" :param P: A 2D array of parameter values\n", | |
"\n", | |
" :param F: A 1D array of frequency\n", | |
"\n", | |
" :param Z: A 2D array of complex immittance\n", | |
"\n", | |
" :param Zerr_Re: A 2D array of weights for \\\n", | |
" the real part of the immittance\n", | |
"\n", | |
" :param Zerr_Im: A 2D array of weights for \\\n", | |
" the imaginary part of the immittance\n", | |
"\n", | |
" :param LB: A 1D array of values for \\\n", | |
" the lower bounds (for the total parameters)\n", | |
"\n", | |
" :param LB: A 1D array of values for \\\n", | |
" the upper bounds (for the total parameters)\n", | |
"\n", | |
" :param smf: An array of real elements same size as p0. \\\n", | |
" when set to inf, the corresponding parameter is kept constant\n", | |
"\n", | |
" :returns: A 2D array of the standard error on the parameters\n", | |
"\n", | |
" \"\"\"\n", | |
" P_log = self.convert_to_internal(P)\n", | |
"\n", | |
" chitot = self.compute_total_obj(P_log, F, Z, Zerr_Re, Zerr_Im, LB, UB, smf)/self.dof\n", | |
" hess_mat = jax.hessian(self.compute_total_obj)(P_log, F, Z, Zerr_Re, Zerr_Im, LB, UB, smf)\n", | |
" try:\n", | |
" # Here we check to see if the Hessian matrix is singular \\\n", | |
" # or ill-conditioned since this makes accurate computation of the\n", | |
" # confidence intervals close to impossible.\n", | |
" hess_inv = jnp.linalg.inv(hess_mat)\n", | |
" except Exception as e:\n", | |
" print(e.__doc__)\n", | |
" print(e.message)\n", | |
" hess_inv = jnp.linalg.pinv(hess_mat)\n", | |
"\n", | |
" # The covariance matrix of the parameter estimates\n", | |
" # is (asymptotically) the inverse of the hessian matrix\n", | |
" cov_mat = hess_inv * chitot\n", | |
" perr = jnp.zeros(shape=(self.num_params, self.num_eis))\n", | |
" for i in range(self.num_params):\n", | |
" perr = perr.at[i, :].set((jnp.sqrt(jnp.diag(cov_mat)))[\n", | |
" self.kvals[i]:self.kvals[i + 1]\n", | |
" ])\n", | |
" perr = perr.copy() * P\n", | |
" # if the error is nan, a value of 1 is assigned.\n", | |
" return jnp.nan_to_num(perr, nan=1.0e15)\n", | |
"\n", | |
"\n", | |
" def compute_aic(self,\n", | |
" p: jnp.ndarray,\n", | |
" f: jnp.ndarray,\n", | |
" z: jnp.ndarray,\n", | |
" zerr_re: jnp.ndarray,\n", | |
" zerr_im: jnp.ndarray,\n", | |
" ) -> jnp.ndarray:\n", | |
" \"\"\"\n", | |
" Computes the Akaike Information Criterion according to\n", | |
" `M. Ingdal et al <https://www.sciencedirect.com/science/article/abs/pii/S0013468619311739>`_\n", | |
"\n", | |
" :param p: A 1D array of parameter values\n", | |
"\n", | |
" :param f: A 1D array of frequency\n", | |
"\n", | |
" :param z: A 1D array of complex immittance\n", | |
"\n", | |
" :param zerr_re: A 1D array of weights for \\\n", | |
" the real part of the immittance\n", | |
"\n", | |
" :param zerr_im: A 1D array of weights for \\\n", | |
" the imaginary part of the immittance\n", | |
"\n", | |
"\n", | |
" :returns: A value for the AIC\n", | |
" \"\"\"\n", | |
"\n", | |
" wrss = self.compute_wrss(p, f, z, zerr_re, zerr_im)\n", | |
" if self.weight_name == \"sigma\":\n", | |
" m2lnL = (\n", | |
" (2 * self.num_freq) * jnp.log(2 * jnp.pi)\n", | |
" + jnp.sum(jnp.log(zerr_re**2))\n", | |
" + jnp.sum(jnp.log(zerr_im**2))\n", | |
" + (wrss)\n", | |
" )\n", | |
" aic = m2lnL + 2 * self.num_params\n", | |
"\n", | |
" elif self.weight_name == \"unit\":\n", | |
" m2lnL = (\n", | |
" 2 * self.num_freq * jnp.log(2 * jnp.pi)\n", | |
" - 2 * self.num_freq\n", | |
" * jnp.log(2 * self.num_freq)\n", | |
" + 2 * self.num_freq\n", | |
" + 2 * self.num_freq * jnp.log(wrss)\n", | |
" )\n", | |
" aic = m2lnL + 2 * self.num_params\n", | |
"\n", | |
" else:\n", | |
" wt_re = 1 / zerr_re**2\n", | |
" wt_im = wt_re\n", | |
" m2lnL = (\n", | |
" 2 * self.num_freq * jnp.log(2 * jnp.pi)\n", | |
" - 2 * self.num_freq\n", | |
" * jnp.log(2 * self.num_freq)\n", | |
" + 2 * self.num_freq\n", | |
" - jnp.sum(jnp.log(wt_re))\n", | |
" - jnp.sum(jnp.log(wt_im))\n", | |
" + 2 * self.num_freq * jnp.log(wrss)\n", | |
" ) # log-likelihood calculation\n", | |
" aic = m2lnL + 2 * (self.num_params + 1)\n", | |
" return aic\n", | |
"\n", | |
" def fit_simultaneous(self,\n", | |
" method : str = 'TNC',\n", | |
" n_iter : int = 5000,\n", | |
" ) -> Tuple[\n", | |
" jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray\n", | |
" ]:\n", | |
"\n", | |
" \"\"\"\n", | |
" Simultaneous fitting routine with an arbitrary smoothing factor.\n", | |
"\n", | |
" :params method: Solver to use (must be one of \"'TNC', \\\n", | |
" 'BFGS' or 'L-BFGS-B'\")\n", | |
"\n", | |
" :params n_iter: Number of iterations\n", | |
"\n", | |
"\n", | |
" :returns: A tuple containing the optimal parameters (popt), \\\n", | |
" the standard error of the parameters (perr), \\\n", | |
" the objective function at the minimum (chisqr), \\\n", | |
" the total cost function (chitot) and the AIC\n", | |
" \"\"\"\n", | |
" self.method = method.lower()\n", | |
" assert (self.method in ['tnc', 'bfgs', 'l-bfgs-b']), (\"method must be one of \"\n", | |
" \"'TNC', 'BFGS' or 'L-BFGS-B'\")\n", | |
" if hasattr(self, \"popt\") and self.popt.shape[1] == self.Z.shape[1]:\n", | |
" print(\"\\nUsing prefit\")\n", | |
"\n", | |
" self.par_log = (\n", | |
" self.convert_to_internal(self.popt)\n", | |
" )\n", | |
" else:\n", | |
"\n", | |
" self.par_log = (\n", | |
" self.convert_to_internal(self.p0)\n", | |
" )\n", | |
" print(\"\\nUsing initial\")\n", | |
"\n", | |
" # Optimizer 1 uses the BFGS algorithm\n", | |
" start = datetime.now()\n", | |
"\n", | |
" solver = jaxopt.ScipyMinimize(\n", | |
" method=self.method,\n", | |
" fun=jax.jit(self.compute_total_obj),\n", | |
" dtype='float64',\n", | |
" tol=1e-14,\n", | |
" maxiter=n_iter,\n", | |
" )\n", | |
" self.sol = solver.run(\n", | |
" self.par_log,\n", | |
" self.F,\n", | |
" self.Z,\n", | |
" self.Zerr_Re,\n", | |
" self.Zerr_Im,\n", | |
" self.lb_vec,\n", | |
" self.ub_vec,\n", | |
" self.smf\n", | |
" )\n", | |
"\n", | |
" self.popt = self.convert_to_external(self.sol.params)\n", | |
" self.chitot = self.sol.state.fun_val/self.dof\n", | |
"\n", | |
" self.perr = self.compute_perr(\n", | |
" self.popt,\n", | |
" self.F,\n", | |
" self.Z,\n", | |
" self.Zerr_Re,\n", | |
" self.Zerr_Im,\n", | |
" self.lb_vec,\n", | |
" self.ub_vec,\n", | |
" self.smf,\n", | |
" )\n", | |
"\n", | |
" self.chisqr = (\n", | |
" jnp.mean(\n", | |
" jax.vmap(self.compute_wrms, in_axes=(1, None, 1, 1, 1))(\n", | |
" self.popt, self.F, self.Z, self.Zerr_Re, self.Zerr_Im\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" self.AIC = (\n", | |
" jnp.mean(\n", | |
" jax.vmap(self.compute_aic, in_axes=(1, None, 1, 1, 1))(\n", | |
" self.popt, self.F, self.Z, self.Zerr_Re, self.Zerr_Im\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" print(\"\\nOptimization complete\")\n", | |
" end = datetime.now()\n", | |
" print(f\"total time is {end-start}\", end=\" \")\n", | |
" self.Z_exp = self.Z.copy()\n", | |
" self.Y_exp = 1 / self.Z_exp.copy()\n", | |
" self.Z_pred, self.Y_pred = self.model_prediction(self.popt, self.F)\n", | |
" self.indices = [i for i in range(self.Z_exp.shape[1])]\n", | |
" return self.popt, self.perr, self.chisqr, self.chitot, self.AIC\n", | |
"\n", | |
" def real_to_complex(self,\n", | |
" z: jnp.ndarray,\n", | |
" ) -> jnp.ndarray:\n", | |
" \"\"\"\n", | |
" :param z: Takes a real vector of length 2n \\\n", | |
" where n is the number of frequencies\n", | |
"\n", | |
" :returns: Returns a complex vector of length n.\n", | |
" \"\"\"\n", | |
" return z[: len(z) // 2] + 1j * z[len(z) // 2:]\n", | |
"\n", | |
" def complex_to_real(self,\n", | |
" z: jnp.ndarray,\n", | |
" ) -> jnp.ndarray:\n", | |
"\n", | |
" \"\"\"\n", | |
" :param z: Takes a complex vector of length n \\\n", | |
" where n is the number of frequencies\n", | |
"\n", | |
" :returns: Returns a real vector of length 2n\n", | |
" \"\"\"\n", | |
"\n", | |
" return jnp.concatenate((z.real, z.imag), axis=0)\n", | |
"\n", | |
" def model_prediction(self,\n", | |
" P: jnp.ndarray,\n", | |
" F: jnp.ndarray\n", | |
" ) -> Tuple[jnp.ndarray, jnp.ndarray]:\n", | |
" \"\"\"\n", | |
" Computes the predicted immittance and its inverse\n", | |
"\n", | |
" :param P:\n", | |
"\n", | |
" :param Z:\n", | |
"\n", | |
" :returns: The predicted immittance (Z_pred) \\\n", | |
" and its inverse(Y_pred)\n", | |
" \"\"\"\n", | |
" Z_pred = jax.vmap(self.real_to_complex, in_axes=0)(\n", | |
" jax.vmap(self.func, in_axes=(1, None))(P, F)\n", | |
" ).T\n", | |
" Y_pred = 1 / Z_pred.copy()\n", | |
" return Z_pred, Y_pred\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "tzccT-w6AiJN" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def par(a, b):\n", | |
" \"\"\"\n", | |
" Defines the total impedance of two circuit elements in parallel\n", | |
" \"\"\"\n", | |
" return 1/(1/a + 1/b)\n", | |
"\n", | |
"def redox(p, f):\n", | |
" w = 2*jnp.pi*f # Angular frequency\n", | |
" s = 1j*w # Complex variable\n", | |
" Rs = p[0]\n", | |
" Qh = p[1]\n", | |
" nh = p[2]\n", | |
" Rct = p[3]\n", | |
" Wct = p[4]\n", | |
" Rw = p[5]\n", | |
" Zw = Wct/jnp.sqrt(w) * (1-1j) # Planar infinite length Warburg impedance\n", | |
" Zdl = 1/(s**nh*Qh) # admittance of a CPE\n", | |
" Z = Rs + par(Zdl, Rct + par(Zw, Rw))\n", | |
" Y = 1/Z\n", | |
" return jnp.concatenate((Y.real, Y.imag), axis = 0)\n", | |
"\n", | |
"\n", | |
"p0 = jnp.asarray([1.6295e+02, 3.0678e-08, 9.3104e-01, 1.1865e+04, 4.7125e+05, 1.3296e+06])\n", | |
"\n", | |
"bounds = [[1e-15,1e15], [1e-9, 1e2], [1e-1,1e0], [1e-15,1e15], [1e-15,1e15], [1e-15,1e15]]\n", | |
"\n", | |
"smf_modulus = jnp.asarray([1., 1., 1., 1., 1., 1.]) # Smoothing factor used with the modulus\n", | |
"\n", | |
"smf_inf = jnp.asarray([jnp.inf, 1., 1., 1., 1., 1.]) # Smoothing factor with one parameter set to Inf" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gSkCFxuzCktO", | |
"outputId": "a0b5b568-4b78-45ae-ef9b-045d6d43e357" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"F = jnp.asarray([8.00000e+00, 2.50000e+01, 4.20000e+01, 5.90000e+01,\n", | |
" 7.60000e+01, 9.30000e+01, 1.27000e+02, 1.61000e+02,\n", | |
" 2.46000e+02, 3.14000e+02, 3.99000e+02, 5.01000e+02,\n", | |
" 6.37000e+02, 7.88000e+02, 9.93000e+02, 1.25200e+03,\n", | |
" 1.58500e+03, 1.99500e+03, 2.51200e+03, 3.16200e+03,\n", | |
" 3.98100e+03, 5.01200e+03, 6.31000e+03, 7.94300e+03,\n", | |
" 1.00000e+04, 1.25890e+04, 1.58490e+04, 1.99530e+04,\n", | |
" 2.51190e+04, 3.16230e+04, 3.98110e+04, 5.01190e+04,\n", | |
" 6.30960e+04, 7.94330e+04, 1.00000e+05, 1.25893e+05,\n", | |
" 1.58490e+05, 1.99527e+05, 2.51189e+05, 3.16228e+05,\n", | |
" 3.98108e+05, 5.01188e+05, 6.30958e+05, 7.94329e+05,\n", | |
" 1.00000e+06])\n", | |
"\n", | |
"Y = jnp.asarray([[3.98043994e-07+1.62846334e-06j,\n", | |
" 3.83471274e-07+1.60171055e-06j,\n", | |
" 3.83865881e-07+1.58502473e-06j,\n", | |
" 4.04896667e-07+1.58470630e-06j,\n", | |
" 4.51064921e-07+1.60540776e-06j],\n", | |
" [9.22443178e-07+4.39969926e-06j,\n", | |
" 9.06191531e-07+4.38099732e-06j,\n", | |
" 9.16687611e-07+4.37470180e-06j,\n", | |
" 9.62976515e-07+4.38477446e-06j,\n", | |
" 1.05098138e-06+4.41408429e-06j],\n", | |
" [1.40957809e-06+6.86928752e-06j,\n", | |
" 1.38162397e-06+6.83126518e-06j,\n", | |
" 1.38059693e-06+6.81408665e-06j,\n", | |
" 1.41883277e-06+6.82275413e-06j,\n", | |
" 1.50544156e-06+6.86007979e-06j],\n", | |
" [1.68711949e-06+9.25850964e-06j,\n", | |
" 1.66052735e-06+9.20596722e-06j,\n", | |
" 1.67216979e-06+9.17365924e-06j,\n", | |
" 1.73414162e-06+9.16654426e-06j,\n", | |
" 1.85462545e-06+9.18734440e-06j],\n", | |
" [2.15576119e-06+1.16992223e-05j,\n", | |
" 2.10827579e-06+1.16345200e-05j,\n", | |
" 2.10738881e-06+1.15910179e-05j,\n", | |
" 2.16757689e-06+1.15751991e-05j,\n", | |
" 2.29701277e-06+1.15904422e-05j],\n", | |
" [2.58288446e-06+1.40525399e-05j,\n", | |
" 2.51553729e-06+1.39863741e-05j,\n", | |
" 2.48908032e-06+1.39332806e-05j,\n", | |
" 2.52246764e-06+1.38992500e-05j,\n", | |
" 2.62926756e-06+1.38897467e-05j],\n", | |
" [3.31642832e-06+1.85012723e-05j,\n", | |
" 3.29454474e-06+1.83918928e-05j,\n", | |
" 3.31309980e-06+1.82958465e-05j,\n", | |
" 3.38095288e-06+1.82216481e-05j,\n", | |
" 3.50368350e-06+1.81772357e-05j],\n", | |
" [4.12509917e-06+2.25237090e-05j,\n", | |
" 4.10579196e-06+2.24286177e-05j,\n", | |
" 4.12116287e-06+2.23589195e-05j,\n", | |
" 4.18102400e-06+2.23195184e-05j,\n", | |
" 4.29224065e-06+2.23151201e-05j],\n", | |
" [5.77924402e-06+3.31866577e-05j,\n", | |
" 5.69006625e-06+3.30411312e-05j,\n", | |
" 5.66323070e-06+3.29221439e-05j,\n", | |
" 5.71128066e-06+3.28385904e-05j,\n", | |
" 5.84011150e-06+3.27982052e-05j],\n", | |
" [7.04483091e-06+4.16688999e-05j,\n", | |
" 6.98463055e-06+4.14659116e-05j,\n", | |
" 6.97550695e-06+4.12703739e-05j,\n", | |
" 7.03040587e-06+4.10956090e-05j,\n", | |
" 7.15912302e-06+4.09553868e-05j],\n", | |
" [8.63866353e-06+5.13732193e-05j,\n", | |
" 8.50426386e-06+5.11485632e-05j,\n", | |
" 8.42278678e-06+5.09697711e-05j,\n", | |
" 8.41987094e-06+5.08404373e-05j,\n", | |
" 8.51576897e-06+5.07598197e-05j],\n", | |
" [1.03729171e-05+6.31615694e-05j,\n", | |
" 1.02327485e-05+6.28848429e-05j,\n", | |
" 1.01667838e-05+6.26310721e-05j,\n", | |
" 1.01907526e-05+6.24130189e-05j,\n", | |
" 1.03141883e-05+6.22438456e-05j],\n", | |
" [1.28909396e-05+7.83684227e-05j,\n", | |
" 1.27298499e-05+7.80434129e-05j,\n", | |
" 1.26394616e-05+7.77378664e-05j,\n", | |
" 1.26342220e-05+7.74690998e-05j,\n", | |
" 1.27239173e-05+7.72554340e-05j],\n", | |
" [1.55678263e-05+9.51368056e-05j,\n", | |
" 1.53954279e-05+9.47842200e-05j,\n", | |
" 1.52970006e-05+9.44365966e-05j,\n", | |
" 1.52850498e-05+9.41093895e-05j,\n", | |
" 1.53678557e-05+9.38235098e-05j],\n", | |
" [1.90616338e-05+1.17447395e-04j,\n", | |
" 1.88478443e-05+1.17017706e-04j,\n", | |
" 1.87262340e-05+1.16581010e-04j,\n", | |
" 1.87125606e-05+1.16152383e-04j,\n", | |
" 1.88150934e-05+1.15755727e-04j],\n", | |
" [2.36724445e-05+1.44870515e-04j,\n", | |
" 2.34057225e-05+1.44312377e-04j,\n", | |
" 2.32291331e-05+1.43749552e-04j,\n", | |
" 2.31617378e-05+1.43208046e-04j,\n", | |
" 2.32165712e-05+1.42720542e-04j],\n", | |
" [2.98377254e-05+1.79296054e-04j,\n", | |
" 2.95218451e-05+1.78680115e-04j,\n", | |
" 2.92895420e-05+1.78052374e-04j,\n", | |
" 2.91592060e-05+1.77445327e-04j,\n", | |
" 2.91458418e-05+1.76899135e-04j],\n", | |
" [3.71368005e-05+2.21134513e-04j,\n", | |
" 3.67530884e-05+2.20456888e-04j,\n", | |
" 3.64819971e-05+2.19732538e-04j,\n", | |
" 3.63438885e-05+2.19003981e-04j,\n", | |
" 3.63522122e-05+2.18327899e-04j],\n", | |
" [4.71590574e-05+2.72656704e-04j,\n", | |
" 4.66598431e-05+2.71883619e-04j,\n", | |
" 4.62840071e-05+2.71045399e-04j,\n", | |
" 4.60535412e-05+2.70186109e-04j,\n", | |
" 4.59836847e-05+2.69369048e-04j],\n", | |
" [6.05509194e-05+3.36161669e-04j,\n", | |
" 5.98841943e-05+3.35299905e-04j,\n", | |
" 5.93654004e-05+3.34340759e-04j,\n", | |
" 5.90322197e-05+3.33344913e-04j,\n", | |
" 5.89099800e-05+3.32394353e-04j],\n", | |
" [7.88562465e-05+4.14546928e-04j,\n", | |
" 7.81773488e-05+4.13466740e-04j,\n", | |
" 7.76331726e-05+4.12246969e-04j,\n", | |
" 7.72511048e-05+4.10973036e-04j,\n", | |
" 7.70551051e-05+4.09756583e-04j],\n", | |
" [1.04056926e-04+5.10462851e-04j,\n", | |
" 1.03213046e-04+5.09286532e-04j,\n", | |
" 1.02485879e-04+5.07898745e-04j,\n", | |
" 1.01920967e-04+5.06399898e-04j,\n", | |
" 1.01564037e-04+5.04926487e-04j],\n", | |
" [1.39663694e-04+6.28142734e-04j,\n", | |
" 1.38702802e-04+6.26766239e-04j,\n", | |
" 1.37810479e-04+6.25117798e-04j,\n", | |
" 1.37032752e-04+6.23326225e-04j,\n", | |
" 1.36427101e-04+6.21561194e-04j],\n", | |
" [1.88594946e-04+7.69385544e-04j,\n", | |
" 1.87430560e-04+7.67729071e-04j,\n", | |
" 1.86385252e-04+7.65808218e-04j,\n", | |
" 1.85518133e-04+7.63777934e-04j,\n", | |
" 1.84892968e-04+7.61830714e-04j],\n", | |
" [2.59671040e-04+9.40661295e-04j,\n", | |
" 2.58338725e-04+9.38724843e-04j,\n", | |
" 2.56982952e-04+9.36332683e-04j,\n", | |
" 2.55680061e-04+9.33679519e-04j,\n", | |
" 2.54537823e-04+9.31019720e-04j],\n", | |
" [3.59062746e-04+1.14189147e-03j,\n", | |
" 3.57378914e-04+1.14002009e-03j,\n", | |
" 3.55668395e-04+1.13755930e-03j,\n", | |
" 3.54003307e-04+1.13472599e-03j,\n", | |
" 3.52498580e-04+1.13181409e-03j],\n", | |
" [4.98924463e-04+1.37415191e-03j,\n", | |
" 4.96647903e-04+1.37173687e-03j,\n", | |
" 4.94265929e-04+1.36869797e-03j,\n", | |
" 4.91948973e-04+1.36530725e-03j,\n", | |
" 4.89903730e-04+1.36190688e-03j],\n", | |
" [6.93496200e-04+1.63521001e-03j,\n", | |
" 6.90673129e-04+1.63290917e-03j,\n", | |
" 6.87621825e-04+1.62983825e-03j,\n", | |
" 6.84574014e-04+1.62626372e-03j,\n", | |
" 6.81804435e-04+1.62254740e-03j],\n", | |
" [9.57296055e-04+1.91309035e-03j,\n", | |
" 9.53748415e-04+1.91089732e-03j,\n", | |
" 9.49627138e-04+1.90796948e-03j,\n", | |
" 9.45268781e-04+1.90456549e-03j,\n", | |
" 9.41093254e-04+1.90101075e-03j],\n", | |
" [1.30123761e-03+2.19320343e-03j,\n", | |
" 1.29747635e-03+2.19098944e-03j,\n", | |
" 1.29290798e-03+2.18768464e-03j,\n", | |
" 1.28791679e-03+2.18362780e-03j,\n", | |
" 1.28301373e-03+2.17926595e-03j],\n", | |
" [1.73074182e-03+2.43863533e-03j,\n", | |
" 1.72621978e-03+2.43695825e-03j,\n", | |
" 1.72084465e-03+2.43458990e-03j,\n", | |
" 1.71503122e-03+2.43169256e-03j,\n", | |
" 1.70930568e-03+2.42848461e-03j],\n", | |
" [2.23910459e-03+2.62897694e-03j,\n", | |
" 2.23373249e-03+2.62814597e-03j,\n", | |
" 2.22713780e-03+2.62692641e-03j,\n", | |
" 2.21994217e-03+2.62538018e-03j,\n", | |
" 2.21290882e-03+2.62357481e-03j],\n", | |
" [2.80288863e-03+2.72976886e-03j,\n", | |
" 2.79733306e-03+2.72965804e-03j,\n", | |
" 2.79038353e-03+2.72933533e-03j,\n", | |
" 2.78272317e-03+2.72875628e-03j,\n", | |
" 2.77517387e-03+2.72789295e-03j],\n", | |
" [3.38658039e-03+2.72387289e-03j,\n", | |
" 3.38097266e-03+2.72540422e-03j,\n", | |
" 3.37401521e-03+2.72714160e-03j,\n", | |
" 3.36621539e-03+2.72886851e-03j,\n", | |
" 3.35824443e-03+2.73031136e-03j],\n", | |
" [3.94340698e-03+2.61988142e-03j,\n", | |
" 3.93845234e-03+2.62111914e-03j,\n", | |
" 3.93158384e-03+2.62272917e-03j,\n", | |
" 3.92330065e-03+2.62457621e-03j,\n", | |
" 3.91434226e-03+2.62642140e-03j],\n", | |
" [4.43841098e-03+2.43232748e-03j,\n", | |
" 4.43451665e-03+2.43382109e-03j,\n", | |
" 4.42952756e-03+2.43561435e-03j,\n", | |
" 4.42374917e-03+2.43758317e-03j,\n", | |
" 4.41762246e-03+2.43955571e-03j],\n", | |
" [4.85721370e-03+2.19639274e-03j,\n", | |
" 4.85349493e-03+2.19924166e-03j,\n", | |
" 4.84847510e-03+2.20281631e-03j,\n", | |
" 4.84257983e-03+2.20671482e-03j,\n", | |
" 4.83638886e-03+2.21042964e-03j],\n", | |
" [5.19651314e-03+1.93761603e-03j,\n", | |
" 5.19210519e-03+1.94056542e-03j,\n", | |
" 5.18698012e-03+1.94421073e-03j,\n", | |
" 5.18155703e-03+1.94818689e-03j,\n", | |
" 5.17630065e-03+1.95205770e-03j],\n", | |
" [5.47554484e-03+1.68772519e-03j,\n", | |
" 5.47296507e-03+1.68850122e-03j,\n", | |
" 5.46988007e-03+1.68990286e-03j,\n", | |
" 5.46627119e-03+1.69175409e-03j,\n", | |
" 5.46219060e-03+1.69376354e-03j],\n", | |
" [5.67758968e-03+1.42167136e-03j,\n", | |
" 5.67611866e-03+1.42599957e-03j,\n", | |
" 5.67469001e-03+1.43104268e-03j,\n", | |
" 5.67328278e-03+1.43618439e-03j,\n", | |
" 5.67186205e-03+1.44071598e-03j],\n", | |
" [5.84336650e-03+1.20199029e-03j,\n", | |
" 5.84407616e-03+1.20482082e-03j,\n", | |
" 5.84446732e-03+1.20796752e-03j,\n", | |
" 5.84423216e-03+1.21105090e-03j,\n", | |
" 5.84315788e-03+1.21376407e-03j],\n", | |
" [5.96777257e-03+1.00386678e-03j,\n", | |
" 5.96832717e-03+1.00598310e-03j,\n", | |
" 5.96709363e-03+1.00783817e-03j,\n", | |
" 5.96424518e-03+1.00927078e-03j,\n", | |
" 5.96018741e-03+1.01026450e-03j],\n", | |
" [6.03622245e-03+8.31301615e-04j,\n", | |
" 6.03435514e-03+8.33750877e-04j,\n", | |
" 6.03168830e-03+8.37852131e-04j,\n", | |
" 6.02866989e-03+8.43248505e-04j,\n", | |
" 6.02591783e-03+8.49210075e-04j],\n", | |
" [6.09621918e-03+6.91661902e-04j,\n", | |
" 6.09610649e-03+6.95430732e-04j,\n", | |
" 6.09701313e-03+6.99803990e-04j,\n", | |
" 6.09867461e-03+7.04136852e-04j,\n", | |
" 6.10066159e-03+7.07614759e-04j],\n", | |
" [6.13996899e-03+5.59578126e-04j,\n", | |
" 6.13802765e-03+5.60164801e-04j,\n", | |
" 6.13610540e-03+5.62009984e-04j,\n", | |
" 6.13400387e-03+5.64746442e-04j,\n", | |
" 6.13150280e-03+5.67872135e-04j]])" | |
], | |
"metadata": { | |
"id": "M69sCba9HR5f" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"eis = Multieis(p0, F, Y, bounds, smf_modulus, redox, weight='modulus', immittance='admittance')\n", | |
"\n", | |
"popt, perr, chisqr, chitot, aic = eis.fit_simultaneous()\n", | |
"popt[0, :]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Se4hyZu3CoPd", | |
"outputId": "7c0010fe-aceb-4769-ee72-8a0055d95462" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Using initial\n", | |
"\n", | |
"Optimization complete\n", | |
"total time is 0:00:23.909449 " | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([160.50028898, 160.65490235, 160.76659128, 160.7942771 ,\n", | |
" 160.78423369], dtype=float64)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# setting a value to inf sets the parameter constant during optimization\n", | |
"eis = Multieis(p0, F, Y, bounds, smf_inf, redox, weight='modulus', immittance='admittance')\n", | |
"\n", | |
"popt, perr, chisqr, chitot, aic = eis.fit_simultaneous()\n", | |
"popt[0, :]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "fCZeu_7xDPf1", | |
"outputId": "ee03857c-7374-4d38-90db-0f06e5ba41c8" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"Using initial\n", | |
"\n", | |
"Optimization complete\n", | |
"total time is 0:00:09.431290 " | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([160.69970283, 160.69970283, 160.69970283, 160.69970283,\n", | |
" 160.69970283], dtype=float64)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment