Skip to content

Instantly share code, notes, and snippets.

@leandrolcampos
Created August 3, 2022 00:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leandrolcampos/b19177cac0e4f3f0b962edca265ec8ab to your computer and use it in GitHub Desktop.
Save leandrolcampos/b19177cac0e4f3f0b962edca265ec8ab to your computer and use it in GitHub Desktop.
igammainv.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "igammainv.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPTVurhP+4pKPWZ0EXB93Zh",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/leandrolcampos/b19177cac0e4f3f0b962edca265ec8ab/igammainv.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install -U -q tensorflow"
],
"metadata": {
"id": "5xAGAPPG8wcL",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "080cd33d-3764-4785-9dae-957d6aa5100f"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[K |████████████████████████████████| 511.7 MB 6.1 kB/s \n",
"\u001b[K |████████████████████████████████| 5.8 MB 52.6 MB/s \n",
"\u001b[K |████████████████████████████████| 1.6 MB 34.9 MB/s \n",
"\u001b[K |████████████████████████████████| 438 kB 70.6 MB/s \n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install -U -q tensorflow-probability"
],
"metadata": {
"id": "YgA81jppUk2T",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "44d6b590-713b-40d0-ace3-84f7811b4875"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[K |████████████████████████████████| 6.5 MB 3.8 MB/s \n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "KA-hrzh_7myi"
},
"outputs": [],
"source": [
"import functools\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.special as sp_special\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"\n",
"from tensorflow_probability.python import math as tfp_math\n",
"from tensorflow_probability.python.internal import dtype_util\n",
"from tensorflow_probability.python.internal import prefer_static as ps"
]
},
{
"cell_type": "markdown",
"source": [
"## TFP implementation of `igammainv`"
],
"metadata": {
"id": "SK-JwfDStwBA"
}
},
{
"cell_type": "markdown",
"source": [
"https://github.com/tensorflow/probability/blob/21c10c1073082ee592c8f119e9fedc78a704634f/tensorflow_probability/python/math/special.py#L875-L1226"
],
"metadata": {
"id": "urkjhL8nugtT"
}
},
{
"cell_type": "markdown",
"source": [
"## Test settings and auxiliary functions"
],
"metadata": {
"id": "d15s3qJNgkAc"
}
},
{
"cell_type": "code",
"source": [
"SEEDS = [1, 17, 42, 51, 184, 301, 346, 448, 733, 985]"
],
"metadata": {
"id": "bFBYwIZhtV-G"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"len(SEEDS)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bLneGiOieGli",
"outputId": "185c2298-46c6-4eca-86d5-47be18f6d8f8"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"10"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"SIZE_PER_SEED = 50_000\n",
"# SIZE = SIZE_PER_SEED * len(SEEDS)"
],
"metadata": {
"id": "d0UUVrpW7xIC"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"`sample_specs` specifies how to simulate values for parameter `a`. We call each specification a domain.\n",
"\n",
"Values for the parameter `y` are simulated from a Uniform distribution on the interval (0, 1)."
],
"metadata": {
"id": "O9ynlPQzA2sR"
}
},
{
"cell_type": "code",
"source": [
"sample_specs = {\n",
" np.float64: [\n",
" ('uniform', 1.),\n",
" ('uniform', 5.),\n",
" ('uniform', 10.),\n",
" ('uniform', 30.),\n",
" ('uniform', 100.),\n",
" ('uniform', 1000.),\n",
" ('uniform', 10000.),\n",
" ('uniform', 100000.)],\n",
" np.float32: [\n",
" ('uniform', 1.),\n",
" ('uniform', 5.),\n",
" ('uniform', 10.),\n",
" ('uniform', 30.),\n",
" ('uniform', 100.),\n",
" ('uniform', 1000.),\n",
" ('uniform', 10000.),\n",
" ('uniform', 100000.)]}"
],
"metadata": {
"id": "GEB-pRUz7yHH"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def make_samples(sample_specs, size_per_seed, seeds, dtype):\n",
" tiny = tf.constant(np.finfo(dtype).tiny, dtype)\n",
" one = tf.constant(1., dtype)\n",
" eps = tf.constant(np.finfo(dtype).eps, dtype)\n",
"\n",
" # This function supports uniform and half_normal. \n",
" def sample(rng, name, scale):\n",
" if name == 'uniform':\n",
" return np.clip(\n",
" scale * rng.uniform(size=size_per_seed), a_min=tiny, a_max=None)\n",
" return np.clip(\n",
" scale * np.abs(rng.randn(size_per_seed)), a_min=tiny, a_max=None)\n",
"\n",
" samples = []\n",
" for spec in sample_specs:\n",
" list_a = []\n",
" list_b = []\n",
" list_y = []\n",
"\n",
" for seed in seeds:\n",
" rng = np.random.RandomState(seed)\n",
" a = sample(rng, *spec).astype(dtype)\n",
" list_a.append(a)\n",
" y_min = tf.math.igamma(a, tiny)\n",
" y_max = tf.math.igamma(a, 1. - eps)\n",
" # Simulates values for y such that tiny < igamma(a, y) < 1 - eps\n",
" y = rng.uniform(size=size_per_seed, low=y_min, high=y_max).astype(dtype)\n",
" list_y.append(y)\n",
" \n",
" a = np.concatenate(list_a)\n",
" y = np.concatenate(list_y)\n",
" samples.append([a, y])\n",
"\n",
" return samples"
],
"metadata": {
"id": "0EfoCKxF70Sw"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"`format_number` and `get_metrics_dataframe` functions are used to print test results."
],
"metadata": {
"id": "vsIEqUoPQtaE"
}
},
{
"cell_type": "code",
"source": [
"def format_number(num):\n",
" if num == 0:\n",
" return '0'\n",
" elif -2 < np.log10(num) < 6: \n",
" if num == int(num):\n",
" return f'{num:.0f}'\n",
" else:\n",
" return f'{num:.2f}'\n",
" return f'{num:.1e}'"
],
"metadata": {
"id": "Biw3tiGVucgS"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_metrics_dataframe(\n",
" sample_specs,\n",
" sp_results,\n",
" tf_results,\n",
"):\n",
" records = []\n",
"\n",
" idx_cmp = 0\n",
" for idx, (distr, domain) in enumerate(sample_specs[DTYPE]):\n",
" size = sp_results[idx].shape[0]\n",
" record = {'Distribution': distr}\n",
" record['Domain'] = format_number(domain)\n",
" record['#Trials'] = format_number(size)\n",
"\n",
" perc_nan = np.sum(np.isnan(tf_results[idx])) / size * 100\n",
" perc_inf = np.sum(np.isinf(tf_results[idx])) / size * 100\n",
" record['%NaN'] = format_number(perc_nan)\n",
" record['%Inf'] = format_number(perc_inf)\n",
"\n",
" if (perc_nan + perc_inf) == 100:\n",
" record['Min Result'] = 'NaN'\n",
" record['Max Result'] = 'NaN'\n",
" else:\n",
" res = tf_results[idx]\n",
" min_res = np.min(res, initial=np.inf, where=np.isfinite(res))\n",
" max_res = np.max(res, initial=-np.inf, where=np.isfinite(res))\n",
" record['Min Result'] = format_number(min_res)\n",
" record['Max Result'] = format_number(max_res)\n",
"\n",
" if (perc_nan + perc_inf) == 100:\n",
" record['Max Abs Error'] = 'NaN'\n",
" record['Mean Abs Error'] = 'NaN'\n",
" record['Max Rel Error'] = 'NaN'\n",
" record['Mean Rel Error'] = 'NaN'\n",
" else:\n",
" aerr = np.abs(sp_results[idx] - tf_results[idx])\n",
" aerr_valid = np.isfinite(aerr)\n",
" max_aerr = np.max(aerr, initial=0., where=aerr_valid)\n",
" mean_aerr = np.mean(aerr, where=aerr_valid)\n",
" record['Max Abs Error'] = format_number(max_aerr)\n",
" record['Mean Abs Error'] = format_number(mean_aerr)\n",
"\n",
" div_by_zero = (np.abs(sp_results[idx]) < np.sqrt(TINY))\n",
" rerr = aerr / np.where(div_by_zero, 1., np.abs(sp_results[idx]))\n",
" rerr_valid = np.isfinite(rerr)\n",
" max_rerr = np.max(rerr, initial=0., where=rerr_valid)\n",
" mean_rerr = np.mean(rerr, where=rerr_valid)\n",
" record['Max Rel Error'] = format_number(max_rerr)\n",
" record['Mean Rel Error'] = format_number(mean_rerr)\n",
"\n",
" records.append(record)\n",
"\n",
" return pd.DataFrame.from_records(records)"
],
"metadata": {
"id": "XRvmj1-9Kukh"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sp_gammaincinv = sp_special.gammaincinv\n",
"tf_gammaincinv = tf.function(\n",
" tfp_math.igammainv, autograph=False, jit_compile=False)"
],
"metadata": {
"id": "fUEgyjLQbBmV"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Test results for `float64`"
],
"metadata": {
"id": "KebKjtrOgwgx"
}
},
{
"cell_type": "code",
"source": [
"DTYPE = np.float64"
],
"metadata": {
"id": "KSS4xfUr7odg"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TINY = np.finfo(DTYPE).tiny\n",
"TINY"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ShOeQ13V7pqG",
"outputId": "435d196c-c456-4c17-d64c-8e298d916814"
},
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"2.2250738585072014e-308"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"source": [
"EPS = np.finfo(DTYPE).eps\n",
"EPS"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3Vws5VIjhinP",
"outputId": "8eeb81cd-7f36-4377-9413-1b56a2c2337a"
},
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"2.220446049250313e-16"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"source": [
"samples = make_samples(\n",
" sample_specs[DTYPE], size_per_seed=SIZE_PER_SEED, seeds=SEEDS, dtype=DTYPE)"
],
"metadata": {
"id": "umRhJjde8Cp7"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sp_res = [sp_gammaincinv(a, y) for a, y in samples]"
],
"metadata": {
"id": "JUw64yHs-dj8"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tf_res = []\n",
"\n",
"for i, (a, y) in enumerate(samples):\n",
" result = tf_gammaincinv(a, y)\n",
" tf_res.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bpohh_aLJZpG",
"outputId": "74b092d0-1496-4852-dd4e-7b1683506a03"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n",
"sample 6 is done.\n",
"sample 7 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f'Method: TFP implementation of igammainv')\n",
"print(f'Benchmark: Cephes implementation of gammaincinv for float64 [SciPy]')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, sp_res, tf_res)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 354
},
"id": "vly0tx24bsUf",
"outputId": "7be9015e-f1bf-4295-d8f5-88cd9d67a16d"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Method: TFP implementation of igammainv\n",
"Benchmark: Cephes implementation of gammaincinv for float64 [SciPy]\n",
"Dtype: float64\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Distribution Domain #Trials %NaN %Inf Min Result Max Result Max Abs Error \\\n",
"0 uniform 1 500000 0 0 3.2e-308 1.00 3.0e-13 \n",
"1 uniform 5 500000 0 0 5.6e-307 1.00 2.5e-13 \n",
"2 uniform 10 500000 0 0 7.4e-300 1.00 1.5e-13 \n",
"3 uniform 30 500000 0 0 1.2e-290 1.00 6.8e-14 \n",
"4 uniform 100 500000 0 0 2.9e-274 1.00 3.9e-14 \n",
"5 uniform 1000 500000 0 0 0 1.00 1.00 \n",
"6 uniform 10000 500000 0 0 0 1.00 1.00 \n",
"7 uniform 100000 500000 0 0 0 1.00 0.99 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 1.6e-16 1.1e-09 8.0e-15 \n",
"1 7.0e-17 1.3e-10 1.2e-15 \n",
"2 7.1e-17 1.1e-10 8.3e-16 \n",
"3 1.5e-16 2.6e-11 3.6e-16 \n",
"4 2.6e-16 1.4e-11 3.3e-16 \n",
"5 1.8e-04 1 1.8e-04 \n",
"6 1.8e-05 1 1.8e-05 \n",
"7 2.0e-06 1 2.0e-06 "
],
"text/html": [
"\n",
" <div id=\"df-cd85607e-5fc8-4c0a-9d83-2ef61189180a\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Distribution</th>\n",
" <th>Domain</th>\n",
" <th>#Trials</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Min Result</th>\n",
" <th>Max Result</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>uniform</td>\n",
" <td>1</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3.2e-308</td>\n",
" <td>1.00</td>\n",
" <td>3.0e-13</td>\n",
" <td>1.6e-16</td>\n",
" <td>1.1e-09</td>\n",
" <td>8.0e-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>uniform</td>\n",
" <td>5</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>5.6e-307</td>\n",
" <td>1.00</td>\n",
" <td>2.5e-13</td>\n",
" <td>7.0e-17</td>\n",
" <td>1.3e-10</td>\n",
" <td>1.2e-15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>uniform</td>\n",
" <td>10</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.4e-300</td>\n",
" <td>1.00</td>\n",
" <td>1.5e-13</td>\n",
" <td>7.1e-17</td>\n",
" <td>1.1e-10</td>\n",
" <td>8.3e-16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>uniform</td>\n",
" <td>30</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.2e-290</td>\n",
" <td>1.00</td>\n",
" <td>6.8e-14</td>\n",
" <td>1.5e-16</td>\n",
" <td>2.6e-11</td>\n",
" <td>3.6e-16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>uniform</td>\n",
" <td>100</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2.9e-274</td>\n",
" <td>1.00</td>\n",
" <td>3.9e-14</td>\n",
" <td>2.6e-16</td>\n",
" <td>1.4e-11</td>\n",
" <td>3.3e-16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>uniform</td>\n",
" <td>1000</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>1.8e-04</td>\n",
" <td>1</td>\n",
" <td>1.8e-04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>uniform</td>\n",
" <td>10000</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>1.8e-05</td>\n",
" <td>1</td>\n",
" <td>1.8e-05</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>uniform</td>\n",
" <td>100000</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.00</td>\n",
" <td>0.99</td>\n",
" <td>2.0e-06</td>\n",
" <td>1</td>\n",
" <td>2.0e-06</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-cd85607e-5fc8-4c0a-9d83-2ef61189180a')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-cd85607e-5fc8-4c0a-9d83-2ef61189180a button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-cd85607e-5fc8-4c0a-9d83-2ef61189180a');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"cases = {}\n",
"\n",
"for sample_id in range(8):\n",
" a, y = samples[sample_id]\n",
" x_tf = tf_res[sample_id]\n",
" x_sp = sp_res[sample_id]\n",
" data = np.column_stack([a, y, x_tf, x_sp])\n",
" df = pd.DataFrame(data=data, columns=['a', 'y', 'x_tf', 'x_sp'])\n",
" df['aerr'] = np.abs(df.x_tf - df.x_sp)\n",
" df['rerr'] = df.aerr / df.x_sp\n",
" df['aerr_y_tf'] = np.abs(tf.math.igamma(df.a, df.x_tf) - df.y)\n",
" df['rerr_y_tf'] = df.aerr_y_tf / df.y\n",
" df['aerr_y_sp'] = np.abs(sp_special.gammainc(df.a, df.x_sp) - df.y)\n",
" df['rerr_y_sp'] = df.aerr_y_sp / df.y\n",
" df['rtol'] = np.maximum(df.aerr - 1e-12, 0) / df.x_sp\n",
" df.sort_values(by='aerr', ascending=False, inplace=True)\n",
" cases[sample_id] = df"
],
"metadata": {
"id": "fx1102ugwLqm"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for sample_id in range(8):\n",
" max_rtol = cases[sample_id].rtol.max()\n",
" print(f'sample {sample_id} | max_rtol: {max_rtol}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Yxy_RxRAwLqn",
"outputId": "dadb0bff-a1b7-4269-9c4b-9d9537ad7a08"
},
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 | max_rtol: 0.0\n",
"sample 1 | max_rtol: 0.0\n",
"sample 2 | max_rtol: 0.0\n",
"sample 3 | max_rtol: 0.0\n",
"sample 4 | max_rtol: 0.0\n",
"sample 5 | max_rtol: 0.9999999999989997\n",
"sample 6 | max_rtol: 0.9999999999989991\n",
"sample 7 | max_rtol: 0.9999999999989897\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test results for `float32`"
],
"metadata": {
"id": "nl4hOD0Lg7OJ"
}
},
{
"cell_type": "code",
"source": [
"DTYPE = np.float32"
],
"metadata": {
"id": "eXU_VXWlWDqG"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TINY = np.finfo(DTYPE).tiny\n",
"TINY"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "05ce293e-46fe-486e-d787-7a2d81318891",
"id": "wRbuQNJ7WDqH"
},
"execution_count": 22,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.1754944e-38"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [
"EPS = np.finfo(DTYPE).eps\n",
"EPS"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uci_ntVBhBsn",
"outputId": "14f73591-f6e3-44fd-c3c8-5918d9629d6c"
},
"execution_count": 23,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.1920929e-07"
]
},
"metadata": {},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"source": [
"samples = make_samples(\n",
" sample_specs[DTYPE], size_per_seed=SIZE_PER_SEED, seeds=SEEDS, dtype=DTYPE)"
],
"metadata": {
"id": "TocNiKaXWDqH"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sp_res = [sp_gammaincinv(a, y) for a, y in samples]"
],
"metadata": {
"id": "-luHQfz6WDqH"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tf_res = []\n",
"\n",
"for i, (a, y) in enumerate(samples):\n",
" result = tf_gammaincinv(a, y)\n",
" tf_res.append(result)\n",
" print(f'sample {i} is done.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rcaLGwOJKQWs",
"outputId": "5c0217a7-221a-4146-9c97-2e0229f2a03e"
},
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 is done.\n",
"sample 1 is done.\n",
"sample 2 is done.\n",
"sample 3 is done.\n",
"sample 4 is done.\n",
"sample 5 is done.\n",
"sample 6 is done.\n",
"sample 7 is done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f'Method: TFP implementation of igammainv')\n",
"print(f'Benchmark: Cephes implementation of gammaincinv for float64 [SciPy]')\n",
"print(f'Dtype: {np.dtype(DTYPE).name}')\n",
"get_metrics_dataframe(sample_specs, sp_res, tf_res)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 354
},
"id": "Z19Vs8Arb_ZN",
"outputId": "ec774b54-95e0-497d-98bf-c24f4d15355f"
},
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Method: TFP implementation of igammainv\n",
"Benchmark: Cephes implementation of gammaincinv for float64 [SciPy]\n",
"Dtype: float32\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Distribution Domain #Trials %NaN %Inf Min Result Max Result Max Abs Error \\\n",
"0 uniform 1 500000 0 0 1.2e-38 1.00 2.2e-03 \n",
"1 uniform 5 500000 0 0 1.2e-38 1.00 6.7e-04 \n",
"2 uniform 10 500000 0 0 1.3e-38 1.00 1.7e-04 \n",
"3 uniform 30 500000 0 0 1.4e-38 1.00 3.4e-05 \n",
"4 uniform 100 500000 0 0 0 1.02 1.00 \n",
"5 uniform 1000 500000 0 0 0 1.00 1.00 \n",
"6 uniform 10000 500000 0 0 0 1.00 0.99 \n",
"7 uniform 100000 500000 0 0 0 1.00 0.96 \n",
"\n",
" Mean Abs Error Max Rel Error Mean Rel Error \n",
"0 9.5e-08 0.31 3.4e-06 \n",
"1 4.5e-08 0.09 6.1e-07 \n",
"2 4.5e-08 0.04 3.6e-07 \n",
"3 7.3e-08 0.02 1.6e-07 \n",
"4 2.6e-03 1 2.8e-03 \n",
"5 2.4e-04 1 2.5e-04 \n",
"6 3.2e-05 1 3.4e-05 \n",
"7 1.1e-05 1 1.2e-05 "
],
"text/html": [
"\n",
" <div id=\"df-9e6f0ec3-5df3-42f7-a0c1-a62ce53f9184\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Distribution</th>\n",
" <th>Domain</th>\n",
" <th>#Trials</th>\n",
" <th>%NaN</th>\n",
" <th>%Inf</th>\n",
" <th>Min Result</th>\n",
" <th>Max Result</th>\n",
" <th>Max Abs Error</th>\n",
" <th>Mean Abs Error</th>\n",
" <th>Max Rel Error</th>\n",
" <th>Mean Rel Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>uniform</td>\n",
" <td>1</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.2e-38</td>\n",
" <td>1.00</td>\n",
" <td>2.2e-03</td>\n",
" <td>9.5e-08</td>\n",
" <td>0.31</td>\n",
" <td>3.4e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>uniform</td>\n",
" <td>5</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.2e-38</td>\n",
" <td>1.00</td>\n",
" <td>6.7e-04</td>\n",
" <td>4.5e-08</td>\n",
" <td>0.09</td>\n",
" <td>6.1e-07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>uniform</td>\n",
" <td>10</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.3e-38</td>\n",
" <td>1.00</td>\n",
" <td>1.7e-04</td>\n",
" <td>4.5e-08</td>\n",
" <td>0.04</td>\n",
" <td>3.6e-07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>uniform</td>\n",
" <td>30</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.4e-38</td>\n",
" <td>1.00</td>\n",
" <td>3.4e-05</td>\n",
" <td>7.3e-08</td>\n",
" <td>0.02</td>\n",
" <td>1.6e-07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>uniform</td>\n",
" <td>100</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.02</td>\n",
" <td>1.00</td>\n",
" <td>2.6e-03</td>\n",
" <td>1</td>\n",
" <td>2.8e-03</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>uniform</td>\n",
" <td>1000</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>2.4e-04</td>\n",
" <td>1</td>\n",
" <td>2.5e-04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>uniform</td>\n",
" <td>10000</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.00</td>\n",
" <td>0.99</td>\n",
" <td>3.2e-05</td>\n",
" <td>1</td>\n",
" <td>3.4e-05</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>uniform</td>\n",
" <td>100000</td>\n",
" <td>500000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1.00</td>\n",
" <td>0.96</td>\n",
" <td>1.1e-05</td>\n",
" <td>1</td>\n",
" <td>1.2e-05</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-9e6f0ec3-5df3-42f7-a0c1-a62ce53f9184')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-9e6f0ec3-5df3-42f7-a0c1-a62ce53f9184 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-9e6f0ec3-5df3-42f7-a0c1-a62ce53f9184');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"source": [
"cases = {}\n",
"\n",
"for sample_id in range(8):\n",
" a, y = samples[sample_id]\n",
" x_tf = tf_res[sample_id]\n",
" x_sp = sp_res[sample_id]\n",
" data = np.column_stack([a, y, x_tf, x_sp])\n",
" df = pd.DataFrame(data=data, columns=['a', 'y', 'x_tf', 'x_sp'])\n",
" df['aerr'] = np.abs(df.x_tf - df.x_sp)\n",
" df['rerr'] = df.aerr / df.x_sp\n",
" df['aerr_y_tf'] = np.abs(tf.math.igamma(df.a, df.x_tf) - df.y)\n",
" df['rerr_y_tf'] = df.aerr_y_tf / df.y\n",
" df['aerr_y_sp'] = np.abs(sp_special.gammainc(df.a, df.x_sp) - df.y)\n",
" df['rerr_y_sp'] = df.aerr_y_sp / df.y\n",
" df['rtol'] = np.maximum(df.aerr - 1e-12, 0) / df.x_sp\n",
" df.sort_values(by='aerr', ascending=False, inplace=True)\n",
" cases[sample_id] = df"
],
"metadata": {
"id": "Ra2HBhqCZypb"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for sample_id in range(8):\n",
" max_rtol = cases[sample_id].rtol.max()\n",
" print(f'sample {sample_id} | max_rtol: {max_rtol}')"
],
"metadata": {
"id": "AwynJH8wZ05b",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "6ef028d3-a47b-4b53-f2e4-29a8041f5ea1"
},
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sample 0 | max_rtol: 0.30785343050956726\n",
"sample 1 | max_rtol: 0.016323404386639595\n",
"sample 2 | max_rtol: 0.007907157763838768\n",
"sample 3 | max_rtol: 0.0017979817930608988\n",
"sample 4 | max_rtol: 1.0\n",
"sample 5 | max_rtol: 1.0\n",
"sample 6 | max_rtol: 1.0\n",
"sample 7 | max_rtol: 1.0\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment