Skip to content

Instantly share code, notes, and snippets.

@leandrolcampos
Created October 19, 2022 23:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leandrolcampos/118e7b3e129d4a4f32f366d366ecee7d to your computer and use it in GitHub Desktop.
Save leandrolcampos/118e7b3e129d4a4f32f366d366ecee7d to your computer and use it in GitHub Desktop.
student_t_f32_tiny_t.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOZ4zMPEvVV54LFbi4AYFlG",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/leandrolcampos/118e7b3e129d4a4f32f366d366ecee7d/student_t_f32_tiny_t.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "tl_dwv6lCTMU"
},
"outputs": [],
"source": [
"%pip install -U -q tf-nightly tfp-nightly"
]
},
{
"cell_type": "code",
"source": [
"import functools\n",
"import itertools\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.stats as sp_stats\n",
"import scipy.special as sp_special\n",
"import scipy.integrate as sp_integrate\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"\n",
"from mpmath import mp\n",
"from tensorflow_probability.python.distributions import student_t\n",
"from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient\n",
"from tensorflow_probability.python.internal import dtype_util\n",
"from tensorflow_probability.python.internal import prefer_static as ps\n",
"from tensorflow_probability.python.internal import special_math\n",
"from tensorflow_probability.python.internal import test_util\n",
"from tensorflow_probability.python.math import generic\n",
"from tensorflow_probability.python.math import gradient\n",
"from tensorflow_probability.python.math import special\n",
"from tensorflow_probability.python.math.numeric import log1psquare"
],
"metadata": {
"id": "0FJ_5e6nCa66"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tf.__version__"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "n1v3xnEuTeo9",
"outputId": "106418fb-7424-426c-9aff-0968f1a90d93"
},
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'2.12.0-dev20221019'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"tfp.__version__"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "nOTvfuTlT4iS",
"outputId": "7c0bc601-a970-45b2-c576-548fd11f04f5"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'0.19.0-dev20221019'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"mp.dps = 25; mp.pretty = True"
],
"metadata": {
"id": "OUE6yZKHi1xh"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 1. Implementations"
],
"metadata": {
"id": "JmBhtIFvSleU"
}
},
{
"cell_type": "code",
"source": [
"def log_prob(x, df, loc, scale):\n",
" \"\"\"Adapted from tfp.distributions.student_t.log_prob.\"\"\"\n",
" dtype = dtype_util.common_dtype([x, df, loc, scale], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
" half = numpy_dtype(0.5)\n",
"\n",
" x, df, loc, scale = [\n",
" tf.convert_to_tensor(param, dtype=dtype)\n",
" for param in (x, df, loc, scale)]\n",
"\n",
" # Writing `y` this way reduces XLA mem copies.\n",
" y = (x - loc) * (tf.math.rsqrt(df) / scale)\n",
" log_unnormalized_prob = -half * (df + numpy_dtype(1.)) * log1psquare(y)\n",
" log_normalization = (\n",
" tf.math.log(tf.abs(scale)) + half * tf.math.log(df) +\n",
" special.lbeta(half, half * df))\n",
" return log_unnormalized_prob - log_normalization"
],
"metadata": {
"id": "ou_FdK019YDV"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@tf.function(autograph=False)\n",
"def current_cdf(df, t):\n",
" \"\"\"Adapted from tfp.distributions.student_t.cdf. Used for comparison.\"\"\"\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
"\n",
" df, t = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, t)]\n",
" broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])\n",
" df, t = [tf.broadcast_to(param, broadcast_shape) for param in (df, t)]\n",
"\n",
" x_t = df / (tf.math.square(t) + df)\n",
" neg_cdf = half * special.betainc(half * df, half, x_t)\n",
" return tf.where(t < 0., neg_cdf, one - neg_cdf)"
],
"metadata": {
"id": "B_buX20TS28E"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@tf.function(autograph=False)\n",
"def current_quantile(df, p):\n",
" \"\"\"Adapted from tfp.distributions.student_t.quantile. Used for comparison.\"\"\"\n",
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
" two = numpy_dtype(2.)\n",
"\n",
" df, p = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, p)]\n",
" broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(p)])\n",
" df, p = [tf.broadcast_to(param, broadcast_shape) for param in (df, p)]\n",
"\n",
" p_adjusted = tf.where(p < half, p, one - p)\n",
" x = special.betaincinv(half * df, half, two * p_adjusted)\n",
"\n",
" abs_t = tf.math.exp(\n",
" tf.math.xlogy(half, df) + tf.math.xlog1py(half, -x) -\n",
" tf.math.xlogy(half, x))\n",
"\n",
" return tf.math.sign(p - half) * abs_t"
],
"metadata": {
"id": "wy-Nt83t1tux"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# The implementations of the Student's t-distribution cumulative distribution\n",
"# function and its inverse, respectively stdtr(df, t) and stdtrit(df, p), are\n",
"# based on ideas and equations available in the following references:\n",
"# [1] Geoffrey W. Hill\n",
"# Algorithm 395: Student's t-distribution\n",
"# Communications of the ACM, v. 13, n. 10, p. 617-619, 1970\n",
"# https://doi.org/10.1145/355598.362775\n",
"# [2] Geoffrey W. Hill\n",
"# Remark on \"Algorithm 395: Student's t-Distribution [S14]\"\n",
"# ACM Transactions on Mathematical Software, v. 7, n. 2, p. 247-249, 1981\n",
"# https://doi.org/10.1145/355945.355955\n",
"# [3] William Press, Saul Teukolsky, William Vetterling and Brian Flannery\n",
"# Numerical Recipes: The Art of Scientific Computing\n",
"# Cambridge University Press, 2007 (Third Edition)\n",
"# http://numerical.recipes/book/book.html\n",
"# [4] Geoffrey W. Hill\n",
"# Algorithm 396: Student's t-quantiles\n",
"# Communications of the ACM, v. 13, n. 10, p. 619-620, 1970\n",
"# https://doi.org/10.1145/355598.355600\n",
"# [5] Geoffrey W. Hill\n",
"# Remark on \"Algorithm 396: Student's t-Quantiles [S14]\"\n",
"# ACM Transactions on Mathematical Software, v. 7, n. 2, p. 250-251, 1981\n",
"# https://doi.org/10.1145/355945.355956\n",
"# [6] R Core Team, R Foundation and Ross Ihaka\n",
"# Mathlib: A C Library of Special Functions\n",
"# https://svn.r-project.org/R/tags/R-4-2-1/src/nmath/qt.c"
],
"metadata": {
"id": "lYV8zBZxkbF3"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtr_asymptotic_expansion(df, t, numpy_dtype):\n",
" \"\"\"Computes `stdtr(df, t)` using asymptotic expansion.\"\"\"\n",
" # This function provides a fast approximation of stdtr(df, t) for large value\n",
" # of df. It is based on an asymptotic normalizing expansion of Cornish-Fisher\n",
" # type [1, 2].\n",
" one = numpy_dtype(1.)\n",
" two = numpy_dtype(2.)\n",
"\n",
" coeffs1 = [\n",
" 1.00000000000000000000E+0, 3.00000000000000000000E+0]\n",
"\n",
" coeffs2 = [\n",
" 4.00000000000000022204E-1, 3.29999999999999982236E+0,\n",
" 2.40000000000000000000E+1, 8.55000000000000000000E+1]\n",
"\n",
" coeffs3 = [\n",
" 3.04761904761904789396E-1, 3.75238095238095237249E+0,\n",
" 4.66714285714285708195E+1, 4.27500000000000000000E+2,\n",
" 2.58750000000000000000E+3, 8.51850000000000000000E+3]\n",
"\n",
" coeffs4 = [\n",
" 2.74285714285714299354E-1, 4.49904761904761940627E+0,\n",
" 7.84514285714285648510E+1, 1.11871071428571417528E+3,\n",
" 1.23876000000000003638E+4, 1.01024550000000002910E+5,\n",
" 5.59494000000000000000E+5, 1.76495962500000000000E+6]\n",
"\n",
" coeffs5 = [\n",
" 2.65974025974025973795E-1, 5.44969696969696926203E+0,\n",
" 1.22202943722943729199E+2, 2.35472987012987005073E+3,\n",
" 3.76250090259740245529E+4, 4.86996139285714307334E+5,\n",
" 4.96087065000000037253E+6, 3.79785955499999970198E+7,\n",
" 2.01505390875000000000E+8, 6.22437908625000000000E+8]\n",
"\n",
" terms_coeffs = [\n",
" [numpy_dtype(c) for c in coeffs]\n",
" for coeffs in (coeffs1, coeffs2, coeffs3, coeffs4, coeffs5)]\n",
"\n",
" df_minus_half = df - numpy_dtype(0.5)\n",
" squared_z = df_minus_half * log1psquare(t * tf.math.rsqrt(df))\n",
" z = tf.math.sqrt(squared_z)\n",
" # To avoid overflow when df is huge, we manipulate b and the denominator of\n",
" # each term of the expansion in the logarithmic space.\n",
" log_b = tf.math.log(numpy_dtype(48.)) + tf.math.xlogy(two, df_minus_half)\n",
"\n",
" term_sign = one\n",
" log_term_denominator = numpy_dtype(0.)\n",
" # We initialize the series with its first term.\n",
" series_sum = z\n",
" last_index = len(terms_coeffs) - 1\n",
"\n",
" # We evaluate the next five terms using a procedure based on Horner's method.\n",
" for index, coeffs in enumerate(terms_coeffs):\n",
" if index < last_index:\n",
" log_term_denominator = log_term_denominator + log_b\n",
" else:\n",
" log_term_denominator = log_term_denominator + tf.math.log(\n",
" (numpy_dtype(0.43595) * squared_z + two) * squared_z +\n",
" tf.math.exp(log_b) + numpy_dtype(537.))\n",
"\n",
" term_numerator = coeffs[0]\n",
" for c in coeffs[1:]:\n",
" term_numerator = c + term_numerator * squared_z\n",
"\n",
" term_numerator = term_numerator * z\n",
" series_sum = series_sum + term_sign * tf.math.exp(\n",
" tf.math.log(term_numerator) - log_term_denominator)\n",
" term_sign = -one * term_sign\n",
"\n",
" return special_math.ndtr(tf.math.sign(t) * series_sum)"
],
"metadata": {
"id": "-jgH__lAk98H"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtr_computation(df, t):\n",
" \"\"\"Computes Student's t-distribution cumulative distribution function.\"\"\"\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
"\n",
" df, t = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, t)]\n",
" broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])\n",
" df, t = [tf.broadcast_to(param, broadcast_shape) for param in (df, t)]\n",
"\n",
" # For moderate df and relatively small t**2, or in case of large df, we use\n",
" # asymptotic expansion [1, 2] to compute stdtr(df, t). The condition to use\n",
" # it was specified by experimentation for np.float32 and was taken from [2,\n",
" # page 249] for np.float64.\n",
"\n",
" if numpy_dtype == np.float32:\n",
" use_asymptotic_expansion = (\n",
" (df >= 10.) & (tf.math.square(t) < (16. * df - 5.))) | (df > 30.)\n",
" else:\n",
" use_asymptotic_expansion = (\n",
" (df >= 100.) & (tf.math.square(t) < (0.1 * df - 5.))) | (df > 1000.)\n",
"\n",
" result = _stdtr_asymptotic_expansion(df, t, numpy_dtype)\n",
"\n",
" # Otherwise, we evaluate stdtr(df, t) using the regularized incomplete beta\n",
" # function [3, page 323, equation 6.14.10]:\n",
" # stdtr(df, t) =\n",
" # 0.5 * betainc(0.5 * df, 0.5, df / (df + t**2)), when t < 0\n",
" # 1. - 0.5 * betainc(0.5 * df, 0.5, df / (df + t**2)), when t >= 0\n",
" # To avoid rounding error when t**2 is small compared to df, we compute the\n",
" # ratio df / (df + t**2) in the logarithmic space. If ratio > 0.99, we then\n",
" # use the following symmetry relation:\n",
" # betainc(a, b, x) = 1 - betainc(b, a, 1 - x) .\n",
"\n",
" raw_ratio = t * tf.math.rsqrt(df)\n",
" ratio = tf.math.exp(-log1psquare(raw_ratio))\n",
" one_minus_ratio = tf.math.exp(-log1psquare(tf.math.reciprocal(raw_ratio)))\n",
"\n",
" # The maximum value for the ratio was set by experimentation.\n",
" use_symmetry_relation = (ratio > 0.99)\n",
" half_df = half * df\n",
" a = tf.where(use_symmetry_relation, half, half_df)\n",
" b = tf.where(use_symmetry_relation, half_df, half)\n",
" x = tf.where(use_symmetry_relation, one_minus_ratio, ratio)\n",
"\n",
" y = special.betainc(a, b, x)\n",
" neg_cdf = half * tf.where(use_symmetry_relation, one - y, y)\n",
" result_betainc = tf.where(t < 0., neg_cdf, one - neg_cdf)\n",
"\n",
" result = tf.where(use_asymptotic_expansion, result, result_betainc)\n",
"\n",
" # Determine if df is out of range (should return NaN output).\n",
" result_is_nan = (df <= 0.)\n",
" result = tf.where(result_is_nan, numpy_dtype(np.nan), result)\n",
"\n",
" return result"
],
"metadata": {
"id": "7CTjxBmTv_Qn"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtr_partials(df, t):\n",
" \"\"\"Returns the partial derivatives of `stdtr(df, t)`.\"\"\"\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
" tiny = np.finfo(numpy_dtype).tiny\n",
" eps = np.finfo(numpy_dtype).eps\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
"\n",
" df, t = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, t)]\n",
" broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])\n",
" df, t = [tf.broadcast_to(param, broadcast_shape) for param in (df, t)]\n",
"\n",
" # The gradient with respect to t can be computed more accurately and stably\n",
" # using Student's t-distribution probability density function.\n",
"\n",
" stdtr_grad_t = tf.math.exp(log_prob(t, df, numpy_dtype(0.), one))\n",
"\n",
" # For moderate df and relatively small t**2, or in case of large df, we use\n",
" # automatic differentiation of the procedure _stdtr_asymptotic_expansion to\n",
" # compute the gradient with respect to df.\n",
"\n",
" if numpy_dtype == np.float32:\n",
" use_asymptotic_expansion = (\n",
" (df >= 10.) & (tf.math.square(t) < (16. * df - 5.))) | (df > 30.)\n",
" else:\n",
" use_asymptotic_expansion = (\n",
" (df >= 100.) & (tf.math.square(t) < (0.1 * df - 5.))) | (df > 1000.)\n",
"\n",
" abs_t = tf.math.abs(t)\n",
" min_abs_t = half * df * tf.math.pow(tiny, numpy_dtype(0.25))\n",
" t_is_tiny = (abs_t < min_abs_t)\n",
" df_asymptotic_expansion = tf.where(use_asymptotic_expansion, df, one)\n",
" t_asymptotic_expansion = tf.where(\n",
" use_asymptotic_expansion,\n",
" # Mask out tiny t so the gradient correctly propagates. When t is tiny,\n",
" # second derivative of log1psquare(t * tf.math.rsqrt(df)) can be NaN.\n",
" tf.where(t_is_tiny, tf.where(t < 0., -min_abs_t, min_abs_t), t),\n",
" one)\n",
"\n",
" stdtr_grad_df = gradient.value_and_gradient(\n",
" lambda z: _stdtr_asymptotic_expansion(\n",
" z, t_asymptotic_expansion, numpy_dtype),\n",
" df_asymptotic_expansion)[1]\n",
" # Handle the case (abs_t < min_abs_t): we use a rough linear approximation.\n",
" stdtr_grad_df = tf.where(t_is_tiny, stdtr_grad_df * abs_t, stdtr_grad_df)\n",
"\n",
" # Otherwise, the gradient with respect to df is evaluated using the partial\n",
" # derivatives of betainc(a, b, x). For t < 0, we have:\n",
" # stdtr_grad_df = 0.5 * (betainc_grad_a * a_grad_df +\n",
" # betainc_grad_x * x_grad_df)\n",
" # = 0.5 * (betainc_grad_a * a_grad_df +\n",
" # 2. * stdtr_grad_t / x_grad_t * x_grad_df)\n",
" # = 0.5 * (betainc_grad_a * a_grad_df -\n",
" # stdtr_grad_t * t / df) ,\n",
" # where a = 0.5 * df and x = df / (df + t**2). In above equation, the second\n",
" # equality follows from the fact that:\n",
" # stdtr_grad_t = 0.5 * betainc_grad_x * x_grad_t , for t < 0.\n",
" # To avoid rounding error when t**2 is small compared to df, we compute the\n",
" # ratio df / (df + t**2) in the logarithmic space. If ratio > 0.99, we then\n",
" # use the following symmetry relation:\n",
" # betainc(a, b, x) = 1 - betainc(b, a, 1 - x) .\n",
"\n",
" min_abs_t_betainc = tf.math.sqrt(df * eps * tf.math.reciprocal(one - eps))\n",
" abs_t_is_too_small = (abs_t < min_abs_t_betainc)\n",
" # Mask out too small abs(t) so the gradient correctly propagates. When abs(t)\n",
" # is too small, ratio == 1 and one_minus_ratio < eps.\n",
" abs_t_betainc = tf.where(abs_t_is_too_small, min_abs_t_betainc, abs_t)\n",
"\n",
" raw_ratio = abs_t_betainc * tf.math.rsqrt(df)\n",
" ratio = tf.math.exp(-log1psquare(raw_ratio))\n",
" one_minus_ratio = tf.math.exp(-log1psquare(tf.math.reciprocal(raw_ratio)))\n",
"\n",
" # The maximum value for the ratio was set by experimentation.\n",
" use_symmetry_relation = (ratio > 0.99)\n",
" half_df = half * df\n",
" a = tf.where(use_symmetry_relation, half, half_df)\n",
" b = tf.where(use_symmetry_relation, half_df, half)\n",
" x = tf.where(use_symmetry_relation, one_minus_ratio, ratio)\n",
"\n",
" # Prepare betainc inputs to make the evaluation of its gradients easier.\n",
" use_betainc = ~use_asymptotic_expansion\n",
" a = tf.where(use_betainc, a, half)\n",
" b = tf.where(use_betainc, b, half)\n",
" x = tf.where(use_betainc, x, half)\n",
"\n",
" betainc_grad_a, betainc_grad_b = gradient.value_and_gradient(\n",
" lambda y, z: special.betainc(y, z, x), [a, b])[1]\n",
" betainc_grad_a = tf.where(\n",
" use_symmetry_relation, -betainc_grad_b, betainc_grad_a)\n",
"\n",
" stdtr_grad_df_betainc = half * (\n",
" betainc_grad_a * half + stdtr_grad_t * abs_t / df)\n",
" # Handle the case (t >= 0).\n",
" stdtr_grad_df_betainc = tf.where(\n",
" t >= 0., -stdtr_grad_df_betainc, stdtr_grad_df_betainc)\n",
" # Handle the case (abs_t < min_abs_t_betainc): we use again a rough linear\n",
" # approximation.\n",
" stdtr_grad_df_betainc = tf.where(\n",
" abs_t_is_too_small, stdtr_grad_df_betainc * abs_t, stdtr_grad_df_betainc)\n",
"\n",
" stdtr_grad_df = tf.where(\n",
" use_asymptotic_expansion, stdtr_grad_df, stdtr_grad_df_betainc)\n",
"\n",
" # Determine if df is out of range (should return NaN output).\n",
" result_is_nan = (df <= 0.)\n",
" stdtr_grad_df, stdtr_grad_t = [\n",
" tf.where(result_is_nan, numpy_dtype(np.nan), grad)\n",
" for grad in [stdtr_grad_df, stdtr_grad_t]]\n",
"\n",
" return stdtr_grad_df, stdtr_grad_t"
],
"metadata": {
"id": "wD4LbJTVgJkK"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtr_fwd(df, t):\n",
" \"\"\"Computes output, aux (collaborates with _stdtr_bwd).\"\"\"\n",
" output = _stdtr_computation(df, t)\n",
" return output, (df, t)"
],
"metadata": {
"id": "Mr1Q1WfGdt0q"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtr_bwd(aux, g):\n",
" \"\"\"Reverse mode impl for stdtr.\"\"\"\n",
" df, t = aux\n",
" partial_df, partial_t = _stdtr_partials(df, t)\n",
" return generic.fix_gradient_for_broadcasting(\n",
" [df, t], [partial_df * g, partial_t * g])"
],
"metadata": {
"id": "ZuZs6ZnseE8k"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtr_jvp(primals, tangents):\n",
" \"\"\"Computes JVP for stdtr (supports JAX custom derivative).\"\"\"\n",
" df, t = primals\n",
" ddf, dt = tangents\n",
"\n",
" p = _stdtr_custom_gradient(df, t)\n",
" partial_df, partial_t = _stdtr_partials(df, t)\n",
" return (p, partial_df * ddf + partial_t * dt)"
],
"metadata": {
"id": "3tbl6WjFaD8N"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@tfp_custom_gradient.custom_gradient(\n",
" vjp_fwd=_stdtr_fwd,\n",
" vjp_bwd=_stdtr_bwd,\n",
" jvp_fn=_stdtr_jvp)\n",
"def _stdtr_custom_gradient(df, t):\n",
" \"\"\"Computes `stdtr(df, t)` with correct custom gradient.\"\"\"\n",
" return _stdtr_computation(df, t)"
],
"metadata": {
"id": "N3PUwyQiaDyT"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def stdtr(df, t, name=None):\n",
" \"\"\"Computes the cumulative distribution function of Student's t-distribution.\n",
"\n",
" Returns the area under the probability density function of this distribution\n",
" with `df > 0` degrees of freedom, integrated from minus infinity to `t`.\n",
"\n",
" Args:\n",
" df: A `Tensor`. Must be one of the following types: `float32`, `float64`.\n",
" t: A `Tensor`. Must have the same type as `df`.\n",
" name: A name for the operation (optional).\n",
"\n",
" Returns:\n",
" A `Tensor` with shape broadcast according to the arguments.\n",
"\n",
" Raises:\n",
" TypeError: if `df` is not one of the following types: `float32`, `float64`.\n",
" \"\"\"\n",
" with tf.name_scope(name or 'stdtr'):\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
" df = tf.convert_to_tensor(df, dtype=dtype)\n",
" t = tf.convert_to_tensor(t, dtype=dtype)\n",
"\n",
" if dtype_util.as_numpy_dtype(dtype) not in [np.float32, np.float64]:\n",
" raise TypeError(f'df.dtype={dtype} is not handled. '\n",
" 'See docstring for supported types.')\n",
"\n",
" return _stdtr_custom_gradient(df, t)"
],
"metadata": {
"id": "5UzA9QTkcFng"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"stdtr = tf.function(stdtr, autograph=False)"
],
"metadata": {
"id": "x2U2tpEvv7I9"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_betaincinv(df, p, numpy_dtype, use_betaincinv):\n",
" \"\"\"Computes `stdtrit(df, p)` using special.betaincinv.\"\"\"\n",
" # This function inverts the procedure that computes stdtr(df, t) using the\n",
" # regularized incomplete beta function. For details on this procedure, see\n",
" # the function _stdtr_computation.\n",
" # We assume here that condition (p <= 0.5) is always true.\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
"\n",
" half_df = half * df\n",
" two_p = numpy_dtype(2.) * p\n",
"\n",
" use_symmetry_relation = (\n",
" p > (half * special.betainc(half_df, half, numpy_dtype(0.99))))\n",
" a = tf.where(use_symmetry_relation, half, half_df)\n",
" b = tf.where(use_symmetry_relation, half_df, half)\n",
" y = tf.where(use_symmetry_relation, one - two_p, two_p)\n",
"\n",
" # Prepare betaincinv inputs to make its evaluation easier.\n",
" a = tf.where(use_betaincinv, a, half)\n",
" b = tf.where(use_betaincinv, b, half)\n",
" y = tf.where(use_betaincinv, y, numpy_dtype(0.))\n",
"\n",
" x = special.betaincinv(a, b, y)\n",
"\n",
" log_abs_t = half * (\n",
" tf.math.log(df) + tf.where(use_symmetry_relation, one, -one) * (\n",
" tf.math.log(x) - tf.math.log1p(-x)))\n",
"\n",
" return -tf.math.exp(log_abs_t)"
],
"metadata": {
"id": "m6-8-hFlgA9z"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_series_expansion(df, p, numpy_dtype):\n",
" \"\"\"Computes `stdtrit(df, p)` using series expansion.\"\"\"\n",
" # This function provides a fast approximation of stdtrit(df, p) for df >= 1.\n",
" # It is based on an asymptotic inverse expansion of Cornish-Fisher type about\n",
" # normal deviates. But for small p, where t**2 / df is large, a second series\n",
" # expansion is used to achieve sufficient accuracy. Both approximations were\n",
" # proposed in [4].\n",
" # We assume here that condition (p <= 0.5) is always true.\n",
" half, one, two, three, four, five, six, seven = [\n",
" numpy_dtype(n) for n in (0.5,) + tuple(range(1, 8))]\n",
"\n",
" a = tf.math.reciprocal(df - half)\n",
" b = numpy_dtype(48.) / tf.math.square(a)\n",
" c = numpy_dtype(96.36) + a * (\n",
" (numpy_dtype(20700.) * a / b - numpy_dtype(98.)) * a - numpy_dtype(16.))\n",
" d = df * tf.math.sqrt(a * half * numpy_dtype(np.pi)) * (\n",
" (numpy_dtype(94.5) / (b + c) - three) / b + one)\n",
"\n",
" # First series expansion: asymptotic inverse expansion about normal deviates.\n",
" z = tf.math.ndtri(p)\n",
" squared_z = tf.math.square(z)\n",
" c = b + c + z * (\n",
" ((numpy_dtype(0.05) * d * z - five) * z - seven) * z - two)\n",
" c_correction = numpy_dtype(0.3) * (\n",
" df - numpy_dtype(4.5)) * (z + numpy_dtype(0.6))\n",
" c = tf.where(df >= 5., c, c + c_correction)\n",
"\n",
" squared_t_over_df = numpy_dtype(0.4) * squared_z + numpy_dtype(6.3)\n",
" squared_t_over_df = squared_t_over_df * squared_z + numpy_dtype(36.)\n",
" squared_t_over_df = squared_t_over_df * squared_z + numpy_dtype(94.5)\n",
" squared_t_over_df = z * (\n",
" (squared_t_over_df / c - squared_z - three) / b + one)\n",
" squared_t_over_df = tf.math.expm1(a * tf.math.square(squared_t_over_df))\n",
"\n",
" # Second series expansion.\n",
" y = tf.math.exp(two / df * (\n",
" tf.math.log(d) + tf.math.log(two) + tf.math.log(p)))\n",
"\n",
" df_plus_2 = df + two\n",
" large_squared_t_over_df = (df + six) / (df * y) - numpy_dtype(0.089) * d\n",
" large_squared_t_over_df = df_plus_2 * three * (\n",
" large_squared_t_over_df - numpy_dtype(0.822))\n",
" large_squared_t_over_df = large_squared_t_over_df + half / (df + four)\n",
" large_squared_t_over_df = y / large_squared_t_over_df - one\n",
" large_squared_t_over_df = tf.math.reciprocal(y) + large_squared_t_over_df * (\n",
" (df + one) / df_plus_2)\n",
"\n",
" p_is_not_small = (y >= (numpy_dtype(0.05) + a))\n",
" # The condition to use the first series expansion was improved in [6].\n",
" use_first_series_expansion = p_is_not_small | ((df < 2.1) & (p > 0.25))\n",
" squared_t_over_df = tf.where(\n",
" use_first_series_expansion, squared_t_over_df, large_squared_t_over_df)\n",
"\n",
" return -tf.math.sqrt(df * squared_t_over_df)"
],
"metadata": {
"id": "lNls1EHFKqz6"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_computation(df, p):\n",
" \"\"\"Returns the inverse of `stdtr(df, t)` with respect to `t`.\"\"\"\n",
" # This function increases the accuracy of an initial estimate for t by using\n",
" # Taylor series expansion iterations as proposed in [5].\n",
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
" zero = numpy_dtype(0.)\n",
" eps = np.finfo(numpy_dtype).eps\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
"\n",
" df, p = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, p)]\n",
" broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(p)])\n",
" df, p = [tf.broadcast_to(param, broadcast_shape) for param in (df, p)]\n",
"\n",
" # max_iterations, use_betaincinv, and tolerance were set by experimentation.\n",
" max_iterations = 3\n",
" if numpy_dtype == np.float32:\n",
" use_betaincinv = (df < 2.)\n",
" tolerance = numpy_dtype(8.) * eps\n",
" else:\n",
" use_betaincinv = (df < 1.)\n",
" tolerance = numpy_dtype(4096.) * eps\n",
"\n",
" adjusted_p = tf.where(p < 0.5, p, one - p)\n",
" initial_candidate = tf.where(\n",
" use_betaincinv,\n",
" # Since _stdtrit_betaincinv is expensive, we pass use_betaincinv to it\n",
" # to save computation.\n",
" _stdtrit_betaincinv(df, adjusted_p, numpy_dtype, use_betaincinv),\n",
" _stdtrit_series_expansion(df, adjusted_p, numpy_dtype))\n",
"\n",
" def taylor_expansion_improvement(should_stop, candidate):\n",
" stdtr_grad_t = tf.math.exp(log_prob(candidate, df, zero, one))\n",
" should_stop = should_stop | tf.math.equal(stdtr_grad_t, zero)\n",
"\n",
" first_order_correction = (adjusted_p - stdtr(df, candidate)) / stdtr_grad_t\n",
"\n",
" candidate_is_zero = tf.math.equal(candidate, zero)\n",
" safe_inv_candidate = tf.where(\n",
" candidate_is_zero, one, tf.math.reciprocal(candidate))\n",
" second_order_correction = half * (df + one) * tf.math.square(\n",
" first_order_correction) * safe_inv_candidate * tf.math.reciprocal(\n",
" one + (df * safe_inv_candidate) * safe_inv_candidate)\n",
" second_order_correction = tf.where(\n",
" candidate_is_zero, zero, second_order_correction)\n",
"\n",
" correction = first_order_correction + second_order_correction\n",
" new_candidate = tf.where(should_stop, candidate, candidate + correction)\n",
"\n",
" adjusted_tolerance = tf.math.abs(tolerance * new_candidate)\n",
" should_stop = should_stop | (tf.math.abs(correction) <= adjusted_tolerance)\n",
"\n",
" return should_stop, new_candidate\n",
"\n",
" (_, result) = tf.while_loop(\n",
" cond=lambda stop, _: tf.reduce_any(~stop),\n",
" body=taylor_expansion_improvement,\n",
" loop_vars=(\n",
" ~tf.math.is_finite(initial_candidate),\n",
" initial_candidate),\n",
" maximum_iterations=max_iterations)\n",
"\n",
" # Handle the case (p >= 0.5).\n",
" result = tf.math.sign(half - p) * result\n",
"\n",
" # Determine if the inputs are out of range (should return NaN output).\n",
" result_is_nan = (p <= zero) | (p >= one) | (df <= zero)\n",
" result = tf.where(result_is_nan, numpy_dtype(np.nan), result)\n",
"\n",
" return result"
],
"metadata": {
"id": "MRWxtNSwqvm7"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_partials(df, p, return_value=False):\n",
" \"\"\"Returns the partial derivatives of `stdtrit(df, p)`.\"\"\"\n",
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n",
" numpy_dtype = dtype_util.as_numpy_dtype(dtype)\n",
"\n",
" df, p = [tf.convert_to_tensor(param, dtype=dtype) for param in (df, p)]\n",
" broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(p)])\n",
" df, p = [tf.broadcast_to(param, broadcast_shape) for param in (df, p)]\n",
"\n",
" # We use the fact that stdtr and stdtrit are inverses of each other to\n",
" # compute the gradients.\n",
" t = _stdtrit_custom_gradient(df, p)\n",
" stdtr_partial_df, stdtr_partial_t = _stdtr_partials(df, t)\n",
"\n",
" partial_df = -stdtr_partial_df / stdtr_partial_t\n",
" partial_p = tf.math.reciprocal(stdtr_partial_t)\n",
"\n",
" if return_value:\n",
" results = [partial_df, partial_p, t]\n",
" else:\n",
" results = [partial_df, partial_p]\n",
"\n",
" # Determine if the inputs are out of range (should return NaN output).\n",
" result_is_nan = (p <= 0.) | (p >= 1.) | (df <= 0.)\n",
" results = [\n",
" tf.where(result_is_nan, numpy_dtype(np.nan), result)\n",
" for result in results]\n",
"\n",
" return results"
],
"metadata": {
"id": "VUQipfTj6xUI"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_fwd(df, p):\n",
" \"\"\"Computes output, aux (collaborates with _stdtrit_bwd).\"\"\"\n",
" output = _stdtrit_computation(df, p)\n",
" return output, (df, p)"
],
"metadata": {
"id": "URcsbziz6DOw"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_bwd(aux, g):\n",
" \"\"\"Reverse mode impl for stdtrit.\"\"\"\n",
" df, p = aux\n",
" partial_df, partial_p, *_ = _stdtrit_partials(df, p)\n",
" return generic.fix_gradient_for_broadcasting(\n",
" [df, p], [partial_df * g, partial_p * g])"
],
"metadata": {
"id": "XiMd364A6DOx"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _stdtrit_jvp(primals, tangents):\n",
" \"\"\"Computes JVP for stdtrit (supports JAX custom derivative).\"\"\"\n",
" df, p = primals\n",
" ddf, dp = tangents\n",
"\n",
" partial_df, partial_p, t = _stdtrit_partials(df, p)\n",
" return (t, partial_df * ddf + partial_p * dp)"
],
"metadata": {
"id": "OVeDoY3v6DOx"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@tfp_custom_gradient.custom_gradient(\n",
" vjp_fwd=_stdtrit_fwd,\n",
" vjp_bwd=_stdtrit_bwd,\n",
" jvp_fn=_stdtrit_jvp)\n",
"def _stdtrit_custom_gradient(df, p):\n",
" \"\"\"Computes `stdtrit(df, p)` with correct custom gradient.\"\"\"\n",
" return _stdtrit_computation(df, p)"
],
"metadata": {
"id": "mPQ4epPt6DOx"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def stdtrit(df, p, name=None):\n",
" \"\"\"Computes the inverse of `stdtr` with respect to `t`.\n",
"\n",
" This function returns a value `t` such that `p = stdtr(df, t)`.\n",
"\n",
" Args:\n",
" df: A `Tensor`. Must be one of the following types: `float32`, `float64`.\n",
" p: A `Tensor`. Must have the same type as `df`.\n",
" name: A name for the operation (optional).\n",
"\n",
" Returns:\n",
" A `Tensor` with shape broadcast according to the arguments.\n",
"\n",
" Raises:\n",
" TypeError: if `df` is not one of the following types: `float32`, `float64`.\n",
" \"\"\"\n",
" with tf.name_scope(name or 'stdtrit'):\n",
" dtype = dtype_util.common_dtype([df, p], tf.float32)\n",
" df = tf.convert_to_tensor(df, dtype=dtype)\n",
" p = tf.convert_to_tensor(p, dtype=dtype)\n",
"\n",
" if dtype_util.as_numpy_dtype(dtype) not in [np.float32, np.float64]:\n",
" raise TypeError(f'df.dtype={dtype} is not handled. '\n",
" 'See docstring for supported types.')\n",
"\n",
" return _stdtrit_custom_gradient(df, p)"
],
"metadata": {
"id": "z0sZtbJf6DOx"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"source": [
"stdtrit = tf.function(stdtrit, autograph=False)"
],
"metadata": {
"id": "TsOEFR06jME_"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 2. Test settings and auxiliary functions"
],
"metadata": {
"id": "J-DE_UILC5cM"
}
},
{
"cell_type": "code",
"source": [
"SEEDS = [1, 17, 42, 51, 184, 301, 346, 448, 733, 985]"
],
"metadata": {
"id": "pNFf91IUC4_R"
},
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"source": [
"SIZE_PER_SEED = 100"
],
"metadata": {
"id": "l-Z4RxdHDKFn"
},
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Used to test the partial derivative of stdtr with respect to df.\n",
"SMALL_SIZE_PER_SEED = 5"
],
"metadata": {
"id": "M4QpX0GzWyOH"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sample_specs = {\n",
" np.float64: [\n",
" (0., 1.),\n",
" (1., 10.),\n",
" (10., 30.),\n",
" (30., 100.),\n",
" (100., 1e+3),\n",
" (1e+3, 1e+4),\n",
" (1e+4, 1e+5),\n",
" (1e+5, 1e+6),\n",
" (1e+6, 1e+7),\n",
" (1e+7, 1e+8)],\n",
" np.float32: [\n",
" (0., 1.),\n",
" (1., 10.),\n",
" (10., 30.),\n",
" (30., 100.),\n",
" (100., 1e+3),\n",
" (1e+3, 1e+4)]}"
],
"metadata": {
"id": "kPvY6ZUQDSLZ"
},
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def make_samples(sample_specs, size_per_seed, seeds, dtype, center=False):\n",
" tiny = np.finfo(dtype).tiny\n",
" eps = np.finfo(dtype).eps\n",
" size = len(seeds) * size_per_seed\n",
"\n",
" samples = []\n",
" for spec in sample_specs:\n",
" dfs = []\n",
"\n",
" for seed in seeds:\n",
" rng = np.random.RandomState(seed)\n",
" df_min, df_max = spec\n",
" df = rng.uniform(low=df_min, high=df_max, size=(size_per_seed, 1))\n",
" dfs.append(df.astype(dtype))\n",
"\n",
" df = np.row_stack(dfs)\n",
"\n",
" if center:\n",
" # Generates values of t near zero.\n",
" step_size = eps\n",
" middle = size_per_seed // 2\n",
" min_p = 0.5 - (middle) * step_size\n",
" max_p = 0.5 + (size_per_seed - middle) * step_size\n",
" p = np.linspace(min_p, max_p, size).astype(dtype)\n",
" p[middle] = dtype(0.5)\n",
" else:\n",
" max_t = np.sqrt(df) / np.sqrt(10. * tiny)\n",
" min_t = -max_t\n",
" max_p = np.minimum(1. - eps, sp_special.stdtr(df, max_t)).astype(dtype)\n",
" min_p = np.maximum(eps, sp_special.stdtr(df, min_t)).astype(dtype)\n",
" p = np.linspace(min_p, max_p, size).astype(dtype)\n",
"\n",
" t = sp_special.stdtrit(df, p).astype(dtype)\n",
" samples.append([df, p, t])\n",
"\n",
" return samples"
],
"metadata": {
"id": "CSkbLyRJHq-O"
},
"execution_count": 33,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"`format_number` and `get_metrics_dataframe` functions are used to print test results."
],
"metadata": {
"id": "vsIEqUoPQtaE"
}
},
{
"cell_type": "code",
"source": [
"def format_number(num):\n",
" if num == 0:\n",
" return '0'\n",
" elif -2 < np.log10(np.abs(num)) < 6: \n",
" if num == int(num):\n",
" return f'{num:.0f}'\n",
" else:\n",
" return f'{num:.2f}'\n",
" return f'{num:.1e}'"
],
"metadata": {
"id": "Biw3tiGVucgS"
},
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_metrics_dataframe(\n",
" sample_specs,\n",
" truths,\n",
" results,\n",
" dtype,\n",
"):\n",
" records = []\n",
"\n",
" for idx, (df_min, df_max) in enumerate(sample_specs[dtype]):\n",
" truth = np.array(truths[idx], dtype=np.float64)\n",
" result = np.array(results[idx], dtype=np.float64)\n",
"\n",
" size = np.prod(truth.shape)\n",
" record = {'Min Df': format_number(df_min)}\n",
" record['Max Df'] = format_number(df_max)\n",
" record['#Trials'] = format_number(size)\n",
"\n",
" min_val = np.min(truth)\n",
" max_val = np.max(truth)\n",
" record['Min Value'] = format_number(min_val)\n",
" record['Max Value'] = format_number(max_val)\n",
"\n",
" perc_nan = np.sum(np.isnan(result)) / size * 100\n",
" perc_inf = np.sum(np.isinf(result)) / size * 100\n",
" record['%NaN'] = format_number(perc_nan)\n",
" record['%Inf'] = format_number(perc_inf)\n",
"\n",
" if (perc_nan + perc_inf) == 100:\n",
" record['Max Abs Error'] = 'NaN'\n",
" record['Mean Abs Error'] = 'NaN'\n",
" record['Max Rel Error'] = 'NaN'\n",
" record['Mean Rel Error'] = 'NaN'\n",
" else:\n",
" abserr = np.abs(truth - result)\n",
" abserr_valid = np.isfinite(abserr)\n",
" max_abserr = np.max(abserr, initial=0., where=abserr_valid)\n",
" mean_abserr = np.mean(abserr, where=abserr_valid)\n",
" record['Max Abs Error'] = format_number(max_abserr)\n",
" record['Mean Abs Error'] = format_number(mean_abserr)\n",
"\n",
" relerr = abserr / np.abs(truth)\n",
" relerr = np.where(np.equal(result, truth), np.float64(0.), relerr)\n",
" relerr_valid = np.isfinite(relerr)\n",
" max_relerr = np.max(relerr, initial=0., where=relerr_valid)\n",
" mean_relerr = np.mean(relerr, where=relerr_valid)\n",
" record['Max Rel Error'] = format_number(max_relerr)\n",
" record['Mean Rel Error'] = format_number(mean_relerr)\n",
"\n",
" records.append(record)\n",
"\n",
" return pd.DataFrame.from_records(records)"
],
"metadata": {
"id": "XRvmj1-9Kukh"
},
"execution_count": 35,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 3. Test results for `float32`. Tiny values of `t`."
],
"metadata": {
"id": "BEseQgBKTp2g"
}
},
{
"cell_type": "code",
"source": [
"DTYPE = np.float32"
],
"metadata": {
"id": "KSS4xfUr7odg"
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"source": [
"CENTER = True"
],
"metadata": {
"id": "UiAnA5WW8Abd"
},
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"source": [
"samples = make_samples(\n",
" sample_specs[DTYPE],\n",
" size_per_seed=SIZE_PER_SEED,\n",
" seeds=SEEDS,\n",
" dtype=DTYPE,\n",
" center=CENTER)"
],
"metadata": {
"id": "1l7QBKQENLff"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 3.1. Test results for CDF"
],
"metadata": {
"id": "bWtClyovXDq4"
}
},
{
"cell_type": "code",
"source": [
"sp_cdf = []\n",
"\n",
"for i, (df, _, t) in enumerate(samples):\n",
" result = sp_special.stdtr(df, t)\n",
" sp_cdf.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bf9dd137-241a-4360-88b5-916f2a77d7f8",
"id": "3X9HQQiEXDrA"
},
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"current_tf_cdf = []\n",
"\n",
"for i, (df, _, t) in enumerate(samples):\n",
" result = current_cdf(df, t)\n",
" current_tf_cdf.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3f999df1-317b-449e-a45b-a3e56a7de7a9",
"id": "pU1y8-QOXDrB"
},
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"new_tf_cdf = []\n",
"\n",
"for i, (df, _, t) in enumerate(samples):\n",
" result = stdtr(df, t)\n",
" new_tf_cdf.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c7d10279-466b-47af-bd56-f7de1db14784",
"id": "LyXlE-pJXDrB"
},
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: Current implementation of CDF' + '\\033[0m')\n",
"print(f'Benchmark: Scipy implementation of CDF')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, sp_cdf, current_tf_cdf, DTYPE)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "701d0e10-ad35-4042-d721-3c3c3cf2d81d",
"id": "6lC7Xp-XXDrB"
},
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: Current implementation of CDF\u001b[0m\n",
"Benchmark: Scipy implementation of CDF\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1.0e+06 0.50 0.50 0 0 6.0e-06 \n",
"1 1 10 1.0e+06 0.50 0.50 0 0 6.0e-06 \n",
"2 10 30 1.0e+06 0.50 0.50 0 0 6.0e-06 \n",
"3 30 100 1.0e+06 0.50 0.50 0 0 6.0e-06 \n",
"4 100 1000 1.0e+06 0.50 0.50 0 0 6.0e-06 \n",
"5 1000 10000 1.0e+06 0.50 0.50 0 0 6.0e-06 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 2.9e-06 1.2e-05 5.7e-06 \n",
"1 3.0e-06 1.2e-05 6.0e-06 \n",
"2 3.0e-06 1.2e-05 6.0e-06 \n",
"3 3.0e-06 1.2e-05 6.0e-06 \n",
"4 3.0e-06 1.2e-05 6.0e-06 \n",
"5 3.0e-06 1.2e-05 6.0e-06 "
],
"text/html": [
"\n",
" <div id=\"df-482d5f0d-089a-4a00-90fb-fa53207cb631\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-06</td>\n",
" <td>2.9e-06</td>\n",
" <td>1.2e-05</td>\n",
" <td>5.7e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-06</td>\n",
" <td>3.0e-06</td>\n",
" <td>1.2e-05</td>\n",
" <td>6.0e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-06</td>\n",
" <td>3.0e-06</td>\n",
" <td>1.2e-05</td>\n",
" <td>6.0e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-06</td>\n",
" <td>3.0e-06</td>\n",
" <td>1.2e-05</td>\n",
" <td>6.0e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-06</td>\n",
" <td>3.0e-06</td>\n",
" <td>1.2e-05</td>\n",
" <td>6.0e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-06</td>\n",
" <td>3.0e-06</td>\n",
" <td>1.2e-05</td>\n",
" <td>6.0e-06</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-482d5f0d-089a-4a00-90fb-fa53207cb631')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-482d5f0d-089a-4a00-90fb-fa53207cb631 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-482d5f0d-089a-4a00-90fb-fa53207cb631');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 42
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: New implementation of CDF' + '\\033[0m')\n",
"print(f'Benchmark: Scipy implementation of CDF')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, sp_cdf, new_tf_cdf, DTYPE)"
],
"metadata": {
"id": "iFDvNEJ47nhN",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "dfc9daa2-1503-4894-d6c7-4df1eb5cf39e"
},
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: New implementation of CDF\u001b[0m\n",
"Benchmark: Scipy implementation of CDF\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1.0e+06 0.50 0.50 0 0 6.0e-08 \n",
"1 1 10 1.0e+06 0.50 0.50 0 0 0 \n",
"2 10 30 1.0e+06 0.50 0.50 0 0 0 \n",
"3 30 100 1.0e+06 0.50 0.50 0 0 0 \n",
"4 100 1000 1.0e+06 0.50 0.50 0 0 0 \n",
"5 1000 10000 1.0e+06 0.50 0.50 0 0 0 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 7.2e-12 1.2e-07 1.4e-11 \n",
"1 0 0 0 \n",
"2 0 0 0 \n",
"3 0 0 0 \n",
"4 0 0 0 \n",
"5 0 0 0 "
],
"text/html": [
"\n",
" <div id=\"df-482f9bed-c650-41f9-9947-54f2634a378d\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-08</td>\n",
" <td>7.2e-12</td>\n",
" <td>1.2e-07</td>\n",
" <td>1.4e-11</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-482f9bed-c650-41f9-9947-54f2634a378d')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-482f9bed-c650-41f9-9947-54f2634a378d button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-482f9bed-c650-41f9-9947-54f2634a378d');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 43
}
]
},
{
"cell_type": "markdown",
"source": [
"### 3.1.1. Partial derivative of CDF with respect to `t`"
],
"metadata": {
"id": "sf-sQ9pV-_oD"
}
},
{
"cell_type": "code",
"source": [
"tf_prob = []\n",
"\n",
"for i, (df, _, t) in enumerate(samples):\n",
" result = tf.math.exp(log_prob(t, df, DTYPE(0.), DTYPE(1.)))\n",
" tf_prob.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"id": "LTLS6ZUj-_oJ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "07e1862f-5e36-4a32-a540-8c0964a972c2"
},
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def current_cdf_partial_t(df, t):\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
"\n",
" df = tf.convert_to_tensor(df, dtype=dtype)\n",
" t = tf.convert_to_tensor(t, dtype=dtype)\n",
"\n",
" return gradient.value_and_gradient(lambda _t: current_cdf(df, _t), t)[1]"
],
"metadata": {
"id": "rFxKbxZk_gb4"
},
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"source": [
"current_cdf_partial_t = tf.function(current_cdf_partial_t, autograph=False)"
],
"metadata": {
"id": "iwRnDTarAkXC"
},
"execution_count": 46,
"outputs": []
},
{
"cell_type": "code",
"source": [
"current_tf_partial_t = []\n",
"\n",
"for i, (df, _, t) in enumerate(samples):\n",
" result = current_cdf_partial_t(df, t)\n",
" current_tf_partial_t.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"id": "8CNtST4E-_oJ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "967ee7b6-71db-46e3-de37-c09b3f2a081e"
},
"execution_count": 47,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def new_stdtr_partial_t(df, t):\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
"\n",
" df = tf.convert_to_tensor(df, dtype=dtype)\n",
" t = tf.convert_to_tensor(t, dtype=dtype)\n",
"\n",
" return gradient.value_and_gradient(lambda _t: stdtr(df, _t), t)[1]"
],
"metadata": {
"id": "0B5q5jWREX9C"
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "code",
"source": [
"new_stdtr_partial_t = tf.function(new_stdtr_partial_t, autograph=False)"
],
"metadata": {
"id": "hQAg19VFEX9D"
},
"execution_count": 49,
"outputs": []
},
{
"cell_type": "code",
"source": [
"new_tf_partial_t = []\n",
"\n",
"for i, (df, _, t) in enumerate(samples):\n",
" result = new_stdtr_partial_t(df, t)\n",
" new_tf_partial_t.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"id": "UkMsAlIG-_oJ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "0f651964-12fa-4c1c-88c7-06eb23cc696f"
},
"execution_count": 50,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: Current implementation of partial derivative of CDF wrt t' + '\\033[0m')\n",
"print(f'Benchmark: TFP implementation of prob')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, tf_prob, current_tf_partial_t, DTYPE)"
],
"metadata": {
"id": "2HssqsVs-_oK",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "12833118-6d53-46c6-fc2e-902e3327bfea"
},
"execution_count": 51,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: Current implementation of partial derivative of CDF wrt t\u001b[0m\n",
"Benchmark: TFP implementation of prob\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1.0e+06 3.4e-03 0.32 0 96.90 0.04 \n",
"1 1 10 1.0e+06 0.32 0.39 0 100 NaN \n",
"2 10 30 1.0e+06 0.39 0.40 0 100 NaN \n",
"3 30 100 1.0e+06 0.40 0.40 0 100 NaN \n",
"4 100 1000 1.0e+06 0.40 0.40 0 100 NaN \n",
"5 1000 10000 1.0e+06 0.40 0.40 0 100 NaN \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 6.5e-03 0.49 0.08 \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"5 NaN NaN NaN "
],
"text/html": [
"\n",
" <div id=\"df-62067040-5d3d-41b9-9f09-fc2f2fb55203\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1.0e+06</td>\n",
" <td>3.4e-03</td>\n",
" <td>0.32</td>\n",
" <td>0</td>\n",
" <td>96.90</td>\n",
" <td>0.04</td>\n",
" <td>6.5e-03</td>\n",
" <td>0.49</td>\n",
" <td>0.08</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.32</td>\n",
" <td>0.39</td>\n",
" <td>0</td>\n",
" <td>100</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.39</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>100</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.40</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>100</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.40</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>100</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.40</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>100</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-62067040-5d3d-41b9-9f09-fc2f2fb55203')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-62067040-5d3d-41b9-9f09-fc2f2fb55203 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-62067040-5d3d-41b9-9f09-fc2f2fb55203');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 51
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: New implementation of partial derivative of CDF wrt t' + '\\033[0m')\n",
"print(f'Benchmark: TFP implementation of prob')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, tf_prob, new_tf_partial_t, DTYPE)"
],
"metadata": {
"id": "3LXRWQX2-_oK",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "895455a9-09e4-4b69-9784-c7538ffcbda9"
},
"execution_count": 52,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: New implementation of partial derivative of CDF wrt t\u001b[0m\n",
"Benchmark: TFP implementation of prob\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1.0e+06 3.4e-03 0.32 0 0 0 \n",
"1 1 10 1.0e+06 0.32 0.39 0 0 0 \n",
"2 10 30 1.0e+06 0.39 0.40 0 0 0 \n",
"3 30 100 1.0e+06 0.40 0.40 0 0 0 \n",
"4 100 1000 1.0e+06 0.40 0.40 0 0 0 \n",
"5 1000 10000 1.0e+06 0.40 0.40 0 0 0 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 0 0 0 \n",
"1 0 0 0 \n",
"2 0 0 0 \n",
"3 0 0 0 \n",
"4 0 0 0 \n",
"5 0 0 0 "
],
"text/html": [
"\n",
" <div id=\"df-3cf420a3-a738-49cb-a33a-dce46ae47524\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1.0e+06</td>\n",
" <td>3.4e-03</td>\n",
" <td>0.32</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.32</td>\n",
" <td>0.39</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.39</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.40</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.40</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1.0e+06</td>\n",
" <td>0.40</td>\n",
" <td>0.40</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-3cf420a3-a738-49cb-a33a-dce46ae47524')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-3cf420a3-a738-49cb-a33a-dce46ae47524 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-3cf420a3-a738-49cb-a33a-dce46ae47524');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 52
}
]
},
{
"cell_type": "markdown",
"source": [
"### 3.1.2. Partial derivative of CDF with respect to `df`"
],
"metadata": {
"id": "uh7KWm3nXDrE"
}
},
{
"cell_type": "code",
"source": [
"small_samples = make_samples(\n",
" sample_specs[DTYPE],\n",
" size_per_seed=SMALL_SIZE_PER_SEED,\n",
" seeds=SEEDS,\n",
" dtype=DTYPE,\n",
" center=CENTER)"
],
"metadata": {
"id": "TL1Dwp-OXDrE"
},
"execution_count": 53,
"outputs": []
},
{
"cell_type": "code",
"source": [
"df, _, t = small_samples[0]\n",
"broadcast_shape = functools.reduce(\n",
" ps.broadcast_shape, [ps.shape(df), ps.shape(t)])"
],
"metadata": {
"id": "xSZasJMMnvpV"
},
"execution_count": 54,
"outputs": []
},
{
"cell_type": "code",
"source": [
"small_samples = [\n",
" [np.broadcast_to(param, broadcast_shape) for param in (df, p, t)]\n",
" for (df, p, t) in small_samples]"
],
"metadata": {
"id": "_OtVkNFWnKfr"
},
"execution_count": 55,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _mp_cdf(df, t, numpy_dtype):\n",
" half = numpy_dtype(0.5)\n",
" one = numpy_dtype(1.)\n",
"\n",
" x_t = df / (t * t + df)\n",
" neg_cdf = half * mp.betainc(half * df, half, x1=0, x2=x_t, regularized=True)\n",
" return neg_cdf if t < 0. else one - neg_cdf"
],
"metadata": {
"id": "5rJL5ekmLjHb"
},
"execution_count": 56,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def _mp_cdf_partial_df(df, t, numpy_dtype):\n",
" df, t = [numpy_dtype(z) for z in (df, t)]\n",
" return mp.diff(lambda x: _mp_cdf(x, t, numpy_dtype), df)"
],
"metadata": {
"id": "wPW_mUOuLjHc"
},
"execution_count": 57,
"outputs": []
},
{
"cell_type": "code",
"source": [
"mp_cdf_partial_df = np.frompyfunc(_mp_cdf_partial_df, 3, 1)"
],
"metadata": {
"id": "q0ND3q6GLjHc"
},
"execution_count": 58,
"outputs": []
},
{
"cell_type": "code",
"source": [
"mp_partial_df = []\n",
"\n",
"for i, (df, _, t) in enumerate(small_samples):\n",
" result = mp_cdf_partial_df(df, t, [DTYPE]).astype(DTYPE).astype(DTYPE)\n",
" mp_partial_df.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "714e8e76-5c54-458f-9eab-fde74aff0c62",
"id": "U_e9W1OVLjHc"
},
"execution_count": 59,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def current_cdf_partial_df(df, t):\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
"\n",
" df = tf.convert_to_tensor(df, dtype=dtype)\n",
" t = tf.convert_to_tensor(t, dtype=dtype)\n",
"\n",
" return gradient.value_and_gradient(lambda _df: current_cdf(_df, t), df)[1]"
],
"metadata": {
"id": "k7LwONNALjHd"
},
"execution_count": 60,
"outputs": []
},
{
"cell_type": "code",
"source": [
"current_cdf_partial_df = tf.function(current_cdf_partial_df, autograph=False)"
],
"metadata": {
"id": "P1Wowna4LjHd"
},
"execution_count": 61,
"outputs": []
},
{
"cell_type": "code",
"source": [
"current_tf_partial_df = []\n",
"\n",
"for i, (df, _, t) in enumerate(small_samples):\n",
" result = current_cdf_partial_df(df, t)\n",
" current_tf_partial_df.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "56be079d-76a0-45ba-f7ee-cd84f66912f1",
"id": "tfH3QNqpLjHd"
},
"execution_count": 62,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"mp_partial_df[0].shape, mp_partial_df[0].dtype"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GiLM1R35ljf0",
"outputId": "4a683d21-e6f1-4b84-85d3-44d177409fee"
},
"execution_count": 63,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((50, 50), dtype('float32'))"
]
},
"metadata": {},
"execution_count": 63
}
]
},
{
"cell_type": "code",
"source": [
"current_tf_partial_df[0].shape, current_tf_partial_df[0].dtype"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "r1XTAWj5ldPZ",
"outputId": "840bbb32-1b09-400d-b250-d0f2ff7b75d7"
},
"execution_count": 64,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(TensorShape([50, 50]), tf.float32)"
]
},
"metadata": {},
"execution_count": 64
}
]
},
{
"cell_type": "code",
"source": [
"def new_stdtr_partial_df(df, t):\n",
" dtype = dtype_util.common_dtype([df, t], tf.float32)\n",
"\n",
" df = tf.convert_to_tensor(df, dtype=dtype)\n",
" t = tf.convert_to_tensor(t, dtype=dtype)\n",
"\n",
" return gradient.value_and_gradient(lambda _df: stdtr(_df, t), df)[1]"
],
"metadata": {
"id": "LJlOKxM_LjHe"
},
"execution_count": 65,
"outputs": []
},
{
"cell_type": "code",
"source": [
"new_stdtr_partial_df = tf.function(new_stdtr_partial_df, autograph=False)"
],
"metadata": {
"id": "fwBPt9N3LjHe"
},
"execution_count": 66,
"outputs": []
},
{
"cell_type": "code",
"source": [
"new_tf_partial_df = []\n",
"\n",
"for i, (df, _, t) in enumerate(small_samples):\n",
" result = new_stdtr_partial_df(df, t)\n",
" new_tf_partial_df.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c966b442-38d1-421d-df70-ebf17abec90d",
"id": "4Mvh-AABLjHe"
},
"execution_count": 67,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"new_tf_partial_df[0].shape, new_tf_partial_df[0].dtype"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OgZ2k4fHltl2",
"outputId": "2ec65a0a-a771-45bf-9039-6707cda874f0"
},
"execution_count": 68,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(TensorShape([50, 50]), tf.float32)"
]
},
"metadata": {},
"execution_count": 68
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: Current implementation of partial derivative of CDF wrt df' + '\\033[0m')\n",
"print(f'Benchmark: Mpmath implementation of partial derivative of CDF wrt df')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, mp_partial_df, current_tf_partial_df, DTYPE)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "c3ee83a9-7650-4249-fcbc-d52df3435c91",
"id": "z4wnUL24LjHh"
},
"execution_count": 69,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: Current implementation of partial derivative of CDF wrt df\u001b[0m\n",
"Benchmark: Mpmath implementation of partial derivative of CDF wrt df\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 2500 -1.0e-03 1.6e-03 98.20 0 1.4e-05 \n",
"1 1 10 2500 -4.6e-08 6.9e-08 100 0 NaN \n",
"2 10 30 2500 -5.9e-10 8.9e-10 100 0 NaN \n",
"3 30 100 2500 -6.6e-11 9.9e-11 100 0 NaN \n",
"4 100 1000 2500 -5.9e-12 8.9e-12 100 0 NaN \n",
"5 1000 10000 2500 -5.9e-14 8.9e-14 100 0 NaN \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 4.6e-06 0.06 0.01 \n",
"1 NaN NaN NaN \n",
"2 NaN NaN NaN \n",
"3 NaN NaN NaN \n",
"4 NaN NaN NaN \n",
"5 NaN NaN NaN "
],
"text/html": [
"\n",
" <div id=\"df-78037e6e-c42d-4bbc-8041-bd33c2563e8b\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2500</td>\n",
" <td>-1.0e-03</td>\n",
" <td>1.6e-03</td>\n",
" <td>98.20</td>\n",
" <td>0</td>\n",
" <td>1.4e-05</td>\n",
" <td>4.6e-06</td>\n",
" <td>0.06</td>\n",
" <td>0.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>2500</td>\n",
" <td>-4.6e-08</td>\n",
" <td>6.9e-08</td>\n",
" <td>100</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>2500</td>\n",
" <td>-5.9e-10</td>\n",
" <td>8.9e-10</td>\n",
" <td>100</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>2500</td>\n",
" <td>-6.6e-11</td>\n",
" <td>9.9e-11</td>\n",
" <td>100</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>2500</td>\n",
" <td>-5.9e-12</td>\n",
" <td>8.9e-12</td>\n",
" <td>100</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>2500</td>\n",
" <td>-5.9e-14</td>\n",
" <td>8.9e-14</td>\n",
" <td>100</td>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-78037e6e-c42d-4bbc-8041-bd33c2563e8b')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-78037e6e-c42d-4bbc-8041-bd33c2563e8b button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-78037e6e-c42d-4bbc-8041-bd33c2563e8b');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 69
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: New implementation of partial derivative of CDF wrt df' + '\\033[0m')\n",
"print(f'Benchmark: Mpmath implementation of partial derivative of CDF wrt df')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, mp_partial_df, new_tf_partial_df, DTYPE)"
],
"metadata": {
"id": "0lrXCEK1zGN6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "21ca5ded-e0c1-4e71-ba43-c5e42c16ae10"
},
"execution_count": 70,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: New implementation of partial derivative of CDF wrt df\u001b[0m\n",
"Benchmark: Mpmath implementation of partial derivative of CDF wrt df\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 2500 -1.0e-03 1.6e-03 0 0 1.9e-05 \n",
"1 1 10 2500 -4.6e-08 6.9e-08 0 0 6.9e-08 \n",
"2 10 30 2500 -5.9e-10 8.9e-10 0 0 5.9e-15 \n",
"3 30 100 2500 -6.6e-11 9.9e-11 0 0 2.4e-15 \n",
"4 100 1000 2500 -5.9e-12 8.9e-12 0 0 3.3e-14 \n",
"5 1000 10000 2500 -5.9e-14 8.9e-14 0 0 1.5e-14 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 4.1e-07 1.00 0.98 \n",
"1 4.3e-09 1.00 0.99 \n",
"2 6.4e-16 1.00 0.10 \n",
"3 2.0e-16 1.00 0.10 \n",
"4 7.0e-16 1.00 0.14 \n",
"5 1.2e-15 1.00 0.80 "
],
"text/html": [
"\n",
" <div id=\"df-8b37bc14-1dd5-4553-9e72-eeaa41e15197\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>2500</td>\n",
" <td>-1.0e-03</td>\n",
" <td>1.6e-03</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.9e-05</td>\n",
" <td>4.1e-07</td>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>2500</td>\n",
" <td>-4.6e-08</td>\n",
" <td>6.9e-08</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.9e-08</td>\n",
" <td>4.3e-09</td>\n",
" <td>1.00</td>\n",
" <td>0.99</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>2500</td>\n",
" <td>-5.9e-10</td>\n",
" <td>8.9e-10</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>5.9e-15</td>\n",
" <td>6.4e-16</td>\n",
" <td>1.00</td>\n",
" <td>0.10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>2500</td>\n",
" <td>-6.6e-11</td>\n",
" <td>9.9e-11</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2.4e-15</td>\n",
" <td>2.0e-16</td>\n",
" <td>1.00</td>\n",
" <td>0.10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>2500</td>\n",
" <td>-5.9e-12</td>\n",
" <td>8.9e-12</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3.3e-14</td>\n",
" <td>7.0e-16</td>\n",
" <td>1.00</td>\n",
" <td>0.14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>2500</td>\n",
" <td>-5.9e-14</td>\n",
" <td>8.9e-14</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.5e-14</td>\n",
" <td>1.2e-15</td>\n",
" <td>1.00</td>\n",
" <td>0.80</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-8b37bc14-1dd5-4553-9e72-eeaa41e15197')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-8b37bc14-1dd5-4553-9e72-eeaa41e15197 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-8b37bc14-1dd5-4553-9e72-eeaa41e15197');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 70
}
]
},
{
"cell_type": "markdown",
"source": [
"### 3.1.3. Partial derivatives of CDF\n",
"\n"
],
"metadata": {
"id": "UJ6y7RpPanoB"
}
},
{
"cell_type": "markdown",
"source": [
"Here we compute the maximum difference between autodiff and numerical gradients. We do not run this test for `float32` type nor specifically for values of `t` near zero."
],
"metadata": {
"id": "YzphmoB_anoC"
}
},
{
"cell_type": "code",
"source": [
"space_df = np.logspace(np.log10(0.5), 8., num=15, base=10).tolist()\n",
"space_p = np.linspace(0.01, 0.99, num=15).tolist()\n",
"df, p = zip(*list(itertools.product(space_df, space_p)))\n",
"\n",
"df = np.array(df, dtype=np.float64)\n",
"p = np.array(p, dtype=np.float64)\n",
"t = sp_special.stdtrit(df, p)"
],
"metadata": {
"id": "dhi-C89qanoC"
},
"execution_count": 71,
"outputs": []
},
{
"cell_type": "code",
"source": [
"t.shape == p.shape, t.dtype == p.dtype"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P8eQAH4Ca_Z2",
"outputId": "29602525-5ed0-465c-8281-e92c6472fa7a"
},
"execution_count": 72,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(True, True)"
]
},
"metadata": {},
"execution_count": 72
}
]
},
{
"cell_type": "code",
"source": [
"test_case = test_util.TestCase()"
],
"metadata": {
"id": "LZC87JM-anoD"
},
"execution_count": 73,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"**With respect to `df`**"
],
"metadata": {
"id": "PHIHmEgxanoD"
}
},
{
"cell_type": "code",
"source": [
"test_case.compute_max_gradient_error(\n",
" lambda _df: stdtr(_df, t), [df], delta=1e-5)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "07f6df93-3d46-4f4f-ed68-5edfcffe7c47",
"id": "__GZuebCanoE"
},
"execution_count": 74,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.0521331420302541e-10"
]
},
"metadata": {},
"execution_count": 74
}
]
},
{
"cell_type": "markdown",
"source": [
"**With respect to `t`**"
],
"metadata": {
"id": "qqnQ7_JBanoE"
}
},
{
"cell_type": "code",
"source": [
"test_case.compute_max_gradient_error(\n",
" lambda _t: stdtr(df, _t), [t], delta=1e-5)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3a0185d7-c0ca-44a0-d86b-e44591b60bf6",
"id": "JXridSU_anoF"
},
"execution_count": 75,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.5703188749327524e-11"
]
},
"metadata": {},
"execution_count": 75
}
]
},
{
"cell_type": "markdown",
"source": [
"## 3.2. Test results for Quantile"
],
"metadata": {
"id": "ywzWraJXXOVP"
}
},
{
"cell_type": "markdown",
"source": [
"The SciPy implementation of StudentT quantile function, `sp_special.stdtrit`, is not very accurate. See some examples:\n"
],
"metadata": {
"id": "C6CX1ZpQPZIA"
}
},
{
"cell_type": "code",
"source": [
"numpy_dtype = np.float64"
],
"metadata": {
"id": "OuTjwf6mUqxY"
},
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"source": [
"df = numpy_dtype(0.1)\n",
"p = np.finfo(numpy_dtype).eps\n",
"t = sp_special.stdtrit(df, p)\n",
"relerr = np.abs((p - sp_special.stdtr(df, t)) / p)\n",
"relerr"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5q5e6wc7UtYM",
"outputId": "2300f969-2128-4a9c-8c57-650202a7b07e"
},
"execution_count": 77,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"187970.38253291088"
]
},
"metadata": {},
"execution_count": 77
}
]
},
{
"cell_type": "code",
"source": [
"df = numpy_dtype(9.)\n",
"p = np.finfo(numpy_dtype).eps\n",
"t = sp_special.stdtrit(df, p)\n",
"relerr = np.abs((p - sp_special.stdtr(df, t)) / p)\n",
"relerr"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j8JiSzEOQbvT",
"outputId": "68612247-7afd-4ed4-c334-e464e0588fd4"
},
"execution_count": 78,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.3655008457291729e-08"
]
},
"metadata": {},
"execution_count": 78
}
]
},
{
"cell_type": "code",
"source": [
"df = numpy_dtype(534)\n",
"p = np.finfo(numpy_dtype).eps\n",
"t = sp_special.stdtrit(df, p)\n",
"relerr = np.abs((p - sp_special.stdtr(df, t)) / p)\n",
"relerr"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bby547ffX_eJ",
"outputId": "2f5133fc-cab6-404e-ccac-ff162b4a966a"
},
"execution_count": 79,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.5639093375874324e-07"
]
},
"metadata": {},
"execution_count": 79
}
]
},
{
"cell_type": "markdown",
"source": [
"Since `sp_special.stdtrit(df, p)` is not very accurate, we do not want to compare the current and the new implementation of StudentT quantile function to it. Instead, we are going to compare `sp_special.stdtr(df, t*)` to `p`, where `t*` is the computed quantile."
],
"metadata": {
"id": "DIaCoTfsTBFj"
}
},
{
"cell_type": "code",
"source": [
"truths = [p for (df, p, t) in samples]"
],
"metadata": {
"id": "5zK60Go8mK_r"
},
"execution_count": 80,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sp_quantile = []\n",
"\n",
"for i, (df, p, _) in enumerate(samples):\n",
" sp_t = sp_special.stdtrit(df, p)\n",
" result = sp_special.stdtr(df, sp_t)\n",
" sp_quantile.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"id": "566FWGrKV-cG",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ea456e9d-9b34-40db-e814-951f72d5d969"
},
"execution_count": 81,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"current_tf_quantile = []\n",
"\n",
"for i, (df, p, _) in enumerate(samples):\n",
" current_tf_t = current_quantile(df, p)\n",
" result = sp_special.stdtr(df, current_tf_t)\n",
" current_tf_quantile.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"id": "Glc3Cpqq2HT3",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "07a29bd0-d8d3-43f9-a59b-f222c6067243"
},
"execution_count": 82,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"new_tf_quantile = []\n",
"\n",
"for i, (df, p, _) in enumerate(samples):\n",
" new_tf_t = stdtrit(df, p)\n",
" result = sp_special.stdtr(df, new_tf_t)\n",
" new_tf_quantile.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"id": "B3zoMKP-KFPw",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f92894ca-806b-4744-d72c-5446e8638b5a"
},
"execution_count": 83,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: Scipy implementation of Quantile' + '\\033[0m')\n",
"print(f'Benchmark: True values of p')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, truths, sp_quantile, DTYPE)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"id": "YvbZ6esWmW8c",
"outputId": "d276ead8-1f30-4df8-d7f8-c1dbe176d473"
},
"execution_count": 84,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: Scipy implementation of Quantile\u001b[0m\n",
"Benchmark: True values of p\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1000 0.50 0.50 0 0 0 \n",
"1 1 10 1000 0.50 0.50 0 0 0 \n",
"2 10 30 1000 0.50 0.50 0 0 0 \n",
"3 30 100 1000 0.50 0.50 0 0 0 \n",
"4 100 1000 1000 0.50 0.50 0 0 0 \n",
"5 1000 10000 1000 0.50 0.50 0 0 0 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 0 0 0 \n",
"1 0 0 0 \n",
"2 0 0 0 \n",
"3 0 0 0 \n",
"4 0 0 0 \n",
"5 0 0 0 "
],
"text/html": [
"\n",
" <div id=\"df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-1b8d51c7-edd4-45c2-9fb2-ae4c9e805f88');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 84
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: Current implementation of Quantile' + '\\033[0m')\n",
"print(f'Benchmark: True values of p')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, truths, current_tf_quantile, DTYPE)"
],
"metadata": {
"id": "ujwolwbQ2Qmy",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"outputId": "f3367af6-6282-4f02-db80-1c84f7eb20d3"
},
"execution_count": 85,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: Current implementation of Quantile\u001b[0m\n",
"Benchmark: True values of p\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1000 0.50 0.50 0 0 1.1e-04 \n",
"1 1 10 1000 0.50 0.50 0 0 3.0e-04 \n",
"2 10 30 1000 0.50 0.50 0 0 5.3e-04 \n",
"3 30 100 1000 0.50 0.50 0 0 9.7e-04 \n",
"4 100 1000 1000 0.50 0.50 0 0 3.1e-03 \n",
"5 1000 10000 1000 0.50 0.50 0 0 9.7e-03 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 5.9e-05 2.2e-04 1.2e-04 \n",
"1 2.1e-04 6.0e-04 4.2e-04 \n",
"2 4.2e-04 1.1e-03 8.4e-04 \n",
"3 7.7e-04 1.9e-03 1.5e-03 \n",
"4 2.2e-03 6.1e-03 4.4e-03 \n",
"5 6.9e-03 0.02 0.01 "
],
"text/html": [
"\n",
" <div id=\"df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.1e-04</td>\n",
" <td>5.9e-05</td>\n",
" <td>2.2e-04</td>\n",
" <td>1.2e-04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3.0e-04</td>\n",
" <td>2.1e-04</td>\n",
" <td>6.0e-04</td>\n",
" <td>4.2e-04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>5.3e-04</td>\n",
" <td>4.2e-04</td>\n",
" <td>1.1e-03</td>\n",
" <td>8.4e-04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>9.7e-04</td>\n",
" <td>7.7e-04</td>\n",
" <td>1.9e-03</td>\n",
" <td>1.5e-03</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3.1e-03</td>\n",
" <td>2.2e-03</td>\n",
" <td>6.1e-03</td>\n",
" <td>4.4e-03</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>9.7e-03</td>\n",
" <td>6.9e-03</td>\n",
" <td>0.02</td>\n",
" <td>0.01</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-52e5cabb-4bf5-4a2c-a00e-5a3439221de9');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 85
}
]
},
{
"cell_type": "code",
"source": [
"print('\\033[1m' + f'Method: New implementation of Quantile' + '\\033[0m')\n",
"print(f'Benchmark: True values of p')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, truths, new_tf_quantile, DTYPE)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 290
},
"id": "7QVVc1e9VQk4",
"outputId": "df877738-e234-4d79-f02d-ac68415ad52e"
},
"execution_count": 86,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[1mMethod: New implementation of Quantile\u001b[0m\n",
"Benchmark: True values of p\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Min Df Max Df #Trials Min Value Max Value %NaN %Inf Max Abs Error \\\n",
"0 0 1 1000 0.50 0.50 0 0 6.0e-08 \n",
"1 1 10 1000 0.50 0.50 0 0 0 \n",
"2 10 30 1000 0.50 0.50 0 0 0 \n",
"3 30 100 1000 0.50 0.50 0 0 0 \n",
"4 100 1000 1000 0.50 0.50 0 0 0 \n",
"5 1000 10000 1000 0.50 0.50 0 0 0 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 4.6e-12 1.2e-07 9.1e-12 \n",
"1 0 0 0 \n",
"2 0 0 0 \n",
"3 0 0 0 \n",
"4 0 0 0 \n",
"5 0 0 0 "
],
"text/html": [
"\n",
" <div id=\"df-a29e6918-f234-4594-afef-cc959aaced55\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Min Df</th>\n",
" <th>Max Df</th>\n",
" <th>#Trials</th>\n",
" <th>Min Value</th>\n",
" <th>Max Value</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>6.0e-08</td>\n",
" <td>4.6e-12</td>\n",
" <td>1.2e-07</td>\n",
" <td>9.1e-12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>30</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>100</td>\n",
" <td>1000</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1000</td>\n",
" <td>10000</td>\n",
" <td>1000</td>\n",
" <td>0.50</td>\n",
" <td>0.50</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a29e6918-f234-4594-afef-cc959aaced55')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-a29e6918-f234-4594-afef-cc959aaced55 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-a29e6918-f234-4594-afef-cc959aaced55');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 86
}
]
},
{
"cell_type": "markdown",
"source": [
"### 3.2.1. Partial derivatives of Quantile\n",
"\n"
],
"metadata": {
"id": "x1itZf8107oz"
}
},
{
"cell_type": "markdown",
"source": [
"For `quantile`, we only compute the maximum difference between autodiff and numerical gradients. We do not run this test for `float32` type nor specifically for values of `p` near half.\n",
"\n",
"Remember that we used the fact that `stdtr` and `stdtrit` are inverses of each other to compute the gradients of the latter from the gradients of the former.\n",
"\n"
],
"metadata": {
"id": "nyUTaPPu3325"
}
},
{
"cell_type": "code",
"source": [
"space_df = np.logspace(np.log10(0.5), 8., num=15, base=10).tolist()\n",
"space_p = np.linspace(0.01, 0.99, num=15).tolist()\n",
"df, p = zip(*list(itertools.product(space_df, space_p)))\n",
"\n",
"df = np.array(df, dtype=np.float64)\n",
"p = np.array(p, dtype=np.float64)"
],
"metadata": {
"id": "QAKKD04Sx6Ok"
},
"execution_count": 87,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_case = test_util.TestCase()"
],
"metadata": {
"id": "DpiAEsaZyukT"
},
"execution_count": 88,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"**With respect to `df`**"
],
"metadata": {
"id": "y4JHk6063W8B"
}
},
{
"cell_type": "code",
"source": [
"test_case.compute_max_gradient_error(\n",
" lambda _df: stdtrit(_df, p), [df], delta=1e-6)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "h-ZYzEuKzews",
"outputId": "08561623-34cf-4f92-8897-dc4f736acb03"
},
"execution_count": 89,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"7.160779205150902e-07"
]
},
"metadata": {},
"execution_count": 89
}
]
},
{
"cell_type": "markdown",
"source": [
"**With respect to `p`**"
],
"metadata": {
"id": "VrnuLDro3l4k"
}
},
{
"cell_type": "code",
"source": [
"test_case.compute_max_gradient_error(\n",
" lambda _p: stdtrit(df, _p), [p], delta=1e-7)"
],
"metadata": {
"id": "b2nnaFStLjHZ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "44ce0f6f-8fd2-466e-ea45-0dd9ab08a86a"
},
"execution_count": 90,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"6.231263978406787e-05"
]
},
"metadata": {},
"execution_count": 90
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment