Last active
August 10, 2021 20:03
-
-
Save honno/d5c513df017af824603d0ab82213485b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "7dc7cf5d-59ec-4a6b-aeb9-246c5e9a4aab", | |
"metadata": {}, | |
"source": [ | |
"Experimental demo, intended to maybe be part of [this larger demo](https://nbviewer.jupyter.org/gist/honno/bf7bd9ecbf6926cb68a193a068a55033) at some point.\n", | |
"\n", | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ed4b0085-c122-4970-8f28-7f19ef6f0f31", | |
"metadata": {}, | |
"source": [ | |
"Firstly let's prepare the notebook for presentation purposes." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "b0d4479c-41be-42ad-9092-455c2ac372e0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"!pip install ipytest" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "1f32d296-e8d6-481a-a080-83296e35671c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pytest\n", | |
"import ipytest\n", | |
"ipytest.autoconfig()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "a187f4e9-9260-404b-b2f6-34b3a7cd78e1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# the warnings hypothesis-array-api raises about Array API noncompliance get distracting for demo purposes\n", | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "353f77a7-48d1-447b-985b-616b9b8ec9b1", | |
"metadata": {}, | |
"source": [ | |
"Now let's install the necessary libraries." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "40692675-6ea3-4314-a509-90c120b6f23a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"!pip install hypothesis hypothesis-array-api==0.0.5 numpy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "857f6942-b99b-4010-98da-f8d494692af0", | |
"metadata": {}, | |
"source": [ | |
"We're monkey patching NumPy to make it more Array API compatible. Note `xp` is short hand for \"an Array API module\".\n", | |
"\n", | |
"In the future we would be able to just import an Array API namespace:\n", | |
"```python\n", | |
"from numpy import array_api as xp\n", | |
"```\n", | |
"This should be [coming soon](https://github.com/numpy/numpy/pull/18585) to NumPy!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "8d79e90f-c35e-46e5-99bf-77ed6868b7f1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"xp = np\n", | |
"xp.bool = np.bool_ # monkey patch correct bool dtype in xp namespace" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dc238a31-0b54-4ab6-9a89-cfb37088f20e", | |
"metadata": {}, | |
"source": [ | |
"Let's write a simple `closest()` method which only uses Array API endpoints." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "8e4dc67f-393e-4331-a701-ce39ba327055", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def closest(x, val):\n", | |
" \"\"\"Returns the element in array that is closest to the passed value\"\"\"\n", | |
" diff_x = abs(x - val)\n", | |
" i = xp.argmin(diff_x)\n", | |
" return x[i]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7045ea59-310a-4cc4-a0e4-c662fc04dc82", | |
"metadata": {}, | |
"source": [ | |
"Here are some simple examples to test against our `closest()` method—we just specificy the `x` and `val` arguments and see whether `closest(x, val)` matches with our expected result." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "c958d5e6-858f-4303-8133-534cd71ce631", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@pytest.mark.parametrize(\n", | |
" \"array, value, result\",\n", | |
" [\n", | |
" (xp.linspace(0, 4, 5), 1, 1), # [0, 1, 2, 3, 4]\n", | |
" (xp.linspace(0, 4, 5), 2.4, 2), # [0, 1, 2, 3, 4]\n", | |
" (xp.linspace(-5, 5, 5), 2.6, 2.5), # [-5, -2.5, 0, 2.5, 5]\n", | |
" (xp.linspace(-5, 5, 5), -2, -2.5), # [-5, -2.5, 0, 2.5, 5]\n", | |
" (xp.ones(5), 1.2, 1), # [1, 1, 1, 1, 1]\n", | |
" (xp.asarray([2, 3, 1, 4, 0]), 1.2, 1),\n", | |
" ]\n", | |
")\n", | |
"def test_examples(array, value, result):\n", | |
" assert closest(array, value) == result" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "e8debe7b-38ca-4ac7-8cef-1d2f841da681", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n", | |
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n", | |
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n", | |
"plugins: hypothesis-6.14.6, anyio-3.3.0\n", | |
"collected 6 items\n", | |
"\n", | |
"tmpx3hl9a76.py \u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n", | |
"\n", | |
"\u001b[32m======================================== \u001b[32m\u001b[1m6 passed\u001b[0m\u001b[32m in 0.04s\u001b[0m\u001b[32m =========================================\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"ipytest.run(\"-vk test_examples\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "54500006-bd66-4b42-b8cc-93dea019f844", | |
"metadata": {}, | |
"source": [ | |
"Great, they all pass! Let's use [Hypothesis](https://github.com/HypothesisWorks/hypothesis/) to see if our *specification* of acceptable inputs all result in `closest()` returning a single result.\n", | |
"\n", | |
"Note that a single result in this case means a 0-dimensional array.\n", | |
"```pycon\n", | |
">>> x = xp.linspace(0, 8, 9).reshape((3, 3))\n", | |
">>> x\n", | |
"array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])\n", | |
">>> closest(x, 2.8)\n", | |
"array(3)\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "a499cc41-b73d-459b-bd27-f4eb31b7126d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from hypothesis import given, settings, strategies as st\n", | |
"from hypothesis_array import get_strategies_namespace\n", | |
"\n", | |
"xps = get_strategies_namespace(xp)\n", | |
"\n", | |
"@given(\n", | |
" array=xps.arrays(dtype=\"int8\", shape=xps.array_shapes()),\n", | |
" value=xps.from_dtype(\"int8\"),\n", | |
")\n", | |
"def test_0d_results(array, value):\n", | |
" \"\"\"Result is a 0d array i.e. contains a single value\"\"\"\n", | |
" result = closest(array, value)\n", | |
" assert result.shape == ()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "788edca2-cdb1-499b-bdb4-dcc709f1a8bf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n", | |
"============================================= FAILURES =============================================\n", | |
"\u001b[31m\u001b[1m_________________________________________ test_0d_results __________________________________________\u001b[0m\n", | |
"\n", | |
" \u001b[37m@given\u001b[39;49;00m(\n", | |
"> array=xps.arrays(dtype=\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m, shape=xps.array_shapes()),\n", | |
" value=xps.from_dtype(\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m),\n", | |
" )\n", | |
"\n", | |
"\u001b[1m\u001b[31m/tmp/ipykernel_60417/2300622674.py\u001b[0m:7: \n", | |
"_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n", | |
"\n", | |
"array = array([[0]], dtype=int8), value = 0\n", | |
"\n", | |
" \u001b[37m@given\u001b[39;49;00m(\n", | |
" array=xps.arrays(dtype=\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m, shape=xps.array_shapes()),\n", | |
" value=xps.from_dtype(\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m),\n", | |
" )\n", | |
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_0d_results\u001b[39;49;00m(array, value):\n", | |
" \u001b[33m\"\"\"Result is a 0d array i.e. contains a single value\"\"\"\u001b[39;49;00m\n", | |
" result = closest(array, value)\n", | |
"> \u001b[94massert\u001b[39;49;00m result.shape == ()\n", | |
"\u001b[1m\u001b[31mE assert (1,) == ()\u001b[0m\n", | |
"\u001b[1m\u001b[31mE Left contains one more item: 1\u001b[0m\n", | |
"\u001b[1m\u001b[31mE Full diff:\u001b[0m\n", | |
"\u001b[1m\u001b[31mE - ()\u001b[0m\n", | |
"\u001b[1m\u001b[31mE + (1,)\u001b[0m\n", | |
"\n", | |
"\u001b[1m\u001b[31m/tmp/ipykernel_60417/2300622674.py\u001b[0m:13: AssertionError\n", | |
"-------------------------------------------- Hypothesis --------------------------------------------\n", | |
"Falsifying example: test_0d_results(\n", | |
" array=array([[0]], dtype=int8), value=0,\n", | |
")\n", | |
"===================================== short test summary info ======================================\n", | |
"FAILED tmpvsayf0__.py::test_0d_results - assert (1,) == ()\n", | |
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m6 deselected\u001b[0m\u001b[31m in 0.33s\u001b[0m\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"ipytest.run(\"-k test_0d_results\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "708f32e2-c319-4706-b578-b55c4ee8dbb0", | |
"metadata": {}, | |
"source": [ | |
"Turns out our `closest()` method doesn't return a single, sensible result with multiple dimensional arrays. We forgot to manually test arrays with more than one dimension in our paramatrized `test_examples()` method so its a good thing our specification of shapes as `array_shapes()` caught this." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "b6236868-c261-4f4f-93ae-073ea65ee9e0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([6., 7., 8.])" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = xp.linspace(0, 8, 9).reshape((3, 3)) # [[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n", | |
"closest(x, 1.8)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c381f542-0c04-48f7-842e-22860d41314b", | |
"metadata": {}, | |
"source": [ | |
"The `xp.argmin()` we use returns an index `i` which assumes a flattened (1d) array. This is why we also sometimes get an out of bounds error in our Hypothesis-generated test cases." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "ebc67a60-c559-4933-8fe6-fc69b0127a40", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ERROR: index 5 is out of bounds for axis 0 with size 3\n" | |
] | |
} | |
], | |
"source": [ | |
"x = xp.linspace(0, 8, 9).reshape((3, 3)) \n", | |
"try:\n", | |
" closest(x, 5.2)\n", | |
"except IndexError as e:\n", | |
" print(f\"ERROR: {e}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "26535c90-48cc-49cf-9c71-97d85e702d67", | |
"metadata": {}, | |
"source": [ | |
"So let's fix both of these bugs that Hypothesis found by first flattening `x` in our `closest()` method." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "e4529f79-dbc5-416a-8f8c-20cf570ee7ac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def closest(x, val):\n", | |
" flat_x = x.reshape((-1,))\n", | |
" diff_x = abs(flat_x - val)\n", | |
" i = xp.argmin(diff_x)\n", | |
" return flat_x[i]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "eb68130e-b7a5-422b-98d2-1ae64ed59d49", | |
"metadata": {}, | |
"source": [ | |
"To make sure its fixed we'll run the tests again." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "b1800afc-3ffc-4ecb-b192-b716a6b3a570", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n", | |
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n", | |
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n", | |
"plugins: hypothesis-6.14.6, anyio-3.3.0\n", | |
"collected 7 items / 6 deselected / 1 selected\n", | |
"\n", | |
"tmpvjvpdqqp.py \u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n", | |
"\n", | |
"\u001b[32m================================= \u001b[32m\u001b[1m1 passed\u001b[0m, \u001b[33m6 deselected\u001b[0m\u001b[32m in 0.22s\u001b[0m\u001b[32m ==================================\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"ipytest.run(\"-vk 0d_results\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1d57490c-b5a3-4ee5-8f82-898362096de7", | |
"metadata": {}, | |
"source": [ | |
"How about other expected behaviour? We should test for every edge case!\n", | |
"\n", | |
"For example, infinitys should work just fine in `closest()` in both the array `x` and as the value `val`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "83090f59-adec-4db1-8b44-20d778406558", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@given(xps.arrays(dtype=\"float32\", shape=10))\n", | |
"def test_inf_value(array):\n", | |
" \"\"\"No errors when val=inf\"\"\"\n", | |
" closest(array, float(\"inf\"))\n", | |
"\n", | |
"def test_inf_array():\n", | |
" \"\"\"No errors when array is just infs\"\"\"\n", | |
" x = xp.full(10, float(\"inf\"))\n", | |
" closest(x, 42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "f31db740-9434-4b9a-a30c-88854162b3bd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n", | |
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n", | |
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n", | |
"plugins: hypothesis-6.14.6, anyio-3.3.0\n", | |
"collected 9 items / 7 deselected / 2 selected\n", | |
"\n", | |
"tmpbhkvhebt.py \u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n", | |
"\n", | |
"\u001b[32m================================= \u001b[32m\u001b[1m2 passed\u001b[0m, \u001b[33m7 deselected\u001b[0m\u001b[32m in 0.21s\u001b[0m\u001b[32m ==================================\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"ipytest.run(\"-vk inf\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e8206ce1-c695-4897-881c-fd03aecc7c3d", | |
"metadata": {}, | |
"source": [ | |
"What about NaN behaviour? I think `closest()` should be raising an error if array `x` is just full of NaNs... we can test for raised errors too!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "f2c2d88e-4dca-4f66-8d38-9d723740fb2f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def test_nan_array_raises():\n", | |
" \"\"\"Error raised when array is just nans\"\"\"\n", | |
" x = xp.full(10, float(\"nan\"))\n", | |
" with pytest.raises(ValueError):\n", | |
" closest(x, 42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "2b1a0efa-1260-4452-a704-cb35facb8574", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n", | |
"============================================= FAILURES =============================================\n", | |
"\u001b[31m\u001b[1m______________________________________ test_nan_array_raises _______________________________________\u001b[0m\n", | |
"\n", | |
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_nan_array_raises\u001b[39;49;00m():\n", | |
" \u001b[33m\"\"\"Error raised when array is just nans\"\"\"\u001b[39;49;00m\n", | |
" x = xp.full(\u001b[94m10\u001b[39;49;00m, \u001b[96mfloat\u001b[39;49;00m(\u001b[33m\"\u001b[39;49;00m\u001b[33mnan\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m))\n", | |
" \u001b[94mwith\u001b[39;49;00m pytest.raises(\u001b[96mValueError\u001b[39;49;00m):\n", | |
"> closest(x, \u001b[94m42\u001b[39;49;00m)\n", | |
"\u001b[1m\u001b[31mE Failed: DID NOT RAISE <class 'ValueError'>\u001b[0m\n", | |
"\n", | |
"\u001b[1m\u001b[31m/tmp/ipykernel_60417/556185350.py\u001b[0m:5: Failed\n", | |
"===================================== short test summary info ======================================\n", | |
"FAILED tmpvadh1c8e.py::test_nan_array_raises - Failed: DID NOT RAISE <class 'ValueError'>\n", | |
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m9 deselected\u001b[0m\u001b[31m in 0.02s\u001b[0m\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"ipytest.run(\"-k nan\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "38b2d54d-d9b6-4b3a-99a1-b5b6fd978060", | |
"metadata": {}, | |
"source": [ | |
"Nope, apparently `closest()` is returning something even if `x` is just a bunch of NaNs." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "40c8ed29-fb1e-4802-8eb3-5068d38f3833", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"nan" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = xp.full(10, xp.nan)\n", | |
"closest(x, 42)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "39174420-ece4-446f-a16f-e635b79c6c72", | |
"metadata": {}, | |
"source": [ | |
"Personally I don't like this behaviour—how can the number 42 be at any distance to a NaN?! Let's make sure we tell the user they're doing something nonsensical by raising an error." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "757cc3b3-e081-4b1f-a6b0-5fb11293b55e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def closest(x, val):\n", | |
" if xp.all(xp.isnan(x)):\n", | |
" raise ValueError(\"no closest value when array is just nans\")\n", | |
" flat_x = x.reshape((-1,))\n", | |
" diff_x = abs(flat_x - val)\n", | |
" i = xp.argmin(diff_x)\n", | |
" return flat_x[i]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fd2f3428-0694-4089-92a8-9d40f2fd8606", | |
"metadata": {}, | |
"source": [ | |
"Like before we can see our changes work by running the previously failing test again." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "0daea132-e448-41fc-a65a-0e52bae08245", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n", | |
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n", | |
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n", | |
"plugins: hypothesis-6.14.6, anyio-3.3.0\n", | |
"collected 10 items / 9 deselected / 1 selected\n", | |
"\n", | |
"tmpo543bxzy.py \u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n", | |
"\n", | |
"\u001b[32m================================= \u001b[32m\u001b[1m1 passed\u001b[0m, \u001b[33m9 deselected\u001b[0m\u001b[32m in 0.01s\u001b[0m\u001b[32m ==================================\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"ipytest.run(\"-vk nan\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1da8078b-72a7-4b1c-97fe-2cddf295077c", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"**Ignore everything below!** I wanted to demonstrated Torch but have given up for now—I need to monkeypatch it to make it compatible enough with Array API to work with my Hypothesis strategies." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "12b854be-a0cd-4a56-b857-143cb00a692d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m/tmp/ipykernel_60417/2006531705.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"raise KeyboardInterrupt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0e75af65-448e-4573-834c-9c41d1799d99", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"!pip install torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "75025712-018c-4727-96d7-672f18b429d4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from copy import deepcopy\n", | |
"\n", | |
"torch_xp = torch\n", | |
"\n", | |
"torch_empty = deepcopy(torch.empty)\n", | |
"def torch_xp_empty(shape, dtype=None):\n", | |
" if isinstance(dtype, str):\n", | |
" dtype = getattr(torch, dtype)\n", | |
" if isinstance(shape, int):\n", | |
" return torch_empty(shape, dtype=dtype)\n", | |
" else:\n", | |
" return torch_empty(*shape, dtype=dtype)\n", | |
"torch_xp.empty = torch_xp_empty\n", | |
"\n", | |
"class XPTensor(torch.Tensor):\n", | |
" def __setitem__(self, k, v):\n", | |
" if self.ndim == 0 and isinstance(key, slice):\n", | |
" return self[0]\n", | |
" else:\n", | |
" return super().__getitem__(k)\n", | |
" def __setitem__(self, k, v):\n", | |
" if self.ndim == 0 and isinstance(key, slice):\n", | |
" self[0] = v\n", | |
" else:\n", | |
" super().__setitem__(k, v)\n", | |
"torch_xp.Tensor = XPTensor\n", | |
"\n", | |
"# distinguish previous Array API module as Numpy's\n", | |
"numpy_xp = xp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a5d3a25b-74be-444f-a9bd-be6be3e8c565", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def closest(xp, x, val):\n", | |
" \"\"\"Returns the element in array that is closest to the passed value\"\"\"\n", | |
" if xp.all(xp.isnan(x)):\n", | |
" raise ValueError(\"no closest value when array is just nans\")\n", | |
" flat_x = x.reshape((-1,))\n", | |
" diff_x = abs(flat_x - val)\n", | |
" i = xp.argmin(diff_x)\n", | |
" return flat_x[i]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "083c98e0-7f70-4e2c-a9a1-0ac668e85ccc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%ipytest\n", | |
"\n", | |
"@pytest.mark.parametrize(\"xp\", [numpy_xp, torch_xp])\n", | |
"@pytest.mark.parametrize(\n", | |
" \"func_name, func_args, value, result\",\n", | |
" [\n", | |
" (\"linspace\", (0, 4, 5), 1, 1), # [0, 1, 2, 3, 4]\n", | |
" (\"linspace\", (0, 4, 5), 2.4, 2), # [0, 1, 2, 3, 4]\n", | |
" (\"linspace\", (-5, 5, 5), 2.6, 2.5), # [-5, -2.5, 0, 2.5, 5]\n", | |
" (\"linspace\", (-5, 5, 5), -2, -2.5), # [-5, -2.5, 0, 2.5, 5]\n", | |
" (\"ones\", [5], 1.2, 1), # [1, 1, 1, 1, 1]\n", | |
" (\"asarray\", [[2, 3, 1, 4, 0]], 1.2, 1),\n", | |
" ]\n", | |
")\n", | |
"def test_examples(xp, func_name, func_args, value, result):\n", | |
" array = getattr(xp, func_name)(*func_args)\n", | |
" assert closest(xp, array, value) == result\n", | |
" \n", | |
"import hypothesis_array as naive_xps\n", | |
"\n", | |
"@pytest.mark.parametrize(\"xp\", [numpy_xp, torch_xp])\n", | |
"@given(st.data())\n", | |
"\n", | |
"def test_0d_results(xp, data):\n", | |
" \"\"\"Result is a 0d array i.e. contains a single value\"\"\"\n", | |
" array = data.draw(\n", | |
" naive_xps.arrays(xp=xp, dtype=\"int8\", shape=xps.array_shapes())\n", | |
" )\n", | |
" value = data.draw(\n", | |
" naive_xps.from_dtype(xp=xp, dtype=\"int8\")\n", | |
" )\n", | |
" result = closest(xp, array, value)\n", | |
" assert result.shape == ()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment