Skip to content

Instantly share code, notes, and snippets.

@tupui
Last active August 16, 2021 17:18
Show Gist options
  • Save tupui/4310485c9920868c404728cdf8cfe93f to your computer and use it in GitHub Desktop.
Save tupui/4310485c9920868c404728cdf8cfe93f to your computer and use it in GitHub Desktop.
MCM 2021 Scipy and Pytorch demo
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "mcm2021_scipy_demo.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tupui/4310485c9920868c404728cdf8cfe93f/mcm2021_scipy_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AgylrjFWsC3P"
},
"source": [
"# QMC in SciPy\n",
"\n",
"Tutorial from Pamphile T. Roy and Max Balandat"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZFifc6a33gql"
},
"source": [
"Since SciPy 1.7, we introduced a QMC submodule in `scipy.stats.qmc`.\n",
"Here are some things we can do with it!\n",
"\n",
"A tutorial can be found in the documentation at:\n",
"https://scipy.github.io/devdocs/tutorial/stats.html#quasi-monte-carlo\n",
"\n",
"Let's check our version of SciPy and play with the QMC submodule."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "czfyTmTVr0gJ",
"outputId": "45234c53-bce6-44c8-e88c-e91b8e1ef3f9"
},
"source": [
"# ensure the most recent version of SciPy is installed\n",
"!pip install scipy==1.7.1"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: scipy==1.7.1 in /usr/local/lib/python3.7/dist-packages (1.7.1)\n",
"Requirement already satisfied: numpy<1.23.0,>=1.16.5 in /usr/local/lib/python3.7/dist-packages (from scipy==1.7.1) (1.19.5)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ogrNY4q5sPIk"
},
"source": [
"from scipy.stats import qmc"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5bqzmTYkwJzh"
},
"source": [
"First let's define a `rng` object to make the following reproducible. Please don't use this in production code. See the section bellow."
]
},
{
"cell_type": "code",
"metadata": {
"id": "zagi1rpyqf7P"
},
"source": [
"import numpy as np\n",
"\n",
"# seed = np.random.SeedSequence().entropy\n",
"# print(seed)\n",
"seed = 292114020772849599029278515437886320941\n",
"rng = np.random.default_rng(seed)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zzhh2ede1fjL"
},
"source": [
"## Basics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GwCVF0rOwhk6"
},
"source": [
"There are multiple `QMCEngine` that can be used to sample points in $\\mathcal{U} \\sim [0, 1)^d$."
]
},
{
"cell_type": "code",
"metadata": {
"id": "d0zhYGo5sk1z"
},
"source": [
"n, d = 64, 2\n",
"lhs_engine = qmc.LatinHypercube(d=d, centered=True, seed=rng) # centered to make nice bins\n",
"sample = lhs_engine.random(n=n)"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 391
},
"id": "CUU5orYwtea9",
"outputId": "f805ffbe-64a6-4a4a-dfea-00277e7aa2bb"
},
"source": [
"import seaborn as sns\n",
"import pandas as pd\n",
"sns.pairplot(pd.DataFrame(sample), diag_kind=\"hist\", corner=True, diag_kws={\"bins\": 4})"
],
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<seaborn.axisgrid.PairGrid at 0x7fc227a38d10>"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 5 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_jVyie7LxL8r"
},
"source": [
"With the samples, you can further compute quality metrics such as the discrepancy and also scale to bounds."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VmpnCrwstedZ",
"outputId": "a19e2287-ed48-442c-85b8-9a6b9eb1a239"
},
"source": [
"qmc.discrepancy(sample)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0003283487484500025"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 391
},
"id": "WmYpgzuExg4o",
"outputId": "8ffbc5d2-05ca-473a-c3bf-a86b09e1dcf4"
},
"source": [
"scaled_sample = qmc.scale(sample, l_bounds=[-1, 10], u_bounds=[4, 22])\n",
"sns.pairplot(pd.DataFrame(scaled_sample), diag_kind=\"hist\", corner=True, diag_kws={\"bins\": 4})"
],
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<seaborn.axisgrid.PairGrid at 0x7fc227837850>"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWUAAAFlCAYAAAAzhfm7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAYs0lEQVR4nO3de5BedX3H8c9342oil6oJBM2FNQ5KI7UwXRkGZcQwY1Okxc600Dg6qLShVgWUDlWcKfWPzjjV4qV2qqukoZZB6IiXcRyUSbXgSGIjoohQLxggEMgmXogwEZL99o/nWbLZ7HM/v3O+v3PerxmG5NnL+T3u8vH7u3zPMXcXACCGsaoHAAA4hFAGgEAIZQAIhFAGgEAIZQAIhFAGgECeVfUA+sS5PaRiVQ8AmItKGQACIZQBIBBCGQACIZQBIBBCGQACIZQBIJBcjsR1tGLVaj2y86Gqh6FF48/Rwad/W/UwJMUZS5RxvGjlKj380INVDwPoi2Vy686OgzQzXfipb5c5lgXdeMmZIcYhxRlLpHF0+T3nnDJCYfkCAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIhlAEgEEIZAAIxd696DD2Z2S2SllU9DrXGsKfqQSTUxPe3x93XVzEYYCFZhHIUZrbd3SerHkcqvD+geixfAEAghDIABEIoD2aq6gEkxvsDKsaaMgAEQqUMAIEQygAQyLOqHkA/1q9f77fcckvVw0A9WZePsbaHVDr+3mVRKe/ZU+d+BgA4JFkom9kqM/uGmf3IzO4xs8var3/IzO4zsx+Y2RfM7HmpxgAAuUlZKR+QdIW7r5V0hqR3mNlaSbdKOsXdXyHpx5Lel3AMAJCVZGvK7r5L0q72n/eZ2b2SVrj71+d82lZJf5ZqDOhuZsa1Y+8Teuzx/Vp+7GJNLD1KY2PdllgBpFbKRp+ZTUg6TdK2eR96m6QbyxgDDjcz47rlnkf1npvu0v6nZ7R4fEzXXHCq1r/8BIIZqFDyjT4zO1rS5yVd7u6Pz3n9/WotcVzf4es2mtl2M9s+PT2depilm5lx3T/9G93xsz26f/o3mpkpd6N/x94nnglkSdr/9Izec9Nd2rH3iVLHAeBwSStlMxtXK5Cvd/eb57z+FknnSTrHO7QUuvuU2m2xk5OTtTqaFKFKfezx/c8E8qz9T89o9779WnPc0aWMIVcrVq3WIzsfqnoYkqRF48/Rwad/W/UwwoxDijOWF61cpYcfenDgr0sWymZmkq6VdK+7XzPn9fWSrpT0Gnd/MtX1I+tUpZ586VmlBeLyYxdr8fjYYcG8eHxMxx+zuJTrj6LqtfBHdj6kCz/17dKu182Nl5wZYixRxiHFGcuNl5w51NelXL54laQ3S1pnZne1/zlX0ickHSPp1vZrn0w4hpC6VallmVh6lK654FQtHm/9CsxW6xNLjyptDMOYnWWc+/HbteHT23Tux2/XLfc8WvryD5BKytMX39LCXStfTXXNXESoUsfGTOtffoJOvvQs7d63X8cfk8fpiwizDCClLDr6Uqlqsy1KlTo2Zlpz3NE6Y80yrTnu6PCBLMWYZQApZXHvixSq3GzLtUqNIMIsA0ipsZVy1UfCcqxS52r6LANIpbGVMkfChscsA0insZXy7DR4rsXjYzrh2MWVNnXkgFkGkE5jQ3mhafAn3niafrRrH8etemCzDUinscsXC02D3aXX/8vtXY9bVd24EAGbbUA6ja2UpSOnwbv3da8AaVxoYbMNSKexlfJCelWAOTYupKjs2WwD0iGU55itAOefKpitAHM7sZHylMTsLCPi+wZyRijP0asCzG0tNcfKHmi6Rq0p99Pw0O24VW5rqZySAPLTmEq5iKl8bmupuVX2ABpUKRfV8DBs40IVbcm5VfYA0t7kfpWk/5C0XJJLmnL3j5nZC9R6Lt+EpB2SLnD3X6Yax6wqN+mqakvOrbIHkLZSPiDpCndfK+kMSe8ws7WS3itpi7ufJGlL++/JdWqrLnIq36karrItmZZkIC/JQtndd7n7ne0/75N0r6QVks6XdF37066T9IZUY5gr9VS+W2MJG24A+lXKRp+ZTUg6TdI2ScvdfVf7Q4+qtbyR3KhT+V5NGN2On7HhBqBfyUPZzI5W64nWl7v7463nqba4u5vZgjteZrZR0kZJWr16dSFjGbbhoZ814W7V8OkTS7s2pQDArKShbGbjagXy9e5+c/vlx8zshe6+y8xeKGn3Ql/r7lOSpiRpcnKy51GFlDcK6qcJo1s1zIYbgH4lW1O2Vkl8raR73f2aOR/6sqSL2n++SNKXRr1W6hsF9bMm3GvNmg03AP1IWSm/StKbJd1tZne1X7tK0gcl3WRmF0t6QNIFo14odTtxP2vCVMOj4ZaoQEuyUHb3b0nq9F/VOUVeK/UZ5F43KprFTXqGU+XjpYBoatFmnfp0A1VwWtw4CTikFm3WZbQTsyY8uk7NNZzjBg6pRaVMJRtftyUKznEDh9SiUpaoZOer4gZI3XRrNefGScAhtaiUcbiIG2e9NmOZ6QAttamUcUjKGyANW4H3uiEUMx2ghVCuoVQbZ6M06bBEAfQn2+ULmg06S7VxNsrRNTZjgf5kGcoR10wj6bfZZVCjNunQXAP0lmUoN73ZoNcsIVVVytE1IL0sQ7nKRztVrd9ZQoqqNFUFDuCQLEO5yRVblbME1oWB9LI8fVGnnfxBj5hV3ZLM0TUgrSwr5bpUbMNsWDZ5lgA0QZaVsjRYxRat5XjWME0edZolADhSskrZzDZJOk/Sbnc/pf3aqZI+KWmxpAOS/sbdv5NqDFLs43PDbFjWZZYAYGEpK+XNktbPe+2fJH3A3U+V9PftvyeVsuV4VL1ajzspel036kwCaKJkoezut0n6xfyXJR3b/vPvSHok1fVnVb0x1k2EpYjUzzcEMJiyN/oul/Q1M/uwWv+HcGbqC466MZaynTvCUkTTG3GAaMre6Hu7pHe7+ypJ71bradcLMrONZrbdzLZPT08PfcFRqtEyqsiqj5hFnkkATVR2pXyRpMvaf/4vSZ/p9InuPiVpSpImJyeHTsFRqtEmVJFVHrHjplLAkcqulB+R9Jr2n9dJ+kkZFx22Gm1CFVnVujZr2cDCUh6Ju0HS2ZKWmdlOSVdL+itJHzOzZ0naL2ljqusXoQmNGlWtazdhFgIMI1kou/uGDh/6g1TXLFpTbsAzO5OYWHqUdux9Qtt+vjf5ckKTbyoFdJNlm3VZIpyOKEvZTTZNmIUAw8i2zbofRTRFVH06oixlN9lEOKMNRFTbSjlye3VEZS8nNGkWAgyitpVyle3VObYtD9vyPYqmzEKAQdQ2lKs6zpbrUS+WE4AYart8UdVGUq5HvVIuJ9AkAvSvtqFc1XG2nI96pXiuH2v7wGBqG8pVbSRx1OuQmRnX3Q//KsuZA1CV2q4pS9VsJLE22zJbIW+5b3ftW9WBItW2Uq4KR71aZtfW//KsNcwcgAHUulKuymyFfvrEUknStp/vzeZoXFFm19Y//92dunTdSY2fOQD9olJOpOkbXLNr67t+vV+f3fqALn71Gi0ak845+Xj93ornNeJ/A2AYVMqJRH42YBnmrq3v+vV+Xfut+3XyCceWEsg5Nu8As6iUE8n5aFwRqlpbb/oMBflLVimb2SYz221mP5z3+rvM7D4zu8fMkj/NOqVuFVkVbcvRVHH6pekzFOQv5fLFZknr575gZq+VdL6k33f3l0v6cMLrJ9WrnZqjcdVowtNiUG8pb3J/m5lNzHv57ZI+6O6/bX/O7lTXT61XO3Xdjsbl0ipN8w5yN3SlbGZvHeLLXirpLDPbZmb/Y2avHPb6VeunIqvLXdByuskSMxTkbpRK+QOS/n2I671A0hmSXinpJjNb4+5H/NdtZhvVfobf6tWrRxhmGk2qyKLfZGl+Ff+6312ur9ZkhoLm6RrKZvaDTh+StHyI6+2UdHM7hL9jZjOSlkmanv+J7j4laUqSJicnw5VkTXl+nxT7JEm30xZVjw0YRq9KebmkP5T0y3mvm6RvD3G9L0p6raRvmNlLJT1b0p4hvk/l6rZm3E0Vs4J+17CjV/HAoHqF8lckHe3ud83/gJl9s9sXmtkNks6WtMzMdkq6WtImSZvax+SeknTRQksXuUhxq8uIyp4VDHLWOHIVDwyjayi7+8VdPvbGHl+7ocOH3tTHuBBI2bOCQarfJq3toxlosy5Zri3AZZ4kGeSsMactUDe0WZeIFuD+DFL9NmltH81ApVyilC3AuVbgCxm0+q3LeXBAolIuVapNqbpV4FS/aDIq5RKluklRnW7CM1vxb/v5XknS6RNLqX7RKIRyiVJtStXlJjw5tXMDqbB8UaJU0/J+NsZyuKEQjSAAoVy6FA0nvZo7cllzphEEIJRroVcFnksFOmojSA6zAaAXQrkmulXguVSgo7Rz5zIbAHohlBsgl1bkUdbcc5kNAL1w+iKwohpCcmpFHrYRpC4nUAAq5aCKnI43oRkjl9kA0AuVclBFN4SkbkWuus07p9kA0A2VclC5bM5JMTbZmjAbQDMkq5TNbJOZ7W7f0H7+x64wMzezZamu303VVV0/UrVkpxClzZsbE6EOUi5fbJa0fv6LZrZK0uskPZjw2h3l0sqb03ScTTagOMmWL9z9NjObWOBDH5F0paQvpbp2N7kcncppOp5yk42GEDRNqRt9Zna+pIfd/ft9fO5GM9tuZtunp4942PXQcqrqcpmOp6rqc5nVAEUqbaPPzJ4r6Sq1li56cvcpSVOSNDk5Wdh/hQtVdScuXaIl44t0x8/2UI0NIVVVn8usBihSmZXySyS9WNL3zWyHpJWS7jSzE0ocwxFV3YlLl+hd607ShVNbqcZGkKKqz2lWAxSltErZ3e+WdPzs39vBPOnue8oag3RkVbdkfJEunNpKNRYQDSFoopRH4m6QdIekl5nZTjO7ONW1BjW3qnvyqYO1qsZyOO7Xr5xOoABFSXn6YkOPj0+kuvYg6lSNRWjiKFJOJ1CAojS+zTpKNVZEhRuliaNIuZxAAYrS+DbrCNVYURVuTq3ZABbW+EpZqr4aK6rCzak1uwx1Wl9HcxDKARR19CvKUkwENJ4gV41fvoig02bjkvFFmpnxviv3CEsxUdB4glxRKQewUIV76bqTdOnnvjdwdVf1UkwUNJ4gV1TKAcxWuCs2nqEt9+3WwRnps1sf0K5f7w9f3UW9YVCdjjqiWQjlIMbGTE8+dVAf3/LTw16PfHoi8rnoUZ6MDVSJUA4kt+qu6nXbblU66+vIFaEcSG7VXZXnovup0mfX1yPOMoBOCOVAcqvuqqzsq67SgVQ4fdFBVY0HOZ2eqPJcNKcrUFdUyguIvIEVSZWVfW7r70C/qJQXEO3GPpHbhauq7OleRF0lq5TNbJOk8yTtdvdT2q99SNIfS3pK0s8kvdXdf5VqDMOKdGMfqvaF5bb+DvQrZaW8WdL6ea/dKukUd3+FpB9Lel/C6w8t0o19olXtRSiq8s9p/R3oV7JQdvfbJP1i3mtfd/cD7b9uVes5feFEmhrXbUOLGwUB3VW50fc2STcW/U2LaPuNNDWu24YWR9mA7irZ6DOz90s6IOn6Lp+z0cy2m9n26enpvr5vkVVYlKlxpKq9CHWr/IGilV4pm9lb1NoAPMfdO6alu09JmpKkycnJvlK1jlVYpKq9CMNW/lFvfAQUrdRQNrP1kq6U9Bp3f7Lo7x/p1ESR6tQuPEwrOSdQ0CQpj8TdIOlsScvMbKekq9U6bfEcSbeamSRtdfe/LuqadVt/raNhKv86zoCATpKFsrtvWODla1NdT8rvhj6DqssUftDKv64zIGAhtWqzrtv661xNnsIzA0KTZN9mPb8RQVKIUxNFq2MTSb/qdgIF6CbrSrlJ1WOTp/B1ngEB82VdKTepelyo9fvEpUu0ZHxRyBsVFS3KuXEgtaxDuUmNCPOn8CcuXaJ3rTtJF05tpV0ZqJGsly+atAE0fwq/ZHyRLpzamuSYWF1OeQA5yrpSbtoG0Nwp/JNPHUwyS+CGQUC1sq6Um7wBlGqWUESjBpU2MLysQ1mqVwvyIFI1yox6yqNJJ2KAFLIP5aZKNUsYtQKnJRoYDaGcsRSzhLkV+POf+2z9+eRKvfT4Y+TeqoJ7hX6Tz1MDRSCUcZjZCnztZWfpzgd/pau+cPdAyxBNOhEDpJD16YsoIj9tehhjY6YZ1zOBLPXfmNO0EzFA0aiUR1TXja1hlyGafCIGKEKyStnMNpnZbjP74ZzXXmBmt5rZT9r/fn6q65clt1bvfqv6UZ7oTUs0MLyUyxebJa2f99p7JW1x95MkbWn/PWs5tXoP0hjCMgRQjZQ3ub/NzCbmvXy+Wk8jkaTrJH1T0t+lGkMZctrYGuS4GssQQDXK3uhb7u672n9+VNLykq9fuJwqykGrepYhgPJVttHn7m5mHY8pmNlGSRslafXq1aWNa1A5VZQ5VfVAU5VdKT9mZi+UpPa/d3f6RHefcvdJd5887rjjShvgMHKpKHOq6oGmKrtS/rKkiyR9sP3vL5V8/UbLqaoHmipZKJvZDWpt6i0zs52SrlYrjG8ys4slPSDpglTXx8KaegMnIBcpT19s6PChc1JdEwByR5s1AARi7vHv02Bm02otd1RtmaQ9VQ8ioSa+vz3uPr/JSZJkZre0v6bX96irprzXKt5n59+7HEI5CjPb7u6TVY8jFd5fOd8jF015r9HeJ8sXABAIoQwAgRDKg5mqegCJ8f7K+R65aMp7DfU+WVMGgEColAEgEEK5D2a23sz+z8x+ambZ3wN6LjNbZWbfMLMfmdk9ZnZZ1WNKwcwWmdn3zOwrA3xNIx7U0OF9/oOZPWxmd7X/ObfKMRah0+96tJ8podyDmS2S9K+S/kjSWkkbzGxttaMq1AFJV7j7WklnSHpHzd7frMsk3Tvg12xWAx7UoIXfpyR9xN1Pbf/z1ZLHlEKn3/VQP1NCubfTJf3U3e9396ckfU6tm/XXgrvvcvc723/ep1Zwrah2VMUys5WSXi/pM4N8nbvfJukX814+X60HNKj97zeMPMCKdXiftdPldz3Uz5RQ7m2FpIfm/H2nahZas9pPijlN0rZqR1K4j0q6UtJMr0/sQ+0e1NDFO83sB+3ljeyXaeaa97se6mdKKEOSZGZHS/q8pMvd/fGqx1MUMztP0m53/27R39tbR5fqenzp3yS9RNKpknZJ+udqh1Ocbr/rEX6mhHJvD0taNefvK9uv1YaZjav1S3q9u99c9XgK9ipJf2JmO9RaelpnZv85wvfr+0ENOXP3x9z9oLvPSPq0Wst42evwux7qZ0oo9/a/kk4ysxeb2bMl/YVaN+uvBTMzSddKutfdr6l6PEVz9/e5+0p3n1DrZ/ff7v6mEb7l7IMapBo/qGE2pNr+VNIPO31uLrr8rof6mdI80of2caCPSlokaZO7/2PFQyqMmb1a0u2S7tahNderarLbfhgzO1vS37r7eX1+/jMPapD0mFoPaviipJskrVb7QQ3unvUmWYf3ebZaSxcuaYekS+asu2ap0++6WuvKYX6mhDIABMLyBQAEQigDQCCEMgAEQigDQCCEMgAEQijXSJ3vZoeYFrrDHEZDKNdEA+5mh5g2a+E7zGFIhHJ91PpudoipKXeYKxOhXB+NuZsdUGeEMgAEQijXR+3vZgc0AaFcH7W+mx3QFIRyTbj7AUnvlPQ1tR5zc5O731PtqFB37TvM3SHpZWa208wurnpMueMucQAQCJUyAARCKANAIIQyAARCKANAIIQyAARCKANAIIQyAARCKANAIP8P1NoXD9WiJi4AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 360x360 with 5 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8ClgsrjFyuCt"
},
"source": [
"The start of the show is the Sobol' sampler:"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "M8NA1eA51fpk",
"outputId": "ef25eb46-a83e-41a2-a53c-2c34556fbf0e"
},
"source": [
"sobol_engine = qmc.Sobol(d=d, scramble=False, seed=rng)\n",
"sample = sobol_engine.random(n=n)\n",
"qmc.discrepancy(sample)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.00024001962608743987"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 391
},
"id": "sxNasrNb134A",
"outputId": "4c6bf5b7-78ba-467f-a778-8bda2ebad4c1"
},
"source": [
"sns.pairplot(pd.DataFrame(sample), diag_kind=\"hist\", corner=True, diag_kws={\"bins\": 4})"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<seaborn.axisgrid.PairGrid at 0x7ff6b6f25990>"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 5 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19rER8C12Ito",
"outputId": "5841785f-4400-4f03-daa1-050abed83ead"
},
"source": [
"sobol_engine = qmc.Sobol(d=d, scramble=True, seed=rng)\n",
"sample = sobol_engine.random(n=n)\n",
"qmc.discrepancy(sample)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0001551092547196209"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hgTaZwB128hH"
},
"source": [
"We spent quite some time on Sobol'. First to ensure it starts with 0, and then to have some warnings if users do not ask for $2^n$ points. Who would dare doing this!?"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ff5Prbx22seV",
"outputId": "0ece6166-3217-46c7-c59c-9b5325ca95bf"
},
"source": [
"sobol_engine = qmc.Sobol(d=3, scramble=False, seed=rng)\n",
"sample = sobol_engine.random(10)\n",
"sample"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/scipy/stats/_qmc.py:1078: UserWarning: The balance properties of Sobol' points require n to be a power of 2.\n",
" warnings.warn(\"The balance properties of Sobol' points require\"\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0. , 0. , 0. ],\n",
" [0.5 , 0.5 , 0.5 ],\n",
" [0.75 , 0.25 , 0.25 ],\n",
" [0.25 , 0.75 , 0.75 ],\n",
" [0.375 , 0.375 , 0.625 ],\n",
" [0.875 , 0.875 , 0.125 ],\n",
" [0.625 , 0.125 , 0.875 ],\n",
" [0.125 , 0.625 , 0.375 ],\n",
" [0.1875, 0.3125, 0.9375],\n",
" [0.6875, 0.8125, 0.4375]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fGR-VQjO1j1k"
},
"source": [
"## Sample from another distribution\n",
"\n",
"The following is also new to SciPy. We can sample from any arbitrary distribution using a MC or QMC engine."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nVRBvo431PAq"
},
"source": [
"Just defining a helper function to do a nice plot of the PDF"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y5CZ2NgUtefs"
},
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def plot_pdf(dist, sample):\n",
" fig, ax = plt.subplots()\n",
"\n",
" x = np.linspace(dist.ppf(0.01), dist.ppf(0.99), 100)\n",
" pdf = dist.pdf(x)\n",
" ax.plot(x, pdf, '-', lw=5, label='fisk pdf')\n",
"\n",
" delta = np.max(pdf) * 5e-2\n",
" ax.plot(sample, -delta - delta * np.random.random(99), \"+k\")\n",
"\n",
" ax.hist(sample, density=True, histtype='stepfilled', alpha=0.2)\n",
" ax.legend(loc='best', frameon=False)\n",
" plt.show()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "6dXW9GrB14qJ"
},
"source": [
"First we sample from an arbitrary distribution-here Fisk-using MC."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
},
"id": "uKbfC9RHteiK",
"outputId": "3d2a5d83-cf47-4368-979d-b99656480540"
},
"source": [
"from scipy.stats import fisk\n",
"\n",
"c = 3.9\n",
"dist = fisk(c)\n",
"\n",
"sample = dist.rvs(99, random_state=rng)\n",
"plot_pdf(dist, sample)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x5pvmhcZ2C01"
},
"source": [
"And now we can use `NumericalInverseHermite` to use any `QMCEngine`."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 299
},
"id": "s9ERuEQwuf5L",
"outputId": "09b12cd3-0361-45ce-b24a-3dc262e27092"
},
"source": [
"from scipy.stats import NumericalInverseHermite\n",
"\n",
"fni = NumericalInverseHermite(dist)\n",
"sobol_engine = qmc.Sobol(d=1, scramble=True, seed=rng)\n",
"sample = fni.qrvs(99, qmc_engine=sobol_engine)\n",
"\n",
"\n",
"plot_pdf(dist, sample)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/scipy/stats/_qmc.py:1078: UserWarning: The balance properties of Sobol' points require n to be a power of 2.\n",
" warnings.warn(\"The balance properties of Sobol' points require\"\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nu2-pSO22NSY"
},
"source": [
"There is also `NumericalInverseHermite.rvs` which uses MC."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-fOq96Ax6B7k"
},
"source": [
"## Integration convergence\n",
"\n",
"We all like convergence plots don't we?\n",
"\n",
"In the following, I evaluate a squared sum in 3 dimension as an example."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
},
"id": "VDPWdjDqwFnl",
"outputId": "30e1e0e5-3ee9-4c2f-d15b-e78f48663282"
},
"source": [
"from collections import namedtuple\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.stats import qmc\n",
"\n",
"n_conv = 99999\n",
"min_m, max_m = 4, 11\n",
"ns_gen = 2 ** np.arange(min_m, max_m)\n",
"\n",
"\n",
"def art_2(sample):\n",
" \"\"\"3-D squared sum.\n",
"\n",
" True value 5/3 + 5*(5 - 1)/4\n",
"\n",
" Art B. Owen. On dropping the first Sobol' point. arXiv 2008.08051, 2020.\n",
" \"\"\"\n",
" return np.sum(sample, axis=1) ** 2\n",
"\n",
"\n",
"functions = namedtuple('functions', ['name', 'func', 'dim', 'ref'])\n",
"case = functions('Art 2', art_2, 5, 5 / 3 + 5 * (5 - 1) / 4)\n",
"\n",
"\n",
"def conv_method(sampler, func, n_samples, n_conv, ref):\n",
" samples = [sampler(n_samples) for _ in range(n_conv)]\n",
" samples = np.array(samples)\n",
"\n",
" evals = [np.sum(func(sample)) / n_samples for sample in samples]\n",
" squared_errors = (ref - np.array(evals)) ** 2\n",
" rmse = (np.sum(squared_errors) / n_conv) ** 0.5\n",
"\n",
" return rmse\n",
"\n",
"\n",
"# Analysis\n",
"sample_mc_rmse = []\n",
"sample_sobol_s_rmse = []\n",
"sample_sobol_rmse = []\n",
"\n",
"for ns in ns_gen:\n",
" # Monte Carlo\n",
" sampler_mc = lambda x: rng.random((x, case.dim))\n",
" conv_res = conv_method(sampler_mc, case.func, ns, n_conv, case.ref)\n",
" sample_mc_rmse.append(conv_res)\n",
"\n",
" # Sobol'\n",
" engine = qmc.Sobol(d=case.dim, scramble=False)\n",
" conv_res = conv_method(engine.random, case.func, ns, 1, case.ref)\n",
" sample_sobol_rmse.append(conv_res)\n",
"\n",
" engine = qmc.Sobol(d=case.dim, scramble=True)\n",
" conv_res = conv_method(engine.random, case.func, ns, n_conv, case.ref)\n",
" sample_sobol_s_rmse.append(conv_res)\n",
"\n",
"sample_mc_rmse = np.array(sample_mc_rmse)\n",
"sample_sobol_rmse = np.array(sample_sobol_rmse)\n",
"sample_sobol_s_rmse = np.array(sample_sobol_s_rmse)\n",
"\n",
"# Plot\n",
"fig, ax = plt.subplots()\n",
"\n",
"\n",
"# MC\n",
"ratio = sample_mc_rmse[0] / ns_gen[0] ** (-1 / 2)\n",
"ax.plot(ns_gen, ns_gen ** (-1 / 2) * ratio, ls='-', c='k')\n",
"\n",
"ax.scatter(ns_gen, sample_mc_rmse, label=\"MC\")\n",
"\n",
"# Sobol'\n",
"ratio = sample_sobol_rmse[0] / ns_gen[0] ** (-2/2)\n",
"ax.plot(ns_gen, ns_gen ** (-2/2) * ratio, ls='-', c='k')\n",
"\n",
"ratio = sample_sobol_s_rmse[0] / ns_gen[0] ** (-4/2)\n",
"ax.plot(ns_gen, ns_gen ** (-4/2) * ratio, ls='-', c='k')\n",
"\n",
"ax.scatter(ns_gen, sample_sobol_rmse, label=\"Sobol' unscrambled\")\n",
"ax.scatter(ns_gen, sample_sobol_s_rmse, label=\"Sobol' scrambled\")\n",
"\n",
"ax.set_xlabel(r'$N_s$')\n",
"ax.set_xscale('log')\n",
"ax.set_xticks(ns_gen)\n",
"ax.set_xticklabels([fr'$2^{{{ns}}}$'\n",
" for ns in np.arange(min_m, max_m)])\n",
"\n",
"ax.set_ylabel(r'$\\log (\\epsilon)$')\n",
"ax.set_yscale('log')\n",
"\n",
"ax.legend(loc='lower left')\n",
"fig.tight_layout()\n",
"plt.show()\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_0SHgi4g7LGx"
},
"source": [
"## Random number generators and seed\n",
"\n",
"I have a final note on RNG with NumPy. You may have seen or used lot of things like `np.random.seed(123)` or `np.random.RandomState(123)`. Please don't.\n",
"\n",
"There are 3 problems:\n",
"\n",
"* Using a fix seed is in general dangerous in a sense that you can forget it and then instead of having a RNG, well it's not anymore random and hence just a MC draw. Also small values, such as commonly used 0, 123, etc. tend to produce bad entropy for the generator.\n",
"* Using `np.random.seed(...)` fix the global seed. But new code are not relying on this at all. Hence you might only fix a portion of the code.\n",
"* `np.random.RandomState` should not be used for new code because it's slower and has statistical issues. New code should use `np.random.Generator` which is better in every way.\n",
"\n",
"Have a look at the documentation for more details:\n",
"https://scipy.github.io/devdocs/tutorial/stats.html#random-number-generation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OwCwpG1Rhhg-"
},
"source": [
"# PyTorch\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VoGIgD4IfDru"
},
"source": [
"### Basics\n",
"\n",
"The implementation of Sobol' sequences in PyTorch is essentially a C++ (PyTorch ATen) port of the `scipy` version, so their behavior is equivalent. There are some slight differences in the API (largely argument names), but otherwise using the PyTorch [`SobolEngine`](https://pytorch.org/docs/stable/generated/torch.quasirandom.SobolEngine.html) is very similar to using `scipy`'s `Sobol`. \n",
"\n",
"*Note:* \n",
"Currently `SobolEngine` is limited to being run on the CPU (of course the generated sample tensors can easily be moved to GPU). [`cuRAND`](https://docs.nvidia.com/cuda/curand/host-api-overview.html#generator-types) supports generating Sobol' sequences on the GPU, but this is currently not exposed in PyTorch (always happy to accept PRs...)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "uFDYUHTJ7N_o",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "385b6ffc-767e-4552-cd57-c5efaf5f4693"
},
"source": [
"import torch\n",
"\n",
"torch.__version__ # make sure we're good (1.9+)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'1.9.0+cu102'"
]
},
"metadata": {
"tags": []
},
"execution_count": 53
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bMURUs6s1EHy"
},
"source": [
"t_sobol_engine = torch.quasirandom.SobolEngine(\n",
" dimension=2, scramble=True, seed=29211402,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "usxanizZ1Hr4",
"outputId": "78c31ea0-583e-4528-f3c4-cc5bdd30d7eb"
},
"source": [
"t_sobol_engine.draw(4)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.9967, 0.5888],\n",
" [0.0412, 0.4169],\n",
" [0.2833, 0.9479],\n",
" [0.7389, 0.0573]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 55
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7gRvyyg1xpTB",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4f387b93-703c-4009-b8a8-a7553e0bcc23"
},
"source": [
"t_sobol_engine.reset()\n",
"t_sobol_engine.draw_base2(2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.9967, 0.5888],\n",
" [0.0412, 0.4169],\n",
" [0.2833, 0.9479],\n",
" [0.7389, 0.0573]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 56
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CLwohIlYqnO"
},
"source": [
"### Illustrative example: Auto-differentiating through quasi-random samples.\n",
"\n",
"We consider the case where we draw random samples $\\xi$ from some distributon $\\mathcal{D}$, where $\\mathcal{D} = \\mathcal{D}(x)$ depends on some parameter $x$. We'd like to compute gradients of $\\xi$ w.r.t. $x$. We don't want to do that by hand. In various cases we can do this by combining the \"reparameterization trick\" [1] with auto-differentiation capabilities of modern Deep Learning frameworks. \n",
"\n",
"[1] *D. P. Kingma and M. Welling. Auto-Encoding Variational Bayes. arXiv e-prints, 2013.*"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-uTOD3zMYzTb"
},
"source": [
" from torch import Tensor\n",
"\n",
"# create a tensor, let torch know we want to compute gradients\n",
"x = torch.tensor([1., 2., 3.], requires_grad=True)\n",
"\n",
"# do some kind of transformation\n",
"def dummy_transform(z: Tensor) -> Tensor:\n",
" return torch.sin(z) ** 2\n",
"\n",
"# create a random SPD matrix and add the transformed b to its diagonal\n",
"a = torch.rand(3, 3)\n",
"M = a @ a.t() + torch.diag_embed(dummy_transform(x))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zo_J_eHyajaP"
},
"source": [
"Let's draw some samples from a Multivariate normal.\n",
"\n",
"Note that if $\\epsilon_i \\sim \\mathcal{N}(0, 1)$ iid and $LL^T = M$, then $\\xi := L \\epsilon \\sim \\mathcal{N}(0, M)$. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kHN_JZWFZDBR",
"outputId": "db977bdb-3515-4924-8c7f-6345b055287e"
},
"source": [
"# compute the Cholesky factor of M\n",
"L = torch.linalg.cholesky(M) \n",
"\n",
"# draw quasirandom samples epsilon\n",
"t_sobol_engine = torch.quasirandom.SobolEngine(\n",
" dimension=3, scramble=True, seed=29211402,\n",
")\n",
"sobol_samples = t_sobol_engine.draw(2)\n",
"epsilon = torch.distributions.Normal(0, 1).icdf(sobol_samples) # 2 x 3\n",
"\n",
"# correlate samples \n",
"xi = epsilon @ L.t() # 2 x 3\n",
"xi"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 2.8662, 1.9570, 2.1361],\n",
" [-1.4842, -1.0919, -1.4085]], grad_fn=<MmBackward>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 58
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4qEFutgScM7B"
},
"source": [
"We can now compute gradients of these samples w.r.t. our input using PyTorch's auto-differentiation feature"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fL5sQzh_cd0y",
"outputId": "42267098-7e0c-4360-94cb-51e89a6701c1"
},
"source": [
"# compute gradient of the first element of the first sample w.r.t x\n",
"torch.autograd.grad(xi[0, 0], x)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(tensor([1.1680, -0.0000, -0.0000]),)"
]
},
"metadata": {
"tags": []
},
"execution_count": 59
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "FUia92aphAfK"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment