Skip to content

Instantly share code, notes, and snippets.

@honno
Last active August 10, 2021 20:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save honno/d5c513df017af824603d0ab82213485b to your computer and use it in GitHub Desktop.
Save honno/d5c513df017af824603d0ab82213485b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "7dc7cf5d-59ec-4a6b-aeb9-246c5e9a4aab",
"metadata": {},
"source": [
"Experimental demo, intended to maybe be part of [this larger demo](https://nbviewer.jupyter.org/gist/honno/bf7bd9ecbf6926cb68a193a068a55033) at some point.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "ed4b0085-c122-4970-8f28-7f19ef6f0f31",
"metadata": {},
"source": [
"Firstly let's prepare the notebook for presentation purposes."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b0d4479c-41be-42ad-9092-455c2ac372e0",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install ipytest"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1f32d296-e8d6-481a-a080-83296e35671c",
"metadata": {},
"outputs": [],
"source": [
"import pytest\n",
"import ipytest\n",
"ipytest.autoconfig()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a187f4e9-9260-404b-b2f6-34b3a7cd78e1",
"metadata": {},
"outputs": [],
"source": [
"# the warnings hypothesis-array-api raises about Array API noncompliance get distracting for demo purposes\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"id": "353f77a7-48d1-447b-985b-616b9b8ec9b1",
"metadata": {},
"source": [
"Now let's install the necessary libraries."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "40692675-6ea3-4314-a509-90c120b6f23a",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install hypothesis hypothesis-array-api==0.0.5 numpy"
]
},
{
"cell_type": "markdown",
"id": "857f6942-b99b-4010-98da-f8d494692af0",
"metadata": {},
"source": [
"We're monkey patching NumPy to make it more Array API compatible. Note `xp` is short hand for \"an Array API module\".\n",
"\n",
"In the future we would be able to just import an Array API namespace:\n",
"```python\n",
"from numpy import array_api as xp\n",
"```\n",
"This should be [coming soon](https://github.com/numpy/numpy/pull/18585) to NumPy!"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8d79e90f-c35e-46e5-99bf-77ed6868b7f1",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"xp = np\n",
"xp.bool = np.bool_ # monkey patch correct bool dtype in xp namespace"
]
},
{
"cell_type": "markdown",
"id": "dc238a31-0b54-4ab6-9a89-cfb37088f20e",
"metadata": {},
"source": [
"Let's write a simple `closest()` method which only uses Array API endpoints."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8e4dc67f-393e-4331-a701-ce39ba327055",
"metadata": {},
"outputs": [],
"source": [
"def closest(x, val):\n",
" \"\"\"Returns the element in array that is closest to the passed value\"\"\"\n",
" diff_x = abs(x - val)\n",
" i = xp.argmin(diff_x)\n",
" return x[i]"
]
},
{
"cell_type": "markdown",
"id": "7045ea59-310a-4cc4-a0e4-c662fc04dc82",
"metadata": {},
"source": [
"Here are some simple examples to test against our `closest()` method—we just specificy the `x` and `val` arguments and see whether `closest(x, val)` matches with our expected result."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c958d5e6-858f-4303-8133-534cd71ce631",
"metadata": {},
"outputs": [],
"source": [
"@pytest.mark.parametrize(\n",
" \"array, value, result\",\n",
" [\n",
" (xp.linspace(0, 4, 5), 1, 1), # [0, 1, 2, 3, 4]\n",
" (xp.linspace(0, 4, 5), 2.4, 2), # [0, 1, 2, 3, 4]\n",
" (xp.linspace(-5, 5, 5), 2.6, 2.5), # [-5, -2.5, 0, 2.5, 5]\n",
" (xp.linspace(-5, 5, 5), -2, -2.5), # [-5, -2.5, 0, 2.5, 5]\n",
" (xp.ones(5), 1.2, 1), # [1, 1, 1, 1, 1]\n",
" (xp.asarray([2, 3, 1, 4, 0]), 1.2, 1),\n",
" ]\n",
")\n",
"def test_examples(array, value, result):\n",
" assert closest(array, value) == result"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e8debe7b-38ca-4ac7-8cef-1d2f841da681",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n",
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n",
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n",
"plugins: hypothesis-6.14.6, anyio-3.3.0\n",
"collected 6 items\n",
"\n",
"tmpx3hl9a76.py \u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\n",
"\u001b[32m======================================== \u001b[32m\u001b[1m6 passed\u001b[0m\u001b[32m in 0.04s\u001b[0m\u001b[32m =========================================\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-vk test_examples\")"
]
},
{
"cell_type": "markdown",
"id": "54500006-bd66-4b42-b8cc-93dea019f844",
"metadata": {},
"source": [
"Great, they all pass! Let's use [Hypothesis](https://github.com/HypothesisWorks/hypothesis/) to see if our *specification* of acceptable inputs all result in `closest()` returning a single result.\n",
"\n",
"Note that a single result in this case means a 0-dimensional array.\n",
"```pycon\n",
">>> x = xp.linspace(0, 8, 9).reshape((3, 3))\n",
">>> x\n",
"array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])\n",
">>> closest(x, 2.8)\n",
"array(3)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a499cc41-b73d-459b-bd27-f4eb31b7126d",
"metadata": {},
"outputs": [],
"source": [
"from hypothesis import given, settings, strategies as st\n",
"from hypothesis_array import get_strategies_namespace\n",
"\n",
"xps = get_strategies_namespace(xp)\n",
"\n",
"@given(\n",
" array=xps.arrays(dtype=\"int8\", shape=xps.array_shapes()),\n",
" value=xps.from_dtype(\"int8\"),\n",
")\n",
"def test_0d_results(array, value):\n",
" \"\"\"Result is a 0d array i.e. contains a single value\"\"\"\n",
" result = closest(array, value)\n",
" assert result.shape == ()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "788edca2-cdb1-499b-bdb4-dcc709f1a8bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n",
"============================================= FAILURES =============================================\n",
"\u001b[31m\u001b[1m_________________________________________ test_0d_results __________________________________________\u001b[0m\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(\n",
"> array=xps.arrays(dtype=\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m, shape=xps.array_shapes()),\n",
" value=xps.from_dtype(\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m),\n",
" )\n",
"\n",
"\u001b[1m\u001b[31m/tmp/ipykernel_60417/2300622674.py\u001b[0m:7: \n",
"_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n",
"\n",
"array = array([[0]], dtype=int8), value = 0\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(\n",
" array=xps.arrays(dtype=\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m, shape=xps.array_shapes()),\n",
" value=xps.from_dtype(\u001b[33m\"\u001b[39;49;00m\u001b[33mint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m),\n",
" )\n",
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_0d_results\u001b[39;49;00m(array, value):\n",
" \u001b[33m\"\"\"Result is a 0d array i.e. contains a single value\"\"\"\u001b[39;49;00m\n",
" result = closest(array, value)\n",
"> \u001b[94massert\u001b[39;49;00m result.shape == ()\n",
"\u001b[1m\u001b[31mE assert (1,) == ()\u001b[0m\n",
"\u001b[1m\u001b[31mE Left contains one more item: 1\u001b[0m\n",
"\u001b[1m\u001b[31mE Full diff:\u001b[0m\n",
"\u001b[1m\u001b[31mE - ()\u001b[0m\n",
"\u001b[1m\u001b[31mE + (1,)\u001b[0m\n",
"\n",
"\u001b[1m\u001b[31m/tmp/ipykernel_60417/2300622674.py\u001b[0m:13: AssertionError\n",
"-------------------------------------------- Hypothesis --------------------------------------------\n",
"Falsifying example: test_0d_results(\n",
" array=array([[0]], dtype=int8), value=0,\n",
")\n",
"===================================== short test summary info ======================================\n",
"FAILED tmpvsayf0__.py::test_0d_results - assert (1,) == ()\n",
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m6 deselected\u001b[0m\u001b[31m in 0.33s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-k test_0d_results\")"
]
},
{
"cell_type": "markdown",
"id": "708f32e2-c319-4706-b578-b55c4ee8dbb0",
"metadata": {},
"source": [
"Turns out our `closest()` method doesn't return a single, sensible result with multiple dimensional arrays. We forgot to manually test arrays with more than one dimension in our paramatrized `test_examples()` method so its a good thing our specification of shapes as `array_shapes()` caught this."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b6236868-c261-4f4f-93ae-073ea65ee9e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([6., 7., 8.])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = xp.linspace(0, 8, 9).reshape((3, 3)) # [[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n",
"closest(x, 1.8)"
]
},
{
"cell_type": "markdown",
"id": "c381f542-0c04-48f7-842e-22860d41314b",
"metadata": {},
"source": [
"The `xp.argmin()` we use returns an index `i` which assumes a flattened (1d) array. This is why we also sometimes get an out of bounds error in our Hypothesis-generated test cases."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ebc67a60-c559-4933-8fe6-fc69b0127a40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ERROR: index 5 is out of bounds for axis 0 with size 3\n"
]
}
],
"source": [
"x = xp.linspace(0, 8, 9).reshape((3, 3)) \n",
"try:\n",
" closest(x, 5.2)\n",
"except IndexError as e:\n",
" print(f\"ERROR: {e}\")"
]
},
{
"cell_type": "markdown",
"id": "26535c90-48cc-49cf-9c71-97d85e702d67",
"metadata": {},
"source": [
"So let's fix both of these bugs that Hypothesis found by first flattening `x` in our `closest()` method."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e4529f79-dbc5-416a-8f8c-20cf570ee7ac",
"metadata": {},
"outputs": [],
"source": [
"def closest(x, val):\n",
" flat_x = x.reshape((-1,))\n",
" diff_x = abs(flat_x - val)\n",
" i = xp.argmin(diff_x)\n",
" return flat_x[i]"
]
},
{
"cell_type": "markdown",
"id": "eb68130e-b7a5-422b-98d2-1ae64ed59d49",
"metadata": {},
"source": [
"To make sure its fixed we'll run the tests again."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b1800afc-3ffc-4ecb-b192-b716a6b3a570",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n",
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n",
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n",
"plugins: hypothesis-6.14.6, anyio-3.3.0\n",
"collected 7 items / 6 deselected / 1 selected\n",
"\n",
"tmpvjvpdqqp.py \u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\n",
"\u001b[32m================================= \u001b[32m\u001b[1m1 passed\u001b[0m, \u001b[33m6 deselected\u001b[0m\u001b[32m in 0.22s\u001b[0m\u001b[32m ==================================\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-vk 0d_results\")"
]
},
{
"cell_type": "markdown",
"id": "1d57490c-b5a3-4ee5-8f82-898362096de7",
"metadata": {},
"source": [
"How about other expected behaviour? We should test for every edge case!\n",
"\n",
"For example, infinitys should work just fine in `closest()` in both the array `x` and as the value `val`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "83090f59-adec-4db1-8b44-20d778406558",
"metadata": {},
"outputs": [],
"source": [
"@given(xps.arrays(dtype=\"float32\", shape=10))\n",
"def test_inf_value(array):\n",
" \"\"\"No errors when val=inf\"\"\"\n",
" closest(array, float(\"inf\"))\n",
"\n",
"def test_inf_array():\n",
" \"\"\"No errors when array is just infs\"\"\"\n",
" x = xp.full(10, float(\"inf\"))\n",
" closest(x, 42)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "f31db740-9434-4b9a-a30c-88854162b3bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n",
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n",
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n",
"plugins: hypothesis-6.14.6, anyio-3.3.0\n",
"collected 9 items / 7 deselected / 2 selected\n",
"\n",
"tmpbhkvhebt.py \u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\n",
"\u001b[32m================================= \u001b[32m\u001b[1m2 passed\u001b[0m, \u001b[33m7 deselected\u001b[0m\u001b[32m in 0.21s\u001b[0m\u001b[32m ==================================\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-vk inf\")"
]
},
{
"cell_type": "markdown",
"id": "e8206ce1-c695-4897-881c-fd03aecc7c3d",
"metadata": {},
"source": [
"What about NaN behaviour? I think `closest()` should be raising an error if array `x` is just full of NaNs... we can test for raised errors too!"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f2c2d88e-4dca-4f66-8d38-9d723740fb2f",
"metadata": {},
"outputs": [],
"source": [
"def test_nan_array_raises():\n",
" \"\"\"Error raised when array is just nans\"\"\"\n",
" x = xp.full(10, float(\"nan\"))\n",
" with pytest.raises(ValueError):\n",
" closest(x, 42)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2b1a0efa-1260-4452-a704-cb35facb8574",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n",
"============================================= FAILURES =============================================\n",
"\u001b[31m\u001b[1m______________________________________ test_nan_array_raises _______________________________________\u001b[0m\n",
"\n",
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_nan_array_raises\u001b[39;49;00m():\n",
" \u001b[33m\"\"\"Error raised when array is just nans\"\"\"\u001b[39;49;00m\n",
" x = xp.full(\u001b[94m10\u001b[39;49;00m, \u001b[96mfloat\u001b[39;49;00m(\u001b[33m\"\u001b[39;49;00m\u001b[33mnan\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m))\n",
" \u001b[94mwith\u001b[39;49;00m pytest.raises(\u001b[96mValueError\u001b[39;49;00m):\n",
"> closest(x, \u001b[94m42\u001b[39;49;00m)\n",
"\u001b[1m\u001b[31mE Failed: DID NOT RAISE <class 'ValueError'>\u001b[0m\n",
"\n",
"\u001b[1m\u001b[31m/tmp/ipykernel_60417/556185350.py\u001b[0m:5: Failed\n",
"===================================== short test summary info ======================================\n",
"FAILED tmpvadh1c8e.py::test_nan_array_raises - Failed: DID NOT RAISE <class 'ValueError'>\n",
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m9 deselected\u001b[0m\u001b[31m in 0.02s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-k nan\")"
]
},
{
"cell_type": "markdown",
"id": "38b2d54d-d9b6-4b3a-99a1-b5b6fd978060",
"metadata": {},
"source": [
"Nope, apparently `closest()` is returning something even if `x` is just a bunch of NaNs."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "40c8ed29-fb1e-4802-8eb3-5068d38f3833",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"nan"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = xp.full(10, xp.nan)\n",
"closest(x, 42)"
]
},
{
"cell_type": "markdown",
"id": "39174420-ece4-446f-a16f-e635b79c6c72",
"metadata": {},
"source": [
"Personally I don't like this behaviour—how can the number 42 be at any distance to a NaN?! Let's make sure we tell the user they're doing something nonsensical by raising an error."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "757cc3b3-e081-4b1f-a6b0-5fb11293b55e",
"metadata": {},
"outputs": [],
"source": [
"def closest(x, val):\n",
" if xp.all(xp.isnan(x)):\n",
" raise ValueError(\"no closest value when array is just nans\")\n",
" flat_x = x.reshape((-1,))\n",
" diff_x = abs(flat_x - val)\n",
" i = xp.argmin(diff_x)\n",
" return flat_x[i]"
]
},
{
"cell_type": "markdown",
"id": "fd2f3428-0694-4089-92a8-9d40f2fd8606",
"metadata": {},
"source": [
"Like before we can see our changes work by running the previously failing test again."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "0daea132-e448-41fc-a65a-0e52bae08245",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m======================================= test session starts ========================================\u001b[0m\n",
"platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1\n",
"rootdir: /home/honno/gdrive/GitHub/hypothesis-array-api/fluff/notebooks\n",
"plugins: hypothesis-6.14.6, anyio-3.3.0\n",
"collected 10 items / 9 deselected / 1 selected\n",
"\n",
"tmpo543bxzy.py \u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\n",
"\u001b[32m================================= \u001b[32m\u001b[1m1 passed\u001b[0m, \u001b[33m9 deselected\u001b[0m\u001b[32m in 0.01s\u001b[0m\u001b[32m ==================================\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-vk nan\")"
]
},
{
"cell_type": "markdown",
"id": "1da8078b-72a7-4b1c-97fe-2cddf295077c",
"metadata": {},
"source": [
"---\n",
"\n",
"**Ignore everything below!** I wanted to demonstrated Torch but have given up for now—I need to monkeypatch it to make it compatible enough with Array API to work with my Hypothesis strategies."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "12b854be-a0cd-4a56-b857-143cb00a692d",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_60417/2006531705.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"raise KeyboardInterrupt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e75af65-448e-4573-834c-9c41d1799d99",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75025712-018c-4727-96d7-672f18b429d4",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from copy import deepcopy\n",
"\n",
"torch_xp = torch\n",
"\n",
"torch_empty = deepcopy(torch.empty)\n",
"def torch_xp_empty(shape, dtype=None):\n",
" if isinstance(dtype, str):\n",
" dtype = getattr(torch, dtype)\n",
" if isinstance(shape, int):\n",
" return torch_empty(shape, dtype=dtype)\n",
" else:\n",
" return torch_empty(*shape, dtype=dtype)\n",
"torch_xp.empty = torch_xp_empty\n",
"\n",
"class XPTensor(torch.Tensor):\n",
" def __setitem__(self, k, v):\n",
" if self.ndim == 0 and isinstance(key, slice):\n",
" return self[0]\n",
" else:\n",
" return super().__getitem__(k)\n",
" def __setitem__(self, k, v):\n",
" if self.ndim == 0 and isinstance(key, slice):\n",
" self[0] = v\n",
" else:\n",
" super().__setitem__(k, v)\n",
"torch_xp.Tensor = XPTensor\n",
"\n",
"# distinguish previous Array API module as Numpy's\n",
"numpy_xp = xp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5d3a25b-74be-444f-a9bd-be6be3e8c565",
"metadata": {},
"outputs": [],
"source": [
"def closest(xp, x, val):\n",
" \"\"\"Returns the element in array that is closest to the passed value\"\"\"\n",
" if xp.all(xp.isnan(x)):\n",
" raise ValueError(\"no closest value when array is just nans\")\n",
" flat_x = x.reshape((-1,))\n",
" diff_x = abs(flat_x - val)\n",
" i = xp.argmin(diff_x)\n",
" return flat_x[i]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "083c98e0-7f70-4e2c-a9a1-0ac668e85ccc",
"metadata": {},
"outputs": [],
"source": [
"%%ipytest\n",
"\n",
"@pytest.mark.parametrize(\"xp\", [numpy_xp, torch_xp])\n",
"@pytest.mark.parametrize(\n",
" \"func_name, func_args, value, result\",\n",
" [\n",
" (\"linspace\", (0, 4, 5), 1, 1), # [0, 1, 2, 3, 4]\n",
" (\"linspace\", (0, 4, 5), 2.4, 2), # [0, 1, 2, 3, 4]\n",
" (\"linspace\", (-5, 5, 5), 2.6, 2.5), # [-5, -2.5, 0, 2.5, 5]\n",
" (\"linspace\", (-5, 5, 5), -2, -2.5), # [-5, -2.5, 0, 2.5, 5]\n",
" (\"ones\", [5], 1.2, 1), # [1, 1, 1, 1, 1]\n",
" (\"asarray\", [[2, 3, 1, 4, 0]], 1.2, 1),\n",
" ]\n",
")\n",
"def test_examples(xp, func_name, func_args, value, result):\n",
" array = getattr(xp, func_name)(*func_args)\n",
" assert closest(xp, array, value) == result\n",
" \n",
"import hypothesis_array as naive_xps\n",
"\n",
"@pytest.mark.parametrize(\"xp\", [numpy_xp, torch_xp])\n",
"@given(st.data())\n",
"\n",
"def test_0d_results(xp, data):\n",
" \"\"\"Result is a 0d array i.e. contains a single value\"\"\"\n",
" array = data.draw(\n",
" naive_xps.arrays(xp=xp, dtype=\"int8\", shape=xps.array_shapes())\n",
" )\n",
" value = data.draw(\n",
" naive_xps.from_dtype(xp=xp, dtype=\"int8\")\n",
" )\n",
" result = closest(xp, array, value)\n",
" assert result.shape == ()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment