-
-
Save leandrolcampos/118e7b3e129d4a4f32f366d366ecee7d to your computer and use it in GitHub Desktop.
student_t_f32_tiny_t.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOZ4zMPEvVV54LFbi4AYFlG", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/leandrolcampos/118e7b3e129d4a4f32f366d366ecee7d/student_t_f32_tiny_t.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "tl_dwv6lCTMU" | |
}, | |
"outputs": [], | |
"source": [ | |
"%pip install -U -q tf-nightly tfp-nightly" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import functools\n", | |
"import itertools\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import scipy.stats as sp_stats\n", | |
"import scipy.special as sp_special\n", | |
"import scipy.integrate as sp_integrate\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_probability as tfp\n", | |
"\n", | |
"from mpmath import mp\n", | |
"from tensorflow_probability.python.distributions import student_t\n", | |
"from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient\n", | |
"from tensorflow_probability.python.internal import dtype_util\n", | |
"from tensorflow_probability.python.internal import prefer_static as ps\n", | |
"from tensorflow_probability.python.internal import special_math\n", | |
"from tensorflow_probability.python.internal import test_util\n", | |
"from tensorflow_probability.python.math import generic\n", | |
"from tensorflow_probability.python.math import gradient\n", | |
"from tensorflow_probability.python.math import special\n", | |
"from tensorflow_probability.python.math.numeric import log1psquare" | |
], | |
"metadata": { | |
"id": "0FJ_5e6nCa66" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tf.__version__" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"id": "n1v3xnEuTeo9", | |
"outputId": "106418fb-7424-426c-9aff-0968f1a90d93" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"'2.12.0-dev20221019'" | |
], | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
} | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tfp.__version__" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"id": "nOTvfuTlT4iS", | |
"outputId": "7c0bc601-a970-45b2-c576-548fd11f04f5" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"'0.19.0-dev20221019'" | |
], | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
} | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"mp.dps = 25; mp.pretty = True" | |
], | |
"metadata": { | |
"id": "OUE6yZKHi1xh" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 1. Implementations" | |
], | |
"metadata": { | |
"id": "JmBhtIFvSleU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def log_prob(x, df, loc, scale):\n", | |
" \"\"\"Adapted from tfp.distributions.student_t.log_prob.\"\"\"\n", | |
" dtype = dtype_util.common_dtype([x, df, loc, scale], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
" half = numpy_dtype(0.5)\n", | |
"\n", | |
" x, df, loc, scale = [\n", | |
" tf.convert_to_tensor(param, dtype=dtype)\n", | |
" for param in (x, df, loc, scale)]\n", | |
"\n", | |
" # Writing `y` this way reduces XLA mem copies.\n", | |
" y = (x - loc) * (tf.math.rsqrt(df) / scale)\n", | |
" log_unnormalized_prob = -half * (df + numpy_dtype(1.)) * log1psquare(y)\n", | |
" log_normalization = (\n", | |
" tf.math.log(tf.abs(scale)) + half * tf.math.log(df) +\n", | |
" special.lbeta(half, half * df))\n", | |
" return log_unnormalized_prob - log_normalization" | |
], | |
"metadata": { | |
"id": "ou_FdK019YDV" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@tf.function(autograph=False)\n", | |
"def current_cdf(df, t):\n", | |
" \"\"\"Adapted from tfp.distributions.student_t.cdf. Used for comparison.\"\"\"\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
"\n", | |
" df, t = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, t)]\n", | |
" broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])\n", | |
" df, t = [tf.broadcast_to(param, broadcast_shape) for param in (df, t)]\n", | |
"\n", | |
" x_t = df / (tf.math.square(t) + df)\n", | |
" neg_cdf = half * special.betainc(half * df, half, x_t)\n", | |
" return tf.where(t < 0., neg_cdf, one - neg_cdf)" | |
], | |
"metadata": { | |
"id": "B_buX20TS28E" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@tf.function(autograph=False)\n", | |
"def current_quantile(df, p):\n", | |
" \"\"\"Adapted from tfp.distributions.student_t.quantile. Used for comparison.\"\"\"\n", | |
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
" two = numpy_dtype(2.)\n", | |
"\n", | |
" df, p = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, p)]\n", | |
" broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(p)])\n", | |
" df, p = [tf.broadcast_to(param, broadcast_shape) for param in (df, p)]\n", | |
"\n", | |
" p_adjusted = tf.where(p < half, p, one - p)\n", | |
" x = special.betaincinv(half * df, half, two * p_adjusted)\n", | |
"\n", | |
" abs_t = tf.math.exp(\n", | |
" tf.math.xlogy(half, df) + tf.math.xlog1py(half, -x) -\n", | |
" tf.math.xlogy(half, x))\n", | |
"\n", | |
" return tf.math.sign(p - half) * abs_t" | |
], | |
"metadata": { | |
"id": "wy-Nt83t1tux" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# The implementations of the Student's t-distribution cumulative distribution\n", | |
"# function and its inverse, respectively stdtr(df, t) and stdtrit(df, p), are\n", | |
"# based on ideas and equations available in the following references:\n", | |
"# [1] Geoffrey W. Hill\n", | |
"# Algorithm 395: Student's t-distribution\n", | |
"# Communications of the ACM, v. 13, n. 10, p. 617-619, 1970\n", | |
"# https://doi.org/10.1145/355598.362775\n", | |
"# [2] Geoffrey W. Hill\n", | |
"# Remark on \"Algorithm 395: Student's t-Distribution [S14]\"\n", | |
"# ACM Transactions on Mathematical Software, v. 7, n. 2, p. 247-249, 1981\n", | |
"# https://doi.org/10.1145/355945.355955\n", | |
"# [3] William Press, Saul Teukolsky, William Vetterling and Brian Flannery\n", | |
"# Numerical Recipes: The Art of Scientific Computing\n", | |
"# Cambridge University Press, 2007 (Third Edition)\n", | |
"# http://numerical.recipes/book/book.html\n", | |
"# [4] Geoffrey W. Hill\n", | |
"# Algorithm 396: Student's t-quantiles\n", | |
"# Communications of the ACM, v. 13, n. 10, p. 619-620, 1970\n", | |
"# https://doi.org/10.1145/355598.355600\n", | |
"# [5] Geoffrey W. Hill\n", | |
"# Remark on \"Algorithm 396: Student's t-Quantiles [S14]\"\n", | |
"# ACM Transactions on Mathematical Software, v. 7, n. 2, p. 250-251, 1981\n", | |
"# https://doi.org/10.1145/355945.355956\n", | |
"# [6] R Core Team, R Foundation and Ross Ihaka\n", | |
"# Mathlib: A C Library of Special Functions\n", | |
"# https://svn.r-project.org/R/tags/R-4-2-1/src/nmath/qt.c" | |
], | |
"metadata": { | |
"id": "lYV8zBZxkbF3" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtr_asymptotic_expansion(df, t, numpy_dtype):\n", | |
" \"\"\"Computes `stdtr(df, t)` using asymptotic expansion.\"\"\"\n", | |
" # This function provides a fast approximation of stdtr(df, t) for large value\n", | |
" # of df. It is based on an asymptotic normalizing expansion of Cornish-Fisher\n", | |
" # type [1, 2].\n", | |
" one = numpy_dtype(1.)\n", | |
" two = numpy_dtype(2.)\n", | |
"\n", | |
" coeffs1 = [\n", | |
" 1.00000000000000000000E+0, 3.00000000000000000000E+0]\n", | |
"\n", | |
" coeffs2 = [\n", | |
" 4.00000000000000022204E-1, 3.29999999999999982236E+0,\n", | |
" 2.40000000000000000000E+1, 8.55000000000000000000E+1]\n", | |
"\n", | |
" coeffs3 = [\n", | |
" 3.04761904761904789396E-1, 3.75238095238095237249E+0,\n", | |
" 4.66714285714285708195E+1, 4.27500000000000000000E+2,\n", | |
" 2.58750000000000000000E+3, 8.51850000000000000000E+3]\n", | |
"\n", | |
" coeffs4 = [\n", | |
" 2.74285714285714299354E-1, 4.49904761904761940627E+0,\n", | |
" 7.84514285714285648510E+1, 1.11871071428571417528E+3,\n", | |
" 1.23876000000000003638E+4, 1.01024550000000002910E+5,\n", | |
" 5.59494000000000000000E+5, 1.76495962500000000000E+6]\n", | |
"\n", | |
" coeffs5 = [\n", | |
" 2.65974025974025973795E-1, 5.44969696969696926203E+0,\n", | |
" 1.22202943722943729199E+2, 2.35472987012987005073E+3,\n", | |
" 3.76250090259740245529E+4, 4.86996139285714307334E+5,\n", | |
" 4.96087065000000037253E+6, 3.79785955499999970198E+7,\n", | |
" 2.01505390875000000000E+8, 6.22437908625000000000E+8]\n", | |
"\n", | |
" terms_coeffs = [\n", | |
" [numpy_dtype(c) for c in coeffs]\n", | |
" for coeffs in (coeffs1, coeffs2, coeffs3, coeffs4, coeffs5)]\n", | |
"\n", | |
" df_minus_half = df - numpy_dtype(0.5)\n", | |
" squared_z = df_minus_half * log1psquare(t * tf.math.rsqrt(df))\n", | |
" z = tf.math.sqrt(squared_z)\n", | |
" # To avoid overflow when df is huge, we manipulate b and the denominator of\n", | |
" # each term of the expansion in the logarithmic space.\n", | |
" log_b = tf.math.log(numpy_dtype(48.)) + tf.math.xlogy(two, df_minus_half)\n", | |
"\n", | |
" term_sign = one\n", | |
" log_term_denominator = numpy_dtype(0.)\n", | |
" # We initialize the series with its first term.\n", | |
" series_sum = z\n", | |
" last_index = len(terms_coeffs) - 1\n", | |
"\n", | |
" # We evaluate the next five terms using a procedure based on Horner's method.\n", | |
" for index, coeffs in enumerate(terms_coeffs):\n", | |
" if index < last_index:\n", | |
" log_term_denominator = log_term_denominator + log_b\n", | |
" else:\n", | |
" log_term_denominator = log_term_denominator + tf.math.log(\n", | |
" (numpy_dtype(0.43595) * squared_z + two) * squared_z +\n", | |
" tf.math.exp(log_b) + numpy_dtype(537.))\n", | |
"\n", | |
" term_numerator = coeffs[0]\n", | |
" for c in coeffs[1:]:\n", | |
" term_numerator = c + term_numerator * squared_z\n", | |
"\n", | |
" term_numerator = term_numerator * z\n", | |
" series_sum = series_sum + term_sign * tf.math.exp(\n", | |
" tf.math.log(term_numerator) - log_term_denominator)\n", | |
" term_sign = -one * term_sign\n", | |
"\n", | |
" return special_math.ndtr(tf.math.sign(t) * series_sum)" | |
], | |
"metadata": { | |
"id": "-jgH__lAk98H" | |
}, | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtr_computation(df, t):\n", | |
" \"\"\"Computes Student's t-distribution cumulative distribution function.\"\"\"\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
"\n", | |
" df, t = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, t)]\n", | |
" broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])\n", | |
" df, t = [tf.broadcast_to(param, broadcast_shape) for param in (df, t)]\n", | |
"\n", | |
" # For moderate df and relatively small t**2, or in case of large df, we use\n", | |
" # asymptotic expansion [1, 2] to compute stdtr(df, t). The condition to use\n", | |
" # it was specified by experimentation for np.float32 and was taken from [2,\n", | |
" # page 249] for np.float64.\n", | |
"\n", | |
" if numpy_dtype == np.float32:\n", | |
" use_asymptotic_expansion = (\n", | |
" (df >= 10.) & (tf.math.square(t) < (16. * df - 5.))) | (df > 30.)\n", | |
" else:\n", | |
" use_asymptotic_expansion = (\n", | |
" (df >= 100.) & (tf.math.square(t) < (0.1 * df - 5.))) | (df > 1000.)\n", | |
"\n", | |
" result = _stdtr_asymptotic_expansion(df, t, numpy_dtype)\n", | |
"\n", | |
" # Otherwise, we evaluate stdtr(df, t) using the regularized incomplete beta\n", | |
" # function [3, page 323, equation 6.14.10]:\n", | |
" # stdtr(df, t) =\n", | |
" # 0.5 * betainc(0.5 * df, 0.5, df / (df + t**2)), when t < 0\n", | |
" # 1. - 0.5 * betainc(0.5 * df, 0.5, df / (df + t**2)), when t >= 0\n", | |
" # To avoid rounding error when t**2 is small compared to df, we compute the\n", | |
" # ratio df / (df + t**2) in the logarithmic space. If ratio > 0.99, we then\n", | |
" # use the following symmetry relation:\n", | |
" # betainc(a, b, x) = 1 - betainc(b, a, 1 - x) .\n", | |
"\n", | |
" raw_ratio = t * tf.math.rsqrt(df)\n", | |
" ratio = tf.math.exp(-log1psquare(raw_ratio))\n", | |
" one_minus_ratio = tf.math.exp(-log1psquare(tf.math.reciprocal(raw_ratio)))\n", | |
"\n", | |
" # The maximum value for the ratio was set by experimentation.\n", | |
" use_symmetry_relation = (ratio > 0.99)\n", | |
" half_df = half * df\n", | |
" a = tf.where(use_symmetry_relation, half, half_df)\n", | |
" b = tf.where(use_symmetry_relation, half_df, half)\n", | |
" x = tf.where(use_symmetry_relation, one_minus_ratio, ratio)\n", | |
"\n", | |
" y = special.betainc(a, b, x)\n", | |
" neg_cdf = half * tf.where(use_symmetry_relation, one - y, y)\n", | |
" result_betainc = tf.where(t < 0., neg_cdf, one - neg_cdf)\n", | |
"\n", | |
" result = tf.where(use_asymptotic_expansion, result, result_betainc)\n", | |
"\n", | |
" # Determine if df is out of range (should return NaN output).\n", | |
" result_is_nan = (df <= 0.)\n", | |
" result = tf.where(result_is_nan, numpy_dtype(np.nan), result)\n", | |
"\n", | |
" return result" | |
], | |
"metadata": { | |
"id": "7CTjxBmTv_Qn" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtr_partials(df, t):\n", | |
" \"\"\"Returns the partial derivatives of `stdtr(df, t)`.\"\"\"\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
" tiny = np.finfo(numpy_dtype).tiny\n", | |
" eps = np.finfo(numpy_dtype).eps\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
"\n", | |
" df, t = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, t)]\n", | |
" broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])\n", | |
" df, t = [tf.broadcast_to(param, broadcast_shape) for param in (df, t)]\n", | |
"\n", | |
" # The gradient with respect to t can be computed more accurately and stably\n", | |
" # using Student's t-distribution probability density function.\n", | |
"\n", | |
" stdtr_grad_t = tf.math.exp(log_prob(t, df, numpy_dtype(0.), one))\n", | |
"\n", | |
" # For moderate df and relatively small t**2, or in case of large df, we use\n", | |
" # automatic differentiation of the procedure _stdtr_asymptotic_expansion to\n", | |
" # compute the gradient with respect to df.\n", | |
"\n", | |
" if numpy_dtype == np.float32:\n", | |
" use_asymptotic_expansion = (\n", | |
" (df >= 10.) & (tf.math.square(t) < (16. * df - 5.))) | (df > 30.)\n", | |
" else:\n", | |
" use_asymptotic_expansion = (\n", | |
" (df >= 100.) & (tf.math.square(t) < (0.1 * df - 5.))) | (df > 1000.)\n", | |
"\n", | |
" abs_t = tf.math.abs(t)\n", | |
" min_abs_t = half * df * tf.math.pow(tiny, numpy_dtype(0.25))\n", | |
" t_is_tiny = (abs_t < min_abs_t)\n", | |
" df_asymptotic_expansion = tf.where(use_asymptotic_expansion, df, one)\n", | |
" t_asymptotic_expansion = tf.where(\n", | |
" use_asymptotic_expansion,\n", | |
" # Mask out tiny t so the gradient correctly propagates. When t is tiny,\n", | |
" # second derivative of log1psquare(t * tf.math.rsqrt(df)) can be NaN.\n", | |
" tf.where(t_is_tiny, tf.where(t < 0., -min_abs_t, min_abs_t), t),\n", | |
" one)\n", | |
"\n", | |
" stdtr_grad_df = gradient.value_and_gradient(\n", | |
" lambda z: _stdtr_asymptotic_expansion(\n", | |
" z, t_asymptotic_expansion, numpy_dtype),\n", | |
" df_asymptotic_expansion)[1]\n", | |
" # Handle the case (abs_t < min_abs_t): we use a rough linear approximation.\n", | |
" stdtr_grad_df = tf.where(t_is_tiny, stdtr_grad_df * abs_t, stdtr_grad_df)\n", | |
"\n", | |
" # Otherwise, the gradient with respect to df is evaluated using the partial\n", | |
" # derivatives of betainc(a, b, x). For t < 0, we have:\n", | |
" # stdtr_grad_df = 0.5 * (betainc_grad_a * a_grad_df +\n", | |
" # betainc_grad_x * x_grad_df)\n", | |
" # = 0.5 * (betainc_grad_a * a_grad_df +\n", | |
" # 2. * stdtr_grad_t / x_grad_t * x_grad_df)\n", | |
" # = 0.5 * (betainc_grad_a * a_grad_df -\n", | |
" # stdtr_grad_t * t / df) ,\n", | |
" # where a = 0.5 * df and x = df / (df + t**2). In above equation, the second\n", | |
" # equality follows from the fact that:\n", | |
" # stdtr_grad_t = 0.5 * betainc_grad_x * x_grad_t , for t < 0.\n", | |
" # To avoid rounding error when t**2 is small compared to df, we compute the\n", | |
" # ratio df / (df + t**2) in the logarithmic space. If ratio > 0.99, we then\n", | |
" # use the following symmetry relation:\n", | |
" # betainc(a, b, x) = 1 - betainc(b, a, 1 - x) .\n", | |
"\n", | |
" min_abs_t_betainc = tf.math.sqrt(df * eps * tf.math.reciprocal(one - eps))\n", | |
" abs_t_is_too_small = (abs_t < min_abs_t_betainc)\n", | |
" # Mask out too small abs(t) so the gradient correctly propagates. When abs(t)\n", | |
" # is too small, ratio == 1 and one_minus_ratio < eps.\n", | |
" abs_t_betainc = tf.where(abs_t_is_too_small, min_abs_t_betainc, abs_t)\n", | |
"\n", | |
" raw_ratio = abs_t_betainc * tf.math.rsqrt(df)\n", | |
" ratio = tf.math.exp(-log1psquare(raw_ratio))\n", | |
" one_minus_ratio = tf.math.exp(-log1psquare(tf.math.reciprocal(raw_ratio)))\n", | |
"\n", | |
" # The maximum value for the ratio was set by experimentation.\n", | |
" use_symmetry_relation = (ratio > 0.99)\n", | |
" half_df = half * df\n", | |
" a = tf.where(use_symmetry_relation, half, half_df)\n", | |
" b = tf.where(use_symmetry_relation, half_df, half)\n", | |
" x = tf.where(use_symmetry_relation, one_minus_ratio, ratio)\n", | |
"\n", | |
" # Prepare betainc inputs to make the evaluation of its gradients easier.\n", | |
" use_betainc = ~use_asymptotic_expansion\n", | |
" a = tf.where(use_betainc, a, half)\n", | |
" b = tf.where(use_betainc, b, half)\n", | |
" x = tf.where(use_betainc, x, half)\n", | |
"\n", | |
" betainc_grad_a, betainc_grad_b = gradient.value_and_gradient(\n", | |
" lambda y, z: special.betainc(y, z, x), [a, b])[1]\n", | |
" betainc_grad_a = tf.where(\n", | |
" use_symmetry_relation, -betainc_grad_b, betainc_grad_a)\n", | |
"\n", | |
" stdtr_grad_df_betainc = half * (\n", | |
" betainc_grad_a * half + stdtr_grad_t * abs_t / df)\n", | |
" # Handle the case (t >= 0).\n", | |
" stdtr_grad_df_betainc = tf.where(\n", | |
" t >= 0., -stdtr_grad_df_betainc, stdtr_grad_df_betainc)\n", | |
" # Handle the case (abs_t < min_abs_t_betainc): we use again a rough linear\n", | |
" # approximation.\n", | |
" stdtr_grad_df_betainc = tf.where(\n", | |
" abs_t_is_too_small, stdtr_grad_df_betainc * abs_t, stdtr_grad_df_betainc)\n", | |
"\n", | |
" stdtr_grad_df = tf.where(\n", | |
" use_asymptotic_expansion, stdtr_grad_df, stdtr_grad_df_betainc)\n", | |
"\n", | |
" # Determine if df is out of range (should return NaN output).\n", | |
" result_is_nan = (df <= 0.)\n", | |
" stdtr_grad_df, stdtr_grad_t = [\n", | |
" tf.where(result_is_nan, numpy_dtype(np.nan), grad)\n", | |
" for grad in [stdtr_grad_df, stdtr_grad_t]]\n", | |
"\n", | |
" return stdtr_grad_df, stdtr_grad_t" | |
], | |
"metadata": { | |
"id": "wD4LbJTVgJkK" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtr_fwd(df, t):\n", | |
" \"\"\"Computes output, aux (collaborates with _stdtr_bwd).\"\"\"\n", | |
" output = _stdtr_computation(df, t)\n", | |
" return output, (df, t)" | |
], | |
"metadata": { | |
"id": "Mr1Q1WfGdt0q" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtr_bwd(aux, g):\n", | |
" \"\"\"Reverse mode impl for stdtr.\"\"\"\n", | |
" df, t = aux\n", | |
" partial_df, partial_t = _stdtr_partials(df, t)\n", | |
" return generic.fix_gradient_for_broadcasting(\n", | |
" [df, t], [partial_df * g, partial_t * g])" | |
], | |
"metadata": { | |
"id": "ZuZs6ZnseE8k" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtr_jvp(primals, tangents):\n", | |
" \"\"\"Computes JVP for stdtr (supports JAX custom derivative).\"\"\"\n", | |
" df, t = primals\n", | |
" ddf, dt = tangents\n", | |
"\n", | |
" p = _stdtr_custom_gradient(df, t)\n", | |
" partial_df, partial_t = _stdtr_partials(df, t)\n", | |
" return (p, partial_df * ddf + partial_t * dt)" | |
], | |
"metadata": { | |
"id": "3tbl6WjFaD8N" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@tfp_custom_gradient.custom_gradient(\n", | |
" vjp_fwd=_stdtr_fwd,\n", | |
" vjp_bwd=_stdtr_bwd,\n", | |
" jvp_fn=_stdtr_jvp)\n", | |
"def _stdtr_custom_gradient(df, t):\n", | |
" \"\"\"Computes `stdtr(df, t)` with correct custom gradient.\"\"\"\n", | |
" return _stdtr_computation(df, t)" | |
], | |
"metadata": { | |
"id": "N3PUwyQiaDyT" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def stdtr(df, t, name=None):\n", | |
" \"\"\"Computes the cumulative distribution function of Student's t-distribution.\n", | |
"\n", | |
" Returns the area under the probability density function of this distribution\n", | |
" with `df > 0` degrees of freedom, integrated from minus infinity to `t`.\n", | |
"\n", | |
" Args:\n", | |
" df: A `Tensor`. Must be one of the following types: `float32`, `float64`.\n", | |
" t: A `Tensor`. Must have the same type as `df`.\n", | |
" name: A name for the operation (optional).\n", | |
"\n", | |
" Returns:\n", | |
" A `Tensor` with shape broadcast according to the arguments.\n", | |
"\n", | |
" Raises:\n", | |
" TypeError: if `df` is not one of the following types: `float32`, `float64`.\n", | |
" \"\"\"\n", | |
" with tf.name_scope(name or 'stdtr'):\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
" df = tf.convert_to_tensor(df, dtype=dtype)\n", | |
" t = tf.convert_to_tensor(t, dtype=dtype)\n", | |
"\n", | |
" if dtype_util.as_numpy_dtype(dtype) not in [np.float32, np.float64]:\n", | |
" raise TypeError(f'df.dtype={dtype} is not handled. '\n", | |
" 'See docstring for supported types.')\n", | |
"\n", | |
" return _stdtr_custom_gradient(df, t)" | |
], | |
"metadata": { | |
"id": "5UzA9QTkcFng" | |
}, | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"stdtr = tf.function(stdtr, autograph=False)" | |
], | |
"metadata": { | |
"id": "x2U2tpEvv7I9" | |
}, | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_betaincinv(df, p, numpy_dtype, use_betaincinv):\n", | |
" \"\"\"Computes `stdtrit(df, p)` using special.betaincinv.\"\"\"\n", | |
" # This function inverts the procedure that computes stdtr(df, t) using the\n", | |
" # regularized incomplete beta function. For details on this procedure, see\n", | |
" # the function _stdtr_computation.\n", | |
" # We assume here that condition (p <= 0.5) is always true.\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
"\n", | |
" half_df = half * df\n", | |
" two_p = numpy_dtype(2.) * p\n", | |
"\n", | |
" use_symmetry_relation = (\n", | |
" p > (half * special.betainc(half_df, half, numpy_dtype(0.99))))\n", | |
" a = tf.where(use_symmetry_relation, half, half_df)\n", | |
" b = tf.where(use_symmetry_relation, half_df, half)\n", | |
" y = tf.where(use_symmetry_relation, one - two_p, two_p)\n", | |
"\n", | |
" # Prepare betaincinv inputs to make its evaluation easier.\n", | |
" a = tf.where(use_betaincinv, a, half)\n", | |
" b = tf.where(use_betaincinv, b, half)\n", | |
" y = tf.where(use_betaincinv, y, numpy_dtype(0.))\n", | |
"\n", | |
" x = special.betaincinv(a, b, y)\n", | |
"\n", | |
" log_abs_t = half * (\n", | |
" tf.math.log(df) + tf.where(use_symmetry_relation, one, -one) * (\n", | |
" tf.math.log(x) - tf.math.log1p(-x)))\n", | |
"\n", | |
" return -tf.math.exp(log_abs_t)" | |
], | |
"metadata": { | |
"id": "m6-8-hFlgA9z" | |
}, | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_series_expansion(df, p, numpy_dtype):\n", | |
" \"\"\"Computes `stdtrit(df, p)` using series expansion.\"\"\"\n", | |
" # This function provides a fast approximation of stdtrit(df, p) for df >= 1.\n", | |
" # It is based on an asymptotic inverse expansion of Cornish-Fisher type about\n", | |
" # normal deviates. But for small p, where t**2 / df is large, a second series\n", | |
" # expansion is used to achieve sufficient accuracy. Both approximations were\n", | |
" # proposed in [4].\n", | |
" # We assume here that condition (p <= 0.5) is always true.\n", | |
" half, one, two, three, four, five, six, seven = [\n", | |
" numpy_dtype(n) for n in (0.5,) + tuple(range(1, 8))]\n", | |
"\n", | |
" a = tf.math.reciprocal(df - half)\n", | |
" b = numpy_dtype(48.) / tf.math.square(a)\n", | |
" c = numpy_dtype(96.36) + a * (\n", | |
" (numpy_dtype(20700.) * a / b - numpy_dtype(98.)) * a - numpy_dtype(16.))\n", | |
" d = df * tf.math.sqrt(a * half * numpy_dtype(np.pi)) * (\n", | |
" (numpy_dtype(94.5) / (b + c) - three) / b + one)\n", | |
"\n", | |
" # First series expansion: asymptotic inverse expansion about normal deviates.\n", | |
" z = tf.math.ndtri(p)\n", | |
" squared_z = tf.math.square(z)\n", | |
" c = b + c + z * (\n", | |
" ((numpy_dtype(0.05) * d * z - five) * z - seven) * z - two)\n", | |
" c_correction = numpy_dtype(0.3) * (\n", | |
" df - numpy_dtype(4.5)) * (z + numpy_dtype(0.6))\n", | |
" c = tf.where(df >= 5., c, c + c_correction)\n", | |
"\n", | |
" squared_t_over_df = numpy_dtype(0.4) * squared_z + numpy_dtype(6.3)\n", | |
" squared_t_over_df = squared_t_over_df * squared_z + numpy_dtype(36.)\n", | |
" squared_t_over_df = squared_t_over_df * squared_z + numpy_dtype(94.5)\n", | |
" squared_t_over_df = z * (\n", | |
" (squared_t_over_df / c - squared_z - three) / b + one)\n", | |
" squared_t_over_df = tf.math.expm1(a * tf.math.square(squared_t_over_df))\n", | |
"\n", | |
" # Second series expansion.\n", | |
" y = tf.math.exp(two / df * (\n", | |
" tf.math.log(d) + tf.math.log(two) + tf.math.log(p)))\n", | |
"\n", | |
" df_plus_2 = df + two\n", | |
" large_squared_t_over_df = (df + six) / (df * y) - numpy_dtype(0.089) * d\n", | |
" large_squared_t_over_df = df_plus_2 * three * (\n", | |
" large_squared_t_over_df - numpy_dtype(0.822))\n", | |
" large_squared_t_over_df = large_squared_t_over_df + half / (df + four)\n", | |
" large_squared_t_over_df = y / large_squared_t_over_df - one\n", | |
" large_squared_t_over_df = tf.math.reciprocal(y) + large_squared_t_over_df * (\n", | |
" (df + one) / df_plus_2)\n", | |
"\n", | |
" p_is_not_small = (y >= (numpy_dtype(0.05) + a))\n", | |
" # The condition to use the first series expansion was improved in [6].\n", | |
" use_first_series_expansion = p_is_not_small | ((df < 2.1) & (p > 0.25))\n", | |
" squared_t_over_df = tf.where(\n", | |
" use_first_series_expansion, squared_t_over_df, large_squared_t_over_df)\n", | |
"\n", | |
" return -tf.math.sqrt(df * squared_t_over_df)" | |
], | |
"metadata": { | |
"id": "lNls1EHFKqz6" | |
}, | |
"execution_count": 20, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_computation(df, p):\n", | |
" \"\"\"Returns the inverse of `stdtr(df, t)` with respect to `t`.\"\"\"\n", | |
" # This function increases the accuracy of an initial estimate for t by using\n", | |
" # Taylor series expansion iterations as proposed in [5].\n", | |
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
" zero = numpy_dtype(0.)\n", | |
" eps = np.finfo(numpy_dtype).eps\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
"\n", | |
" df, p = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, p)]\n", | |
" broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(p)])\n", | |
" df, p = [tf.broadcast_to(param, broadcast_shape) for param in (df, p)]\n", | |
"\n", | |
" # max_iterations, use_betaincinv, and tolerance were set by experimentation.\n", | |
" max_iterations = 3\n", | |
" if numpy_dtype == np.float32:\n", | |
" use_betaincinv = (df < 2.)\n", | |
" tolerance = numpy_dtype(8.) * eps\n", | |
" else:\n", | |
" use_betaincinv = (df < 1.)\n", | |
" tolerance = numpy_dtype(4096.) * eps\n", | |
"\n", | |
" adjusted_p = tf.where(p < 0.5, p, one - p)\n", | |
" initial_candidate = tf.where(\n", | |
" use_betaincinv,\n", | |
" # Since _stdtrit_betaincinv is expensive, we pass use_betaincinv to it\n", | |
" # to save computation.\n", | |
" _stdtrit_betaincinv(df, adjusted_p, numpy_dtype, use_betaincinv),\n", | |
" _stdtrit_series_expansion(df, adjusted_p, numpy_dtype))\n", | |
"\n", | |
" def taylor_expansion_improvement(should_stop, candidate):\n", | |
" stdtr_grad_t = tf.math.exp(log_prob(candidate, df, zero, one))\n", | |
" should_stop = should_stop | tf.math.equal(stdtr_grad_t, zero)\n", | |
"\n", | |
" first_order_correction = (adjusted_p - stdtr(df, candidate)) / stdtr_grad_t\n", | |
"\n", | |
" candidate_is_zero = tf.math.equal(candidate, zero)\n", | |
" safe_inv_candidate = tf.where(\n", | |
" candidate_is_zero, one, tf.math.reciprocal(candidate))\n", | |
" second_order_correction = half * (df + one) * tf.math.square(\n", | |
" first_order_correction) * safe_inv_candidate * tf.math.reciprocal(\n", | |
" one + (df * safe_inv_candidate) * safe_inv_candidate)\n", | |
" second_order_correction = tf.where(\n", | |
" candidate_is_zero, zero, second_order_correction)\n", | |
"\n", | |
" correction = first_order_correction + second_order_correction\n", | |
" new_candidate = tf.where(should_stop, candidate, candidate + correction)\n", | |
"\n", | |
" adjusted_tolerance = tf.math.abs(tolerance * new_candidate)\n", | |
" should_stop = should_stop | (tf.math.abs(correction) <= adjusted_tolerance)\n", | |
"\n", | |
" return should_stop, new_candidate\n", | |
"\n", | |
" (_, result) = tf.while_loop(\n", | |
" cond=lambda stop, _: tf.reduce_any(~stop),\n", | |
" body=taylor_expansion_improvement,\n", | |
" loop_vars=(\n", | |
" ~tf.math.is_finite(initial_candidate),\n", | |
" initial_candidate),\n", | |
" maximum_iterations=max_iterations)\n", | |
"\n", | |
" # Handle the case (p >= 0.5).\n", | |
" result = tf.math.sign(half - p) * result\n", | |
"\n", | |
" # Determine if the inputs are out of range (should return NaN output).\n", | |
" result_is_nan = (p <= zero) | (p >= one) | (df <= zero)\n", | |
" result = tf.where(result_is_nan, numpy_dtype(np.nan), result)\n", | |
"\n", | |
" return result" | |
], | |
"metadata": { | |
"id": "MRWxtNSwqvm7" | |
}, | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_partials(df, p, return_value=False):\n", | |
" \"\"\"Returns the partial derivatives of `stdtrit(df, p)`.\"\"\"\n", | |
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n", | |
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n", | |
"\n", | |
" df, p = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, p)]\n", | |
" broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(p)])\n", | |
" df, p = [tf.broadcast_to(param, broadcast_shape) for param in (df, p)]\n", | |
"\n", | |
" # We use the fact that stdtr and stdtrit are inverses of each other to\n", | |
" # compute the gradients.\n", | |
" t = _stdtrit_custom_gradient(df, p)\n", | |
" stdtr_partial_df, stdtr_partial_t = _stdtr_partials(df, t)\n", | |
"\n", | |
" partial_df = -stdtr_partial_df / stdtr_partial_t\n", | |
" partial_p = tf.math.reciprocal(stdtr_partial_t)\n", | |
"\n", | |
" if return_value:\n", | |
" results = [partial_df, partial_p, t]\n", | |
" else:\n", | |
" results = [partial_df, partial_p]\n", | |
"\n", | |
" # Determine if the inputs are out of range (should return NaN output).\n", | |
" result_is_nan = (p <= 0.) | (p >= 1.) | (df <= 0.)\n", | |
" results = [\n", | |
" tf.where(result_is_nan, numpy_dtype(np.nan), result)\n", | |
" for result in results]\n", | |
"\n", | |
" return results" | |
], | |
"metadata": { | |
"id": "VUQipfTj6xUI" | |
}, | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_fwd(df, p):\n", | |
" \"\"\"Computes output, aux (collaborates with _stdtrit_bwd).\"\"\"\n", | |
" output = _stdtrit_computation(df, p)\n", | |
" return output, (df, p)" | |
], | |
"metadata": { | |
"id": "URcsbziz6DOw" | |
}, | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_bwd(aux, g):\n", | |
" \"\"\"Reverse mode impl for stdtrit.\"\"\"\n", | |
" df, p = aux\n", | |
" partial_df, partial_p, *_ = _stdtrit_partials(df, p)\n", | |
" return generic.fix_gradient_for_broadcasting(\n", | |
" [df, p], [partial_df * g, partial_p * g])" | |
], | |
"metadata": { | |
"id": "XiMd364A6DOx" | |
}, | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _stdtrit_jvp(primals, tangents):\n", | |
" \"\"\"Computes JVP for stdtrit (supports JAX custom derivative).\"\"\"\n", | |
" df, p = primals\n", | |
" ddf, dp = tangents\n", | |
"\n", | |
" partial_df, partial_p, t = _stdtrit_partials(df, p)\n", | |
" return (t, partial_df * ddf + partial_p * dp)" | |
], | |
"metadata": { | |
"id": "OVeDoY3v6DOx" | |
}, | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@tfp_custom_gradient.custom_gradient(\n", | |
" vjp_fwd=_stdtrit_fwd,\n", | |
" vjp_bwd=_stdtrit_bwd,\n", | |
" jvp_fn=_stdtrit_jvp)\n", | |
"def _stdtrit_custom_gradient(df, p):\n", | |
" \"\"\"Computes `stdtrit(df, p)` with correct custom gradient.\"\"\"\n", | |
" return _stdtrit_computation(df, p)" | |
], | |
"metadata": { | |
"id": "mPQ4epPt6DOx" | |
}, | |
"execution_count": 26, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def stdtrit(df, p, name=None):\n", | |
" \"\"\"Computes the inverse of `stdtr` with respect to `t`.\n", | |
"\n", | |
" This function returns a value `t` such that `p = stdtr(df, t)`.\n", | |
"\n", | |
" Args:\n", | |
" df: A `Tensor`. Must be one of the following types: `float32`, `float64`.\n", | |
" p: A `Tensor`. Must have the same type as `df`.\n", | |
" name: A name for the operation (optional).\n", | |
"\n", | |
" Returns:\n", | |
" A `Tensor` with shape broadcast according to the arguments.\n", | |
"\n", | |
" Raises:\n", | |
" TypeError: if `df` is not one of the following types: `float32`, `float64`.\n", | |
" \"\"\"\n", | |
" with tf.name_scope(name or 'stdtrit'):\n", | |
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n", | |
" df = tf.convert_to_tensor(df, dtype=dtype)\n", | |
" p = tf.convert_to_tensor(p, dtype=dtype)\n", | |
"\n", | |
" if dtype_util.as_numpy_dtype(dtype) not in [np.float32, np.float64]:\n", | |
" raise TypeError(f'df.dtype={dtype} is not handled. '\n", | |
" 'See docstring for supported types.')\n", | |
"\n", | |
" return _stdtrit_custom_gradient(df, p)" | |
], | |
"metadata": { | |
"id": "z0sZtbJf6DOx" | |
}, | |
"execution_count": 27, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"stdtrit = tf.function(stdtrit, autograph=False)" | |
], | |
"metadata": { | |
"id": "TsOEFR06jME_" | |
}, | |
"execution_count": 28, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 2. Test settings and auxiliary functions" | |
], | |
"metadata": { | |
"id": "J-DE_UILC5cM" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"SEEDS = [1, 17, 42, 51, 184, 301, 346, 448, 733, 985]" | |
], | |
"metadata": { | |
"id": "pNFf91IUC4_R" | |
}, | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"SIZE_PER_SEED = 100" | |
], | |
"metadata": { | |
"id": "l-Z4RxdHDKFn" | |
}, | |
"execution_count": 30, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Used to test the partial derivative of stdtr with respect to df.\n", | |
"SMALL_SIZE_PER_SEED = 5" | |
], | |
"metadata": { | |
"id": "M4QpX0GzWyOH" | |
}, | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sample_specs = {\n", | |
" np.float64: [\n", | |
" (0., 1.),\n", | |
" (1., 10.),\n", | |
" (10., 30.),\n", | |
" (30., 100.),\n", | |
" (100., 1e+3),\n", | |
" (1e+3, 1e+4),\n", | |
" (1e+4, 1e+5),\n", | |
" (1e+5, 1e+6),\n", | |
" (1e+6, 1e+7),\n", | |
" (1e+7, 1e+8)],\n", | |
" np.float32: [\n", | |
" (0., 1.),\n", | |
" (1., 10.),\n", | |
" (10., 30.),\n", | |
" (30., 100.),\n", | |
" (100., 1e+3),\n", | |
" (1e+3, 1e+4)]}" | |
], | |
"metadata": { | |
"id": "kPvY6ZUQDSLZ" | |
}, | |
"execution_count": 32, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def make_samples(sample_specs, size_per_seed, seeds, dtype, center=False):\n", | |
" tiny = np.finfo(dtype).tiny\n", | |
" eps = np.finfo(dtype).eps\n", | |
" size = len(seeds) * size_per_seed\n", | |
"\n", | |
" samples = []\n", | |
" for spec in sample_specs:\n", | |
" dfs = []\n", | |
"\n", | |
" for seed in seeds:\n", | |
" rng = np.random.RandomState(seed)\n", | |
" df_min, df_max = spec\n", | |
" df = rng.uniform(low=df_min, high=df_max, size=(size_per_seed, 1))\n", | |
" dfs.append(df.astype(dtype))\n", | |
"\n", | |
" df = np.row_stack(dfs)\n", | |
"\n", | |
" if center:\n", | |
" # Generates values of t near zero.\n", | |
" step_size = eps\n", | |
" middle = size_per_seed // 2\n", | |
" min_p = 0.5 - (middle) * step_size\n", | |
" max_p = 0.5 + (size_per_seed - middle) * step_size\n", | |
" p = np.linspace(min_p, max_p, size).astype(dtype)\n", | |
" p[middle] = dtype(0.5)\n", | |
" else:\n", | |
" max_t = np.sqrt(df) / np.sqrt(10. * tiny)\n", | |
" min_t = -max_t\n", | |
" max_p = np.minimum(1. - eps, sp_special.stdtr(df, max_t)).astype(dtype)\n", | |
" min_p = np.maximum(eps, sp_special.stdtr(df, min_t)).astype(dtype)\n", | |
" p = np.linspace(min_p, max_p, size).astype(dtype)\n", | |
"\n", | |
" t = sp_special.stdtrit(df, p).astype(dtype)\n", | |
" samples.append([df, p, t])\n", | |
"\n", | |
" return samples" | |
], | |
"metadata": { | |
"id": "CSkbLyRJHq-O" | |
}, | |
"execution_count": 33, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"`format_number` and `get_metrics_dataframe` functions are used to print test results." | |
], | |
"metadata": { | |
"id": "vsIEqUoPQtaE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def format_number(num):\n", | |
" if num == 0:\n", | |
" return '0'\n", | |
" elif -2 < np.log10(np.abs(num)) < 6: \n", | |
" if num == int(num):\n", | |
" return f'{num:.0f}'\n", | |
" else:\n", | |
" return f'{num:.2f}'\n", | |
" return f'{num:.1e}'" | |
], | |
"metadata": { | |
"id": "Biw3tiGVucgS" | |
}, | |
"execution_count": 34, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_metrics_dataframe(\n", | |
" sample_specs,\n", | |
" truths,\n", | |
" results,\n", | |
" dtype,\n", | |
"):\n", | |
" records = []\n", | |
"\n", | |
" for idx, (df_min, df_max) in enumerate(sample_specs[dtype]):\n", | |
" truth = np.array(truths[idx], dtype=np.float64)\n", | |
" result = np.array(results[idx], dtype=np.float64)\n", | |
"\n", | |
" size = np.prod(truth.shape)\n", | |
" record = {'Min Df': format_number(df_min)}\n", | |
" record['Max Df'] = format_number(df_max)\n", | |
" record['#Trials'] = format_number(size)\n", | |
"\n", | |
" min_val = np.min(truth)\n", | |
" max_val = np.max(truth)\n", | |
" record['Min Value'] = format_number(min_val)\n", | |
" record['Max Value'] = format_number(max_val)\n", | |
"\n", | |
" perc_nan = np.sum(np.isnan(result)) / size * 100\n", | |
" perc_inf = np.sum(np.isinf(result)) / size * 100\n", | |
" record['%NaN'] = format_number(perc_nan)\n", | |
" record['%Inf'] = format_number(perc_inf)\n", | |
"\n", | |
" if (perc_nan + perc_inf) == 100:\n", | |
" record['Max Abs Error'] = 'NaN'\n", | |
" record['Mean Abs Error'] = 'NaN'\n", | |
" record['Max Rel Error'] = 'NaN'\n", | |
" record['Mean Rel Error'] = 'NaN'\n", | |
" else:\n", | |
" abserr = np.abs(truth - result)\n", | |
" abserr_valid = np.isfinite(abserr)\n", | |
" max_abserr = np.max(abserr, initial=0., where=abserr_valid)\n", | |
" mean_abserr = np.mean(abserr, where=abserr_valid)\n", | |
" record['Max Abs Error'] = format_number(max_abserr)\n", | |
" record['Mean Abs Error'] = format_number(mean_abserr)\n", | |
"\n", | |
" relerr = abserr / np.abs(truth)\n", | |
" relerr = np.where(np.equal(result, truth), np.float64(0.), relerr)\n", | |
" relerr_valid = np.isfinite(relerr)\n", | |
" max_relerr = np.max(relerr, initial=0., where=relerr_valid)\n", | |
" mean_relerr = np.mean(relerr, where=relerr_valid)\n", | |
" record['Max Rel Error'] = format_number(max_relerr)\n", | |
" record['Mean Rel Error'] = format_number(mean_relerr)\n", | |
"\n", | |
" records.append(record)\n", | |
"\n", | |
" return pd.DataFrame.from_records(records)" | |
], | |
"metadata": { | |
"id": "XRvmj1-9Kukh" | |
}, | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 3. Test results for `float32`. Tiny values of `t`." | |
], | |
"metadata": { | |
"id": "BEseQgBKTp2g" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"DTYPE = np.float32" | |
], | |
"metadata": { | |
"id": "KSS4xfUr7odg" | |
}, | |
"execution_count": 36, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"CENTER = True" | |
], | |
"metadata": { | |
"id": "UiAnA5WW8Abd" | |
}, | |
"execution_count": 37, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"samples = make_samples(\n", | |
" sample_specs[DTYPE],\n", | |
" size_per_seed=SIZE_PER_SEED,\n", | |
" seeds=SEEDS,\n", | |
" dtype=DTYPE,\n", | |
" center=CENTER)" | |
], | |
"metadata": { | |
"id": "1l7QBKQENLff" | |
}, | |
"execution_count": 38, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 3.1. Test results for CDF" | |
], | |
"metadata": { | |
"id": "bWtClyovXDq4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sp_cdf = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(samples):\n", | |
" result = sp_special.stdtr(df, t)\n", | |
" sp_cdf.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "bf9dd137-241a-4360-88b5-916f2a77d7f8", | |
"id": "3X9HQQiEXDrA" | |
}, | |
"execution_count": 39, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_tf_cdf = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(samples):\n", | |
" result = current_cdf(df, t)\n", | |
" current_tf_cdf.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "3f999df1-317b-449e-a45b-a3e56a7de7a9", | |
"id": "pU1y8-QOXDrB" | |
}, | |
"execution_count": 40, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_tf_cdf = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(samples):\n", | |
" result = stdtr(df, t)\n", | |
" new_tf_cdf.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "c7d10279-466b-47af-bd56-f7de1db14784", | |
"id": "LyXlE-pJXDrB" | |
}, | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: Current implementation of CDF' + '\\033[0m')\n", | |
"print(f'Benchmark: Scipy implementation of CDF')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, sp_cdf, current_tf_cdf, DTYPE)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "701d0e10-ad35-4042-d721-3c3c3cf2d81d", | |
"id": "6lC7Xp-XXDrB" | |
}, | |
"execution_count": 42, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: Current implementation of CDF\u001b[0m\n", | |
"Benchmark: Scipy implementation of CDF\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1.0e+06 0.50 0.50 0 0 6.0e-06 \n", | |
"1 1 10 1.0e+06 0.50 0.50 0 0 6.0e-06 \n", | |
"2 10 30 1.0e+06 0.50 0.50 0 0 6.0e-06 \n", | |
"3 30 100 1.0e+06 0.50 0.50 0 0 6.0e-06 \n", | |
"4 100 1000 1.0e+06 0.50 0.50 0 0 6.0e-06 \n", | |
"5 1000 10000 1.0e+06 0.50 0.50 0 0 6.0e-06 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 2.9e-06 1.2e-05 5.7e-06 \n", | |
"1 3.0e-06 1.2e-05 6.0e-06 \n", | |
"2 3.0e-06 1.2e-05 6.0e-06 \n", | |
"3 3.0e-06 1.2e-05 6.0e-06 \n", | |
"4 3.0e-06 1.2e-05 6.0e-06 \n", | |
"5 3.0e-06 1.2e-05 6.0e-06 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-482d5f0d-089a-4a00-90fb-fa53207cb631\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-06</td>\n", | |
" <td>2.9e-06</td>\n", | |
" <td>1.2e-05</td>\n", | |
" <td>5.7e-06</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-06</td>\n", | |
" <td>3.0e-06</td>\n", | |
" <td>1.2e-05</td>\n", | |
" <td>6.0e-06</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-06</td>\n", | |
" <td>3.0e-06</td>\n", | |
" <td>1.2e-05</td>\n", | |
" <td>6.0e-06</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-06</td>\n", | |
" <td>3.0e-06</td>\n", | |
" <td>1.2e-05</td>\n", | |
" <td>6.0e-06</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-06</td>\n", | |
" <td>3.0e-06</td>\n", | |
" <td>1.2e-05</td>\n", | |
" <td>6.0e-06</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-06</td>\n", | |
" <td>3.0e-06</td>\n", | |
" <td>1.2e-05</td>\n", | |
" <td>6.0e-06</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-482d5f0d-089a-4a00-90fb-fa53207cb631')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-482d5f0d-089a-4a00-90fb-fa53207cb631 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-482d5f0d-089a-4a00-90fb-fa53207cb631');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 42 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: New implementation of CDF' + '\\033[0m')\n", | |
"print(f'Benchmark: Scipy implementation of CDF')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, sp_cdf, new_tf_cdf, DTYPE)" | |
], | |
"metadata": { | |
"id": "iFDvNEJ47nhN", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "dfc9daa2-1503-4894-d6c7-4df1eb5cf39e" | |
}, | |
"execution_count": 43, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: New implementation of CDF\u001b[0m\n", | |
"Benchmark: Scipy implementation of CDF\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1.0e+06 0.50 0.50 0 0 6.0e-08 \n", | |
"1 1 10 1.0e+06 0.50 0.50 0 0 0 \n", | |
"2 10 30 1.0e+06 0.50 0.50 0 0 0 \n", | |
"3 30 100 1.0e+06 0.50 0.50 0 0 0 \n", | |
"4 100 1000 1.0e+06 0.50 0.50 0 0 0 \n", | |
"5 1000 10000 1.0e+06 0.50 0.50 0 0 0 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 7.2e-12 1.2e-07 1.4e-11 \n", | |
"1 0 0 0 \n", | |
"2 0 0 0 \n", | |
"3 0 0 0 \n", | |
"4 0 0 0 \n", | |
"5 0 0 0 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-482f9bed-c650-41f9-9947-54f2634a378d\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-08</td>\n", | |
" <td>7.2e-12</td>\n", | |
" <td>1.2e-07</td>\n", | |
" <td>1.4e-11</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-482f9bed-c650-41f9-9947-54f2634a378d')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-482f9bed-c650-41f9-9947-54f2634a378d button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-482f9bed-c650-41f9-9947-54f2634a378d');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 43 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 3.1.1. Partial derivative of CDF with respect to `t`" | |
], | |
"metadata": { | |
"id": "sf-sQ9pV-_oD" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tf_prob = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(samples):\n", | |
" result = tf.math.exp(log_prob(t, df, DTYPE(0.), DTYPE(1.)))\n", | |
" tf_prob.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"id": "LTLS6ZUj-_oJ", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "07e1862f-5e36-4a32-a540-8c0964a972c2" | |
}, | |
"execution_count": 44, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def current_cdf_partial_t(df, t):\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
"\n", | |
" df = tf.convert_to_tensor(df, dtype=dtype)\n", | |
" t = tf.convert_to_tensor(t, dtype=dtype)\n", | |
"\n", | |
" return gradient.value_and_gradient(lambda _t: current_cdf(df, _t), t)[1]" | |
], | |
"metadata": { | |
"id": "rFxKbxZk_gb4" | |
}, | |
"execution_count": 45, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_cdf_partial_t = tf.function(current_cdf_partial_t, autograph=False)" | |
], | |
"metadata": { | |
"id": "iwRnDTarAkXC" | |
}, | |
"execution_count": 46, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_tf_partial_t = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(samples):\n", | |
" result = current_cdf_partial_t(df, t)\n", | |
" current_tf_partial_t.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"id": "8CNtST4E-_oJ", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "967ee7b6-71db-46e3-de37-c09b3f2a081e" | |
}, | |
"execution_count": 47, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def new_stdtr_partial_t(df, t):\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
"\n", | |
" df = tf.convert_to_tensor(df, dtype=dtype)\n", | |
" t = tf.convert_to_tensor(t, dtype=dtype)\n", | |
"\n", | |
" return gradient.value_and_gradient(lambda _t: stdtr(df, _t), t)[1]" | |
], | |
"metadata": { | |
"id": "0B5q5jWREX9C" | |
}, | |
"execution_count": 48, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_stdtr_partial_t = tf.function(new_stdtr_partial_t, autograph=False)" | |
], | |
"metadata": { | |
"id": "hQAg19VFEX9D" | |
}, | |
"execution_count": 49, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_tf_partial_t = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(samples):\n", | |
" result = new_stdtr_partial_t(df, t)\n", | |
" new_tf_partial_t.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"id": "UkMsAlIG-_oJ", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "0f651964-12fa-4c1c-88c7-06eb23cc696f" | |
}, | |
"execution_count": 50, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: Current implementation of partial derivative of CDF wrt t' + '\\033[0m')\n", | |
"print(f'Benchmark: TFP implementation of prob')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, tf_prob, current_tf_partial_t, DTYPE)" | |
], | |
"metadata": { | |
"id": "2HssqsVs-_oK", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "12833118-6d53-46c6-fc2e-902e3327bfea" | |
}, | |
"execution_count": 51, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: Current implementation of partial derivative of CDF wrt t\u001b[0m\n", | |
"Benchmark: TFP implementation of prob\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1.0e+06 3.4e-03 0.32 0 96.90 0.04 \n", | |
"1 1 10 1.0e+06 0.32 0.39 0 100 NaN \n", | |
"2 10 30 1.0e+06 0.39 0.40 0 100 NaN \n", | |
"3 30 100 1.0e+06 0.40 0.40 0 100 NaN \n", | |
"4 100 1000 1.0e+06 0.40 0.40 0 100 NaN \n", | |
"5 1000 10000 1.0e+06 0.40 0.40 0 100 NaN \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 6.5e-03 0.49 0.08 \n", | |
"1 NaN NaN NaN \n", | |
"2 NaN NaN NaN \n", | |
"3 NaN NaN NaN \n", | |
"4 NaN NaN NaN \n", | |
"5 NaN NaN NaN " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-62067040-5d3d-41b9-9f09-fc2f2fb55203\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>3.4e-03</td>\n", | |
" <td>0.32</td>\n", | |
" <td>0</td>\n", | |
" <td>96.90</td>\n", | |
" <td>0.04</td>\n", | |
" <td>6.5e-03</td>\n", | |
" <td>0.49</td>\n", | |
" <td>0.08</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.32</td>\n", | |
" <td>0.39</td>\n", | |
" <td>0</td>\n", | |
" <td>100</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.39</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>100</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>100</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>100</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>100</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-62067040-5d3d-41b9-9f09-fc2f2fb55203')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-62067040-5d3d-41b9-9f09-fc2f2fb55203 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-62067040-5d3d-41b9-9f09-fc2f2fb55203');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 51 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: New implementation of partial derivative of CDF wrt t' + '\\033[0m')\n", | |
"print(f'Benchmark: TFP implementation of prob')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, tf_prob, new_tf_partial_t, DTYPE)" | |
], | |
"metadata": { | |
"id": "3LXRWQX2-_oK", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "895455a9-09e4-4b69-9784-c7538ffcbda9" | |
}, | |
"execution_count": 52, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: New implementation of partial derivative of CDF wrt t\u001b[0m\n", | |
"Benchmark: TFP implementation of prob\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1.0e+06 3.4e-03 0.32 0 0 0 \n", | |
"1 1 10 1.0e+06 0.32 0.39 0 0 0 \n", | |
"2 10 30 1.0e+06 0.39 0.40 0 0 0 \n", | |
"3 30 100 1.0e+06 0.40 0.40 0 0 0 \n", | |
"4 100 1000 1.0e+06 0.40 0.40 0 0 0 \n", | |
"5 1000 10000 1.0e+06 0.40 0.40 0 0 0 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 0 0 0 \n", | |
"1 0 0 0 \n", | |
"2 0 0 0 \n", | |
"3 0 0 0 \n", | |
"4 0 0 0 \n", | |
"5 0 0 0 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-3cf420a3-a738-49cb-a33a-dce46ae47524\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>3.4e-03</td>\n", | |
" <td>0.32</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.32</td>\n", | |
" <td>0.39</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.39</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1.0e+06</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0.40</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-3cf420a3-a738-49cb-a33a-dce46ae47524')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-3cf420a3-a738-49cb-a33a-dce46ae47524 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-3cf420a3-a738-49cb-a33a-dce46ae47524');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 52 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 3.1.2. Partial derivative of CDF with respect to `df`" | |
], | |
"metadata": { | |
"id": "uh7KWm3nXDrE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"small_samples = make_samples(\n", | |
" sample_specs[DTYPE],\n", | |
" size_per_seed=SMALL_SIZE_PER_SEED,\n", | |
" seeds=SEEDS,\n", | |
" dtype=DTYPE,\n", | |
" center=CENTER)" | |
], | |
"metadata": { | |
"id": "TL1Dwp-OXDrE" | |
}, | |
"execution_count": 53, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"df, _, t = small_samples[0]\n", | |
"broadcast_shape = functools.reduce(\n", | |
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])" | |
], | |
"metadata": { | |
"id": "xSZasJMMnvpV" | |
}, | |
"execution_count": 54, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"small_samples = [\n", | |
" [np.broadcast_to(param, broadcast_shape) for param in (df, p, t)]\n", | |
" for (df, p, t) in small_samples]" | |
], | |
"metadata": { | |
"id": "_OtVkNFWnKfr" | |
}, | |
"execution_count": 55, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _mp_cdf(df, t, numpy_dtype):\n", | |
" half = numpy_dtype(0.5)\n", | |
" one = numpy_dtype(1.)\n", | |
"\n", | |
" x_t = df / (t * t + df)\n", | |
" neg_cdf = half * mp.betainc(half * df, half, x1=0, x2=x_t, regularized=True)\n", | |
" return neg_cdf if t < 0. else one - neg_cdf" | |
], | |
"metadata": { | |
"id": "5rJL5ekmLjHb" | |
}, | |
"execution_count": 56, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def _mp_cdf_partial_df(df, t, numpy_dtype):\n", | |
" df, t = [numpy_dtype(z) for z in (df, t)]\n", | |
" return mp.diff(lambda x: _mp_cdf(x, t, numpy_dtype), df)" | |
], | |
"metadata": { | |
"id": "wPW_mUOuLjHc" | |
}, | |
"execution_count": 57, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"mp_cdf_partial_df = np.frompyfunc(_mp_cdf_partial_df, 3, 1)" | |
], | |
"metadata": { | |
"id": "q0ND3q6GLjHc" | |
}, | |
"execution_count": 58, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"mp_partial_df = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(small_samples):\n", | |
" result = mp_cdf_partial_df(df, t, [DTYPE]).astype(DTYPE).astype(DTYPE)\n", | |
" mp_partial_df.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "714e8e76-5c54-458f-9eab-fde74aff0c62", | |
"id": "U_e9W1OVLjHc" | |
}, | |
"execution_count": 59, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def current_cdf_partial_df(df, t):\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
"\n", | |
" df = tf.convert_to_tensor(df, dtype=dtype)\n", | |
" t = tf.convert_to_tensor(t, dtype=dtype)\n", | |
"\n", | |
" return gradient.value_and_gradient(lambda _df: current_cdf(_df, t), df)[1]" | |
], | |
"metadata": { | |
"id": "k7LwONNALjHd" | |
}, | |
"execution_count": 60, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_cdf_partial_df = tf.function(current_cdf_partial_df, autograph=False)" | |
], | |
"metadata": { | |
"id": "P1Wowna4LjHd" | |
}, | |
"execution_count": 61, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_tf_partial_df = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(small_samples):\n", | |
" result = current_cdf_partial_df(df, t)\n", | |
" current_tf_partial_df.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "56be079d-76a0-45ba-f7ee-cd84f66912f1", | |
"id": "tfH3QNqpLjHd" | |
}, | |
"execution_count": 62, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"mp_partial_df[0].shape, mp_partial_df[0].dtype" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GiLM1R35ljf0", | |
"outputId": "4a683d21-e6f1-4b84-85d3-44d177409fee" | |
}, | |
"execution_count": 63, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((50, 50), dtype('float32'))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 63 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_tf_partial_df[0].shape, current_tf_partial_df[0].dtype" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "r1XTAWj5ldPZ", | |
"outputId": "840bbb32-1b09-400d-b250-d0f2ff7b75d7" | |
}, | |
"execution_count": 64, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(TensorShape([50, 50]), tf.float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 64 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def new_stdtr_partial_df(df, t):\n", | |
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n", | |
"\n", | |
" df = tf.convert_to_tensor(df, dtype=dtype)\n", | |
" t = tf.convert_to_tensor(t, dtype=dtype)\n", | |
"\n", | |
" return gradient.value_and_gradient(lambda _df: stdtr(_df, t), df)[1]" | |
], | |
"metadata": { | |
"id": "LJlOKxM_LjHe" | |
}, | |
"execution_count": 65, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_stdtr_partial_df = tf.function(new_stdtr_partial_df, autograph=False)" | |
], | |
"metadata": { | |
"id": "fwBPt9N3LjHe" | |
}, | |
"execution_count": 66, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_tf_partial_df = []\n", | |
"\n", | |
"for i, (df, _, t) in enumerate(small_samples):\n", | |
" result = new_stdtr_partial_df(df, t)\n", | |
" new_tf_partial_df.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "c966b442-38d1-421d-df70-ebf17abec90d", | |
"id": "4Mvh-AABLjHe" | |
}, | |
"execution_count": 67, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_tf_partial_df[0].shape, new_tf_partial_df[0].dtype" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "OgZ2k4fHltl2", | |
"outputId": "2ec65a0a-a771-45bf-9039-6707cda874f0" | |
}, | |
"execution_count": 68, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(TensorShape([50, 50]), tf.float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 68 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: Current implementation of partial derivative of CDF wrt df' + '\\033[0m')\n", | |
"print(f'Benchmark: Mpmath implementation of partial derivative of CDF wrt df')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, mp_partial_df, current_tf_partial_df, DTYPE)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "c3ee83a9-7650-4249-fcbc-d52df3435c91", | |
"id": "z4wnUL24LjHh" | |
}, | |
"execution_count": 69, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: Current implementation of partial derivative of CDF wrt df\u001b[0m\n", | |
"Benchmark: Mpmath implementation of partial derivative of CDF wrt df\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 2500 -1.0e-03 1.6e-03 98.20 0 1.4e-05 \n", | |
"1 1 10 2500 -4.6e-08 6.9e-08 100 0 NaN \n", | |
"2 10 30 2500 -5.9e-10 8.9e-10 100 0 NaN \n", | |
"3 30 100 2500 -6.6e-11 9.9e-11 100 0 NaN \n", | |
"4 100 1000 2500 -5.9e-12 8.9e-12 100 0 NaN \n", | |
"5 1000 10000 2500 -5.9e-14 8.9e-14 100 0 NaN \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 4.6e-06 0.06 0.01 \n", | |
"1 NaN NaN NaN \n", | |
"2 NaN NaN NaN \n", | |
"3 NaN NaN NaN \n", | |
"4 NaN NaN NaN \n", | |
"5 NaN NaN NaN " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-78037e6e-c42d-4bbc-8041-bd33c2563e8b\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>2500</td>\n", | |
" <td>-1.0e-03</td>\n", | |
" <td>1.6e-03</td>\n", | |
" <td>98.20</td>\n", | |
" <td>0</td>\n", | |
" <td>1.4e-05</td>\n", | |
" <td>4.6e-06</td>\n", | |
" <td>0.06</td>\n", | |
" <td>0.01</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>2500</td>\n", | |
" <td>-4.6e-08</td>\n", | |
" <td>6.9e-08</td>\n", | |
" <td>100</td>\n", | |
" <td>0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>2500</td>\n", | |
" <td>-5.9e-10</td>\n", | |
" <td>8.9e-10</td>\n", | |
" <td>100</td>\n", | |
" <td>0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>2500</td>\n", | |
" <td>-6.6e-11</td>\n", | |
" <td>9.9e-11</td>\n", | |
" <td>100</td>\n", | |
" <td>0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>2500</td>\n", | |
" <td>-5.9e-12</td>\n", | |
" <td>8.9e-12</td>\n", | |
" <td>100</td>\n", | |
" <td>0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>2500</td>\n", | |
" <td>-5.9e-14</td>\n", | |
" <td>8.9e-14</td>\n", | |
" <td>100</td>\n", | |
" <td>0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-78037e6e-c42d-4bbc-8041-bd33c2563e8b')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-78037e6e-c42d-4bbc-8041-bd33c2563e8b button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-78037e6e-c42d-4bbc-8041-bd33c2563e8b');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 69 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: New implementation of partial derivative of CDF wrt df' + '\\033[0m')\n", | |
"print(f'Benchmark: Mpmath implementation of partial derivative of CDF wrt df')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, mp_partial_df, new_tf_partial_df, DTYPE)" | |
], | |
"metadata": { | |
"id": "0lrXCEK1zGN6", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "21ca5ded-e0c1-4e71-ba43-c5e42c16ae10" | |
}, | |
"execution_count": 70, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: New implementation of partial derivative of CDF wrt df\u001b[0m\n", | |
"Benchmark: Mpmath implementation of partial derivative of CDF wrt df\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 2500 -1.0e-03 1.6e-03 0 0 1.9e-05 \n", | |
"1 1 10 2500 -4.6e-08 6.9e-08 0 0 6.9e-08 \n", | |
"2 10 30 2500 -5.9e-10 8.9e-10 0 0 5.9e-15 \n", | |
"3 30 100 2500 -6.6e-11 9.9e-11 0 0 2.4e-15 \n", | |
"4 100 1000 2500 -5.9e-12 8.9e-12 0 0 3.3e-14 \n", | |
"5 1000 10000 2500 -5.9e-14 8.9e-14 0 0 1.5e-14 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 4.1e-07 1.00 0.98 \n", | |
"1 4.3e-09 1.00 0.99 \n", | |
"2 6.4e-16 1.00 0.10 \n", | |
"3 2.0e-16 1.00 0.10 \n", | |
"4 7.0e-16 1.00 0.14 \n", | |
"5 1.2e-15 1.00 0.80 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-8b37bc14-1dd5-4553-9e72-eeaa41e15197\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>2500</td>\n", | |
" <td>-1.0e-03</td>\n", | |
" <td>1.6e-03</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1.9e-05</td>\n", | |
" <td>4.1e-07</td>\n", | |
" <td>1.00</td>\n", | |
" <td>0.98</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>2500</td>\n", | |
" <td>-4.6e-08</td>\n", | |
" <td>6.9e-08</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.9e-08</td>\n", | |
" <td>4.3e-09</td>\n", | |
" <td>1.00</td>\n", | |
" <td>0.99</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>2500</td>\n", | |
" <td>-5.9e-10</td>\n", | |
" <td>8.9e-10</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>5.9e-15</td>\n", | |
" <td>6.4e-16</td>\n", | |
" <td>1.00</td>\n", | |
" <td>0.10</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>2500</td>\n", | |
" <td>-6.6e-11</td>\n", | |
" <td>9.9e-11</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>2.4e-15</td>\n", | |
" <td>2.0e-16</td>\n", | |
" <td>1.00</td>\n", | |
" <td>0.10</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>2500</td>\n", | |
" <td>-5.9e-12</td>\n", | |
" <td>8.9e-12</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>3.3e-14</td>\n", | |
" <td>7.0e-16</td>\n", | |
" <td>1.00</td>\n", | |
" <td>0.14</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>2500</td>\n", | |
" <td>-5.9e-14</td>\n", | |
" <td>8.9e-14</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1.5e-14</td>\n", | |
" <td>1.2e-15</td>\n", | |
" <td>1.00</td>\n", | |
" <td>0.80</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-8b37bc14-1dd5-4553-9e72-eeaa41e15197')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-8b37bc14-1dd5-4553-9e72-eeaa41e15197 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-8b37bc14-1dd5-4553-9e72-eeaa41e15197');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 70 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 3.1.3. Partial derivatives of CDF\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "UJ6y7RpPanoB" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Here we compute the maximum difference between autodiff and numerical gradients. We do not run this test for `float32` type nor specifically for values of `t` near zero." | |
], | |
"metadata": { | |
"id": "YzphmoB_anoC" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"space_df = np.logspace(np.log10(0.5), 8., num=15, base=10).tolist()\n", | |
"space_p = np.linspace(0.01, 0.99, num=15).tolist()\n", | |
"df, p = zip(*list(itertools.product(space_df, space_p)))\n", | |
"\n", | |
"df = np.array(df, dtype=np.float64)\n", | |
"p = np.array(p, dtype=np.float64)\n", | |
"t = sp_special.stdtrit(df, p)" | |
], | |
"metadata": { | |
"id": "dhi-C89qanoC" | |
}, | |
"execution_count": 71, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"t.shape == p.shape, t.dtype == p.dtype" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "P8eQAH4Ca_Z2", | |
"outputId": "29602525-5ed0-465c-8281-e92c6472fa7a" | |
}, | |
"execution_count": 72, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(True, True)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 72 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_case = test_util.TestCase()" | |
], | |
"metadata": { | |
"id": "LZC87JM-anoD" | |
}, | |
"execution_count": 73, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**With respect to `df`**" | |
], | |
"metadata": { | |
"id": "PHIHmEgxanoD" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_case.compute_max_gradient_error(\n", | |
" lambda _df: stdtr(_df, t), [df], delta=1e-5)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "07f6df93-3d46-4f4f-ed68-5edfcffe7c47", | |
"id": "__GZuebCanoE" | |
}, | |
"execution_count": 74, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1.0521331420302541e-10" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 74 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**With respect to `t`**" | |
], | |
"metadata": { | |
"id": "qqnQ7_JBanoE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_case.compute_max_gradient_error(\n", | |
" lambda _t: stdtr(df, _t), [t], delta=1e-5)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "3a0185d7-c0ca-44a0-d86b-e44591b60bf6", | |
"id": "JXridSU_anoF" | |
}, | |
"execution_count": 75, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1.5703188749327524e-11" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 75 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 3.2. Test results for Quantile" | |
], | |
"metadata": { | |
"id": "ywzWraJXXOVP" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"The SciPy implementation of StudentT quantile function, `sp_special.stdtrit`, is not very accurate. See some examples:\n" | |
], | |
"metadata": { | |
"id": "C6CX1ZpQPZIA" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"numpy_dtype = np.float64" | |
], | |
"metadata": { | |
"id": "OuTjwf6mUqxY" | |
}, | |
"execution_count": 76, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"df = numpy_dtype(0.1)\n", | |
"p = np.finfo(numpy_dtype).eps\n", | |
"t = sp_special.stdtrit(df, p)\n", | |
"relerr = np.abs((p - sp_special.stdtr(df, t)) / p)\n", | |
"relerr" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5q5e6wc7UtYM", | |
"outputId": "2300f969-2128-4a9c-8c57-650202a7b07e" | |
}, | |
"execution_count": 77, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"187970.38253291088" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 77 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"df = numpy_dtype(9.)\n", | |
"p = np.finfo(numpy_dtype).eps\n", | |
"t = sp_special.stdtrit(df, p)\n", | |
"relerr = np.abs((p - sp_special.stdtr(df, t)) / p)\n", | |
"relerr" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "j8JiSzEOQbvT", | |
"outputId": "68612247-7afd-4ed4-c334-e464e0588fd4" | |
}, | |
"execution_count": 78, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1.3655008457291729e-08" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 78 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"df = numpy_dtype(534)\n", | |
"p = np.finfo(numpy_dtype).eps\n", | |
"t = sp_special.stdtrit(df, p)\n", | |
"relerr = np.abs((p - sp_special.stdtr(df, t)) / p)\n", | |
"relerr" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "bby547ffX_eJ", | |
"outputId": "2f5133fc-cab6-404e-ccac-ff162b4a966a" | |
}, | |
"execution_count": 79, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1.5639093375874324e-07" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 79 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Since `sp_special.stdtrit(df, p)` is not very accurate, we do not want to compare the current and the new implementation of StudentT quantile function to it. Instead, we are going to compare `sp_special.stdtr(df, t*)` to `p`, where `t*` is the computed quantile." | |
], | |
"metadata": { | |
"id": "DIaCoTfsTBFj" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"truths = [p for (df, p, t) in samples]" | |
], | |
"metadata": { | |
"id": "5zK60Go8mK_r" | |
}, | |
"execution_count": 80, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sp_quantile = []\n", | |
"\n", | |
"for i, (df, p, _) in enumerate(samples):\n", | |
" sp_t = sp_special.stdtrit(df, p)\n", | |
" result = sp_special.stdtr(df, sp_t)\n", | |
" sp_quantile.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"id": "566FWGrKV-cG", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "ea456e9d-9b34-40db-e814-951f72d5d969" | |
}, | |
"execution_count": 81, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"current_tf_quantile = []\n", | |
"\n", | |
"for i, (df, p, _) in enumerate(samples):\n", | |
" current_tf_t = current_quantile(df, p)\n", | |
" result = sp_special.stdtr(df, current_tf_t)\n", | |
" current_tf_quantile.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"id": "Glc3Cpqq2HT3", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "07a29bd0-d8d3-43f9-a59b-f222c6067243" | |
}, | |
"execution_count": 82, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"new_tf_quantile = []\n", | |
"\n", | |
"for i, (df, p, _) in enumerate(samples):\n", | |
" new_tf_t = stdtrit(df, p)\n", | |
" result = sp_special.stdtr(df, new_tf_t)\n", | |
" new_tf_quantile.append(result)\n", | |
" print(f'sample {i} is done.')" | |
], | |
"metadata": { | |
"id": "B3zoMKP-KFPw", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "f92894ca-806b-4744-d72c-5446e8638b5a" | |
}, | |
"execution_count": 83, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"sample 0 is done.\n", | |
"sample 1 is done.\n", | |
"sample 2 is done.\n", | |
"sample 3 is done.\n", | |
"sample 4 is done.\n", | |
"sample 5 is done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: Scipy implementation of Quantile' + '\\033[0m')\n", | |
"print(f'Benchmark: True values of p')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, truths, sp_quantile, DTYPE)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"id": "YvbZ6esWmW8c", | |
"outputId": "d276ead8-1f30-4df8-d7f8-c1dbe176d473" | |
}, | |
"execution_count": 84, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: Scipy implementation of Quantile\u001b[0m\n", | |
"Benchmark: True values of p\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1000 0.50 0.50 0 0 0 \n", | |
"1 1 10 1000 0.50 0.50 0 0 0 \n", | |
"2 10 30 1000 0.50 0.50 0 0 0 \n", | |
"3 30 100 1000 0.50 0.50 0 0 0 \n", | |
"4 100 1000 1000 0.50 0.50 0 0 0 \n", | |
"5 1000 10000 1000 0.50 0.50 0 0 0 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 0 0 0 \n", | |
"1 0 0 0 \n", | |
"2 0 0 0 \n", | |
"3 0 0 0 \n", | |
"4 0 0 0 \n", | |
"5 0 0 0 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 84 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: Current implementation of Quantile' + '\\033[0m')\n", | |
"print(f'Benchmark: True values of p')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, truths, current_tf_quantile, DTYPE)" | |
], | |
"metadata": { | |
"id": "ujwolwbQ2Qmy", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"outputId": "f3367af6-6282-4f02-db80-1c84f7eb20d3" | |
}, | |
"execution_count": 85, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: Current implementation of Quantile\u001b[0m\n", | |
"Benchmark: True values of p\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1000 0.50 0.50 0 0 1.1e-04 \n", | |
"1 1 10 1000 0.50 0.50 0 0 3.0e-04 \n", | |
"2 10 30 1000 0.50 0.50 0 0 5.3e-04 \n", | |
"3 30 100 1000 0.50 0.50 0 0 9.7e-04 \n", | |
"4 100 1000 1000 0.50 0.50 0 0 3.1e-03 \n", | |
"5 1000 10000 1000 0.50 0.50 0 0 9.7e-03 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 5.9e-05 2.2e-04 1.2e-04 \n", | |
"1 2.1e-04 6.0e-04 4.2e-04 \n", | |
"2 4.2e-04 1.1e-03 8.4e-04 \n", | |
"3 7.7e-04 1.9e-03 1.5e-03 \n", | |
"4 2.2e-03 6.1e-03 4.4e-03 \n", | |
"5 6.9e-03 0.02 0.01 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1.1e-04</td>\n", | |
" <td>5.9e-05</td>\n", | |
" <td>2.2e-04</td>\n", | |
" <td>1.2e-04</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>3.0e-04</td>\n", | |
" <td>2.1e-04</td>\n", | |
" <td>6.0e-04</td>\n", | |
" <td>4.2e-04</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>5.3e-04</td>\n", | |
" <td>4.2e-04</td>\n", | |
" <td>1.1e-03</td>\n", | |
" <td>8.4e-04</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>9.7e-04</td>\n", | |
" <td>7.7e-04</td>\n", | |
" <td>1.9e-03</td>\n", | |
" <td>1.5e-03</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>3.1e-03</td>\n", | |
" <td>2.2e-03</td>\n", | |
" <td>6.1e-03</td>\n", | |
" <td>4.4e-03</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>9.7e-03</td>\n", | |
" <td>6.9e-03</td>\n", | |
" <td>0.02</td>\n", | |
" <td>0.01</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 85 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('\\033[1m' + f'Method: New implementation of Quantile' + '\\033[0m')\n", | |
"print(f'Benchmark: True values of p')\n", | |
"print(f'Dtype: {np.dtype(DTYPE).name}')\n", | |
"get_metrics_dataframe(sample_specs, truths, new_tf_quantile, DTYPE)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 290 | |
}, | |
"id": "7QVVc1e9VQk4", | |
"outputId": "df877738-e234-4d79-f02d-ac68415ad52e" | |
}, | |
"execution_count": 86, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[1mMethod: New implementation of Quantile\u001b[0m\n", | |
"Benchmark: True values of p\n", | |
"Dtype: float32\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n", | |
"0 0 1 1000 0.50 0.50 0 0 6.0e-08 \n", | |
"1 1 10 1000 0.50 0.50 0 0 0 \n", | |
"2 10 30 1000 0.50 0.50 0 0 0 \n", | |
"3 30 100 1000 0.50 0.50 0 0 0 \n", | |
"4 100 1000 1000 0.50 0.50 0 0 0 \n", | |
"5 1000 10000 1000 0.50 0.50 0 0 0 \n", | |
"\n", | |
" Mean Abs Error Max Rel Error Mean Rel Error \n", | |
"0 4.6e-12 1.2e-07 9.1e-12 \n", | |
"1 0 0 0 \n", | |
"2 0 0 0 \n", | |
"3 0 0 0 \n", | |
"4 0 0 0 \n", | |
"5 0 0 0 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-a29e6918-f234-4594-afef-cc959aaced55\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Min Df</th>\n", | |
" <th>Max Df</th>\n", | |
" <th>#Trials</th>\n", | |
" <th>Min Value</th>\n", | |
" <th>Max Value</th>\n", | |
" <th>%NaN</th>\n", | |
" <th>%Inf</th>\n", | |
" <th>Max Abs Error</th>\n", | |
" <th>Mean Abs Error</th>\n", | |
" <th>Max Rel Error</th>\n", | |
" <th>Mean Rel Error</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>6.0e-08</td>\n", | |
" <td>4.6e-12</td>\n", | |
" <td>1.2e-07</td>\n", | |
" <td>9.1e-12</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>10</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>10</td>\n", | |
" <td>30</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>30</td>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>100</td>\n", | |
" <td>1000</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1000</td>\n", | |
" <td>10000</td>\n", | |
" <td>1000</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0.50</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a29e6918-f234-4594-afef-cc959aaced55')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-a29e6918-f234-4594-afef-cc959aaced55 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-a29e6918-f234-4594-afef-cc959aaced55');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 86 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 3.2.1. Partial derivatives of Quantile\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "x1itZf8107oz" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"For `quantile`, we only compute the maximum difference between autodiff and numerical gradients. We do not run this test for `float32` type nor specifically for values of `p` near half.\n", | |
"\n", | |
"Remember that we used the fact that `stdtr` and `stdtrit` are inverses of each other to compute the gradients of the latter from the gradients of the former.\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "nyUTaPPu3325" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"space_df = np.logspace(np.log10(0.5), 8., num=15, base=10).tolist()\n", | |
"space_p = np.linspace(0.01, 0.99, num=15).tolist()\n", | |
"df, p = zip(*list(itertools.product(space_df, space_p)))\n", | |
"\n", | |
"df = np.array(df, dtype=np.float64)\n", | |
"p = np.array(p, dtype=np.float64)" | |
], | |
"metadata": { | |
"id": "QAKKD04Sx6Ok" | |
}, | |
"execution_count": 87, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_case = test_util.TestCase()" | |
], | |
"metadata": { | |
"id": "DpiAEsaZyukT" | |
}, | |
"execution_count": 88, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**With respect to `df`**" | |
], | |
"metadata": { | |
"id": "y4JHk6063W8B" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_case.compute_max_gradient_error(\n", | |
" lambda _df: stdtrit(_df, p), [df], delta=1e-6)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "h-ZYzEuKzews", | |
"outputId": "08561623-34cf-4f92-8897-dc4f736acb03" | |
}, | |
"execution_count": 89, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"7.160779205150902e-07" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 89 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**With respect to `p`**" | |
], | |
"metadata": { | |
"id": "VrnuLDro3l4k" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test_case.compute_max_gradient_error(\n", | |
" lambda _p: stdtrit(df, _p), [p], delta=1e-7)" | |
], | |
"metadata": { | |
"id": "b2nnaFStLjHZ", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "44ce0f6f-8fd2-466e-ea45-0dd9ab08a86a" | |
}, | |
"execution_count": 90, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"6.231263978406787e-05" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 90 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment