Skip to content

Instantly share code, notes, and snippets.

@honno
Last active September 25, 2021 17:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save honno/7a160c08d6a4bdebfbcd9ea4fb06303f to your computer and use it in GitHub Desktop.
Save honno/7a160c08d6a4bdebfbcd9ea4fb06303f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "6a4d5159-a519-45db-8198-e4a30bab649c",
"metadata": {},
"source": [
"Over the summer,\n",
"I've been interning at Quansight Labs\n",
"to develop testing tools\n",
"for the developers and users\n",
"of the upcoming [Array API standard](https://data-apis.org/array-api/latest/).\n",
"Specifically,\n",
"I contributed \"strategies\"\n",
"to the testing library [Hypothesis](https://github.com/HypothesisWorks/hypothesis/),\n",
"which I'm excited to announce\n",
"are now available in [`hypothesis.extra.array_api`](https://hypothesis.readthedocs.io/en/latest/numpy.html#array-api).\n",
"Check out the primary [pull request](https://github.com/HypothesisWorks/hypothesis/pull/3065) I made\n",
"for more background.\n",
"\n",
"This blog post is for anyone developing array-consuming methods\n",
"(think SciPy and scikit-learn)\n",
"and is new to property-based testing.\n",
"I demonstrate a typical workflow\n",
"of testing with Hypothesis\n",
"whilst writing an array-consuming function that works for *all* [libraries adopting the Array API](https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders),\n",
"catching bugs before your users do.\n",
"\n",
"<!-- TEASER_END -->"
]
},
{
"cell_type": "markdown",
"id": "654396e0",
"metadata": {},
"source": [
"## Before we begin\n",
"\n",
"Hypothesis shipped with its Array API strategies in [version 6.21](https://hypothesis.readthedocs.io/en/latest/changes.html#v6-21-0).\n",
"We also need to use NumPy >= 1.22\n",
"so that we can test with its\n",
"[recently merged](https://github.com/numpy/numpy/pull/18585)\n",
"Array API implementation—this hasn't been released just yet,\n",
"so I would recommend installing a [nightly build](https://anaconda.org/scipy-wheels-nightly/numpy).\n",
"\n",
"I will be using\n",
"the excellent [ipytest](https://github.com/chmp/ipytest/) extension\n",
"to nicely run tests in Jupyter\n",
"as if we were using [pytest](https://github.com/pytest-dev/pytest/) proper.\n",
"For pretty printing I use the superb [Rich](https://github.com/willmcgugan/rich) library,\n",
"where I simply override Python's builtin `print`\n",
"with [`rich.print`](https://rich.readthedocs.io/en/stable/reference/init.html#rich.print).\n",
"I also suppress all warnings\n",
"for convenience's sake."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c4a3cb4a-c6db-442b-9507-bedd03aabbe0",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install hypothesis>=6.21\n",
"!pip install -i https://pypi.anaconda.org/scipy-wheels-nightly/simple numpy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "43dd3731-a8d2-47b3-b88e-10eddc645f34",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install ipytest\n",
"import ipytest; ipytest.autoconfig(display_columns=80)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4c7c2561-0671-4763-b5e0-935a1957b68b",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install rich\n",
"from rich import print"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2f5d822b",
"metadata": {},
"outputs": [],
"source": [
"import warnings; warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"id": "06545d86-904a-44ca-83c5-d5304a251d55",
"metadata": {},
"source": [
"## What the Array API enables\n",
"\n",
"The [API](https://data-apis.org/array-api/latest/) standardises functionality of array libraries,\n",
"which has [numerous benefits](https://data-apis.org/array-api/latest/use_cases.html)\n",
"for both developers and users.\n",
"I recommend reading the [Data APIs announcement post](https://data-apis.org/blog/announcing_the_consortium/)\n",
"to get a better idea of how the API is being shaped,\n",
"but for our purposes it works an awful lot like NumPy.\n",
"\n",
"The most exciting prospect for me\n",
"is being able to easily write an array-consuming method\n",
"that works with all the adopting libraries.\n",
"Let's try writing this method\n",
"to calculate the cumulative sums of an array:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "893afcdb-66cb-4d54-b9b0-071c3f411e82",
"metadata": {},
"outputs": [],
"source": [
"def cumulative_sums(x):\n",
" \"\"\"Return the cumulative sums of the elements of the input.\"\"\"\n",
" xp = x.__array_namespace__()\n",
" \n",
" result = xp.empty(x.size, dtype=x.dtype)\n",
" result[0] = x[0]\n",
" for i in range(1, x.size):\n",
" result[i] = result[i - 1] + x[i]\n",
" \n",
" return result"
]
},
{
"cell_type": "markdown",
"id": "ed669b65",
"metadata": {},
"source": [
"The all-important\n",
"[`__array_namespace__()`](https://data-apis.org/array-api/latest/API_specification/array_object.html#method-array-namespace) method\n",
"allows array-consuming methods to get the array's respective Array API module.\n",
"Conventionally we assign it to the variable `xp`.\n",
"\n",
"From there you just need to rely on the guarantees of the Array API\n",
"to support NumPy, TensorFlow, PyTorch, CuPy, etc.\n",
"all in one simple method!"
]
},
{
"cell_type": "markdown",
"id": "6e824f0f",
"metadata": {},
"source": [
"## Good ol' unit tests\n",
"\n",
"I hope you'd want write some tests at some point 😉\n",
"\n",
"We can import NumPy's Array API implementation\n",
"and test with that for now,\n",
"although in the future it'd be a good idea to try other implementations\n",
"(see [related Hypothesis issue](https://github.com/HypothesisWorks/hypothesis/issues/3085)).\n",
"We don't `import numpy as np`,\n",
"but instead import NumPy's new module `numpy.array_api`,\n",
"which exists to comply with the Array API standard\n",
"where `numpy` proper can not\n",
"(namely so NumPy can keep backwards compatibility)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cb1e47fe-863a-4ff6-913a-572da5985d04",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\u001b[32m\u001b[32m\u001b[1m1 passed\u001b[0m\u001b[32m in 0.02s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"from numpy import array_api as nxp\n",
"\n",
"def test_cumulative_sums():\n",
" x = nxp.asarray([0, 1, 2, 3, 4])\n",
" assert nxp.all(cumulative_sums(x) == nxp.asarray([0, 1, 3, 6, 10]))\n",
" \n",
"ipytest.run()"
]
},
{
"cell_type": "markdown",
"id": "098c6b37",
"metadata": {},
"source": [
"I would probably write a [parametrized](https://docs.pytest.org/en/stable/parametrize.html) test here\n",
"and write cases to cover all the interesting scenarios I can think of.\n",
"Whatever we do,\n",
"we will definitely miss some edge cases.\n",
"What if we could catch bugs\n",
"we would never think of ourselves?"
]
},
{
"cell_type": "markdown",
"id": "fd6530d8-3072-464e-8322-a51a877f5b70",
"metadata": {},
"source": [
"## Testing our assumptions with Hypothesis\n",
"\n",
"<!-- I would put this in a quote block, but lists look bad with the blog's style -->\n",
"\n",
"Hypothesis is a property-based testing library. To lift from their excellent [docs](https://hypothesis.readthedocs.io/en/latest/index.html),\n",
"think of a normal unit test as being something like the following:\n",
"1. Set up some data.\n",
"2. Perform some operations on the data.\n",
"3. Assert something about the result.\n",
"\n",
"Hypothesis lets you write tests which instead look like this:\n",
"1. For all data matching some specification.\n",
"2. Perform some operations on the data.\n",
"3. Assert something about the result.\n",
"\n",
"You almost certainly will find new bugs with Hypothesis\n",
"thanks to how it cleverly fuzzes your specifications,\n",
"but the package really shines in how it [\"reduces\" failing test cases](https://drmaciver.github.io/papers/reduction-via-generation-preview.pdf)\n",
"to present only the minimal reproducers that trigger said bugs.\n",
"This demo will showcase both its power and user-friendliness.\n",
"\n",
"Let's try testing a simple assumption that we can make about our `cumulative_sums()` method:\n",
"\n",
"> For an array with positive elements,\n",
"> its cumulative sums should only increment or remain the same per step.\n",
"\n",
"<!--Formally we might express this assumption as $\\forall i \\in \\{1,\\ldots,\\vert x \\vert \\}.f(x)_i - f(x)_{i-1} \\geq 0$.-->\n",
"<!--Formally you might specify this assumption as\n",
"\"if $A$ is a $n$-lengthed ordered set\n",
"containing values $v$ that satisfy $v\\in\\mathbb{R}$ and $v\\geq0$,\n",
"for the cumulative sums function $f$ defined as $f(A)_j = \\sum_{i=1}^j A_i$,\n",
"when $j > 1$ the following is always true: $f(A)_j \\geq f(A)_{j-1}$.\"-->\n",
"\n",
"We can write a simple enough Hypothesis-powered test method for this:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e16ec818-c666-4951-bf4b-5c49d226e417",
"metadata": {},
"outputs": [],
"source": [
"from hypothesis import given\n",
"from hypothesis.extra.array_api import make_strategies_namespace\n",
"\n",
"xps = make_strategies_namespace(nxp)\n",
"\n",
"@given(xps.arrays(dtype=\"uint8\", shape=10))\n",
"def test_positive_arrays_have_incrementing_sums(x):\n",
" a = cumulative_sums(x)\n",
" assert nxp.all(a[1:] >= a[:-1])"
]
},
{
"cell_type": "markdown",
"id": "8a20a10e-f7ec-48ac-bafb-85e51f570263",
"metadata": {},
"source": [
"As the Array API tools provided by Hypothesis\n",
"are agnostic to the adopting array/tensor libraries,\n",
"we first need to bind an implementation\n",
"via [`make_strategies_namespace()`](https://hypothesis.readthedocs.io/en/latest/numpy.html#hypothesis.extra.array_api.make_strategies_namespace).\n",
"Passing `numpy.array_api` will give us\n",
"a [`SimpleNamespace`](https://docs.python.org/3/library/types.html#types.SimpleNamespace)\n",
"to use these tools for NumPy's Array API implementation.\n",
"\n",
"The [`@given()`](https://hypothesis.readthedocs.io/en/latest/details.html#hypothesis.given) decorator\n",
"tells Hypothesis what values it should generate for our test method.\n",
"In this case\n",
"[`xps.arrays()`](https://hypothesis.readthedocs.io/en/latest/numpy.html#xps.arrays) is a \"search strategy\" \n",
"that specifies Array API-compliant arrays from `numpy.array_api`\n",
"should be generated.\n",
"\n",
"In this case,\n",
"`shape=10` specifies the arrays generated are 1-dimensional and of size 10,\n",
"and `dtype=\"uint8\"` specifies they should contain unsigned integers\n",
"(which is handy for our test method as uints are always positive).\n",
"Let's quickly see a small sample of the arrays Hypothesis can generate:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d60a6f47-4ef0-41dc-ab17-8506d5203937",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">239</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">211</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">226</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">129</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">13</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">235</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">163</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m239\u001b[0m, \u001b[1;36m211\u001b[0m, \u001b[1;36m226\u001b[0m, \u001b[1;36m129\u001b[0m, \u001b[1;36m31\u001b[0m, \u001b[1;36m13\u001b[0m, \u001b[1;36m80\u001b[0m, \u001b[1;36m235\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m163\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">164</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">175</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">111</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">63</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">241</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">64</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">201</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">173</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">117</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m164\u001b[0m, \u001b[1;36m175\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m111\u001b[0m, \u001b[1;36m63\u001b[0m, \u001b[1;36m241\u001b[0m, \u001b[1;36m64\u001b[0m, \u001b[1;36m201\u001b[0m, \u001b[1;36m173\u001b[0m, \u001b[1;36m117\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">106</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">149</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">210</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">230</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">58</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">37</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">66</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">153</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">203</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">181</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m106\u001b[0m, \u001b[1;36m149\u001b[0m, \u001b[1;36m210\u001b[0m, \u001b[1;36m230\u001b[0m, \u001b[1;36m58\u001b[0m, \u001b[1;36m37\u001b[0m, \u001b[1;36m66\u001b[0m, \u001b[1;36m153\u001b[0m, \u001b[1;36m203\u001b[0m, \u001b[1;36m181\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">93</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m93\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">16</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m16\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">172</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m172\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">129</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m129\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">111</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m111\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">67</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m67\u001b[0m, \u001b[1;36m0\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">Array</span><span style=\"font-weight: bold\">([</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">255</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">254</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">253</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">252</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">251</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">250</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">249</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">248</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">247</span><span style=\"font-weight: bold\">]</span>, <span style=\"color: #808000; text-decoration-color: #808000\">dtype</span>=<span style=\"color: #800080; text-decoration-color: #800080\">uint8</span><span style=\"font-weight: bold\">)</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mArray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0\u001b[0m, \u001b[1;36m255\u001b[0m, \u001b[1;36m254\u001b[0m, \u001b[1;36m253\u001b[0m, \u001b[1;36m252\u001b[0m, \u001b[1;36m251\u001b[0m, \u001b[1;36m250\u001b[0m, \u001b[1;36m249\u001b[0m, \u001b[1;36m248\u001b[0m, \u001b[1;36m247\u001b[0m\u001b[1m]\u001b[0m, \u001b[33mdtype\u001b[0m=\u001b[35muint8\u001b[0m\u001b[1m)\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000\">...</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[33m...\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for _ in range(10):\n",
" x = xps.arrays(dtype=\"uint8\", shape=10, unique=True).example()\n",
" print(repr(x))\n",
"print(\"...\")"
]
},
{
"cell_type": "markdown",
"id": "3d92b309-6d61-4405-9121-a43d69548ebc",
"metadata": {},
"source": [
"How Hypothesis \"draws\" from its strategies can look rather unremarkable at first.\n",
"A small sample of draws might look fairly uniform\n",
"but trust that strategies will end up covering all kinds of edge cases.\n",
"Importantly it will cover these cases efficiently\n",
"so that Hypothesis-powered tests are *relatively* quick to run on your machine.\n",
"\n",
"All our test method does is get the cumulative sums array `a`\n",
"that is returned from `cumulative_sums(x)`,\n",
"and then check that every element `a[i]`\n",
"is greater than or equal to `a[i-1]`.\n",
"\n",
"Time to run it!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3844f19b-32a4-4673-8467-03164942927b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n",
"=================================== FAILURES ===================================\n",
"\u001b[31m\u001b[1m_________________ test_positive_arrays_have_incrementing_sums __________________\u001b[0m\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(xps.arrays(dtype=\u001b[33m\"\u001b[39;49;00m\u001b[33muint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m, shape=\u001b[94m10\u001b[39;49;00m))\n",
"> \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_positive_arrays_have_incrementing_sums\u001b[39;49;00m(x):\n",
"\n",
"\u001b[1m\u001b[31m<cell>\u001b[0m:7: \n",
"_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n",
"\n",
"x = Array([26, 26, 26, 26, 26, 26, 26, 26, 26, 26], dtype=uint8)\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(xps.arrays(dtype=\u001b[33m\"\u001b[39;49;00m\u001b[33muint8\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m, shape=\u001b[94m10\u001b[39;49;00m))\n",
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_positive_arrays_have_incrementing_sums\u001b[39;49;00m(x):\n",
" a = cumulative_sums(x)\n",
"> \u001b[94massert\u001b[39;49;00m nxp.all(a[\u001b[94m1\u001b[39;49;00m:] >= a[:-\u001b[94m1\u001b[39;49;00m])\n",
"\u001b[1m\u001b[31mE assert Array(False, dtype=bool)\u001b[0m\n",
"\u001b[1m\u001b[31mE + where Array(False, dtype=bool) = <function all at 0x7f2d48cc2430>(Array([ 52, 78, 104, 130, 156, 182, 208, 234, 4], dtype=uint8) >= Array([ 26, 52, 78, 104, 130, 156, 182, 208, 234], dtype=uint8))\u001b[0m\n",
"\u001b[1m\u001b[31mE + where <function all at 0x7f2d48cc2430> = nxp.all\u001b[0m\n",
"\n",
"\u001b[1m\u001b[31m<cell>\u001b[0m:9: AssertionError\n",
"---------------------------------- Hypothesis ----------------------------------\n",
"Falsifying example: test_positive_arrays_have_incrementing_sums(\n",
" x=Array([26, 26, 26, 26, 26, 26, 26, 26, 26, 26], dtype=uint8),\n",
")\n",
"=========================== short test summary info ============================\n",
"FAILED <cell>::test_positive_arrays_have_incrementing_sums - assert A...\n",
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m1 deselected\u001b[0m\u001b[31m in 0.17s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"ipytest.run(\"-k positive_arrays_have_incrementing_sums\", \"--hypothesis-seed=3\")"
]
},
{
"cell_type": "markdown",
"id": "4d9482e0-33e2-46be-8f43-affd4074818f",
"metadata": {},
"source": [
"Hypothesis has tested our assumption\n",
"and told us we're wrong.\n",
"It provides us with the following falsifying example:\n",
"\n",
"```python\n",
">>> x = xp.full(10, 26, dtype=xp.uint8)\n",
">>> x\n",
"Array([ 26, 26, 26, 26, 26, 26, 26, 26, 26, 26], dtype=uint8)\n",
">>> cumulative_sums(x)\n",
"Array([ 26, 52, 78, 104, 130, 156, 182, 208, 234, 4], dtype=uint8)\n",
"```\n",
"\n",
"You can see that an overflow error has occurred for the final cumulative sum,\n",
"as 234 + 26 (260) cannot be represented in 8-bit unsigned integers.\n",
"\n",
"Let's try promoting the dtype of the cumulative sums array\n",
"so that it can represent larger numbers,\n",
"and then we can run the test again."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ab2389bd-4a38-4aa5-a6c2-efb81b78cf2d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[33m\u001b[33m\u001b[1m2 deselected\u001b[0m\u001b[33m in 0.00s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"def max_dtype(xp, dtype):\n",
" if dtype in [getattr(xp, name) for name in (\"int8\", \"int16\", \"int32\", \"int64\")]:\n",
" return xp.int64\n",
" elif dtype in [getattr(xp, name) for name in (\"uint8\", \"uint16\", \"uint32\", \"uint64\")]:\n",
" return xp.uint64\n",
" else:\n",
" return xp.float64\n",
"\n",
"def cumulative_sums(x):\n",
" xp = x.__array_namespace__()\n",
" \n",
" result = xp.empty(x.size, dtype=max_dtype(xp, x.dtype))\n",
" result[0] = x[0]\n",
" for i in range(1, x.size):\n",
" result[i] = result[i - 1] + x[i]\n",
" \n",
" return result\n",
"\n",
"ipytest.run(\"-k cumulative_sums_uint8_arrays_accumulate\")"
]
},
{
"cell_type": "markdown",
"id": "2bdbfbc3-0508-40c8-8828-f1822e3dbbdb",
"metadata": {},
"source": [
"You can see another assumption about our code is:\n",
"\n",
"> We can find the cumulative sums of arrays of any scalar dtype.\n",
"\n",
"We should cover this assumption in our test method `test_positive_arrays_have_incrementing_sums`\n",
"by passing child search strategies\n",
"into our [`xps.arrays()`](https://hypothesis.readthedocs.io/en/latest/numpy.html#xps.arrays) parent strategy.\n",
"Specifying `dtype` as [`xps.scalar_dtypes()`](https://hypothesis.readthedocs.io/en/latest/numpy.html#xps.scalar_dtypes)\n",
"will tell Hypothesis to generate arrays of all scalar dtypes.\n",
"To specify that these array values should be positive,\n",
"we can just pass keyword arguments to the underlying\n",
"value generating strategy [`xps.from_dtype()`](https://hypothesis.readthedocs.io/en/latest/numpy.html#xps.from_dtype)\n",
"via `elements={\"min_value\": 0}`.\n",
"\n",
"And while we're at it, let's make sure to cover another assumption:\n",
"\n",
"> We can find the cumulative sums of arrays with multiple dimensions.\n",
"\n",
"Specifying `shape` as [`xps.array_shapes()`](https://hypothesis.readthedocs.io/en/latest/numpy.html#xps.array_shapes)\n",
"will tell Hypothesis to generate arrays of various dimensionality and sizes.\n",
"We can [filter](https://hypothesis.readthedocs.io/en/latest/data.html#filtering)\n",
"this strategy with `lambda s: prod(s) > 1`\n",
"so that always `x.size > 1`,\n",
"allowing our test code to still work."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "713e4267-2615-4e5f-bb35-0d1cf6f57278",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n",
"=================================== FAILURES ===================================\n",
"\u001b[31m\u001b[1m_________________ test_positive_arrays_have_incrementing_sums __________________\u001b[0m\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(\n",
"> xps.arrays(\n",
" dtype=xps.scalar_dtypes(),\n",
" shape=xps.array_shapes().filter(\u001b[94mlambda\u001b[39;49;00m s: prod(s) > \u001b[94m1\u001b[39;49;00m),\n",
" elements={\u001b[33m\"\u001b[39;49;00m\u001b[33mmin_value\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: \u001b[94m0\u001b[39;49;00m},\n",
" )\n",
" )\n",
"\u001b[1m\u001b[31mE hypothesis.errors.MultipleFailures: Hypothesis found 2 distinct failures.\u001b[0m\n",
"\n",
"\u001b[1m\u001b[31m<cell>\u001b[0m:5: MultipleFailures\n",
"---------------------------------- Hypothesis ----------------------------------\n",
"Falsifying example: test_positive_arrays_have_incrementing_sums(\n",
" x=Array([[False, False]], dtype=bool),\n",
")\n",
"TypeError: only size-1 arrays can be converted to Python scalars\n",
"\n",
"The above exception was the direct cause of the following exception:\n",
"\n",
"Traceback (most recent call last):\n",
" <cell>, line 12, in test_positive_arrays_have_incrementing_sums\n",
" a = cumulative_sums(x)\n",
" <cell>, line 13, in cumulative_sums\n",
" result[0] = x[0]\n",
" File \"<env>/numpy/array_api/_array_object.py\", line 657, in __setitem__\n",
" self._array.__setitem__(key, asarray(value)._array)\n",
"ValueError: setting an array element with a sequence.\n",
"\n",
"Falsifying example: test_positive_arrays_have_incrementing_sums(\n",
" x=Array([False, False], dtype=bool),\n",
")\n",
"Traceback (most recent call last):\n",
" <cell>, line 12, in test_positive_arrays_have_incrementing_sums\n",
" a = cumulative_sums(x)\n",
" <cell>, line 15, in cumulative_sums\n",
" result[i] = result[i - 1] + x[i]\n",
" File \"<env>/numpy/array_api/_array_object.py\", line 362, in __add__\n",
" other = self._check_allowed_dtypes(other, \"numeric\", \"__add__\")\n",
" File \"<env>/numpy/array_api/_array_object.py\", line 125, in _check_allowed_dtypes\n",
" raise TypeError(f\"Only {dtype_category} dtypes are allowed in {op}\")\n",
"TypeError: Only numeric dtypes are allowed in __add__\n",
"=========================== short test summary info ============================\n",
"FAILED <cell>::test_positive_arrays_have_incrementing_sums - hypothes...\n",
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m1 deselected\u001b[0m\u001b[31m in 0.38s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"from math import prod\n",
"from hypothesis import settings\n",
"\n",
"@given(\n",
" xps.arrays(\n",
" dtype=xps.scalar_dtypes(),\n",
" shape=xps.array_shapes().filter(lambda s: prod(s) > 1),\n",
" elements={\"min_value\": 0},\n",
" )\n",
")\n",
"def test_positive_arrays_have_incrementing_sums(x):\n",
" a = cumulative_sums(x)\n",
" assert nxp.all(a[1:] >= a[:-1])\n",
" \n",
"ipytest.run(\"-k positive_arrays_have_incrementing_sums\", \"--hypothesis-seed=3\")"
]
},
{
"cell_type": "markdown",
"id": "73fd0ff6",
"metadata": {},
"source": [
"Again Hypothesis has proved our assumptions wrong,\n",
"and this time it's found two problems.\n",
"\n",
"Firstly, our `cumulative_sums()` method doesn't adjust for boolean arrays,\n",
"so we get an error when we add two `bool` values together.\n",
"\n",
"```python\n",
">>> x = xp.zeros(2, dtype=xp.bool)\n",
">>> x\n",
"Array([False, False], dtype=bool)\n",
">>> cumulative_sums(x)\n",
"Traceback:\n",
" <cell>, line 15, in cumulative_sums\n",
" result[i] = result[i - 1] + x[i]\n",
" ...\n",
"TypeError: Only numeric dtypes are allowed in __add__\n",
"```\n",
" \n",
"Secondly, our `cumulative_sums()` method is assuming arrays are 1-dimensional,\n",
"so we get an error when we wrongly\n",
"assume `x[0]` will always return a single scalar\n",
"(technically a 0-dimensional array).\n",
"\n",
"```python\n",
">>> x = xp.zeros((1, 2), dtype=xp.bool)\n",
">>> x\n",
"Array([[False, False]], dtype=bool)\n",
">>> cumulative_sums(x)\n",
"Traceback:\n",
" <cell>, line 13, in cumulative_sums\n",
" result[0] = x[0]\n",
" ...\n",
"TypeError: only size-1 arrays can be converted to Python scalars\n",
"```\n",
"\n",
"I'm going to\n",
"flatten input arrays\n",
"and convert the boolean arrays to integer arrays of ones and zeros.\n",
"Of-course we'll run the test again to make sure our updated `cumulative_sums()` method now works."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "54507222-cd61-45d6-a345-adcee0dc70b1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mF\u001b[0m\u001b[31m [100%]\u001b[0m\n",
"=================================== FAILURES ===================================\n",
"\u001b[31m\u001b[1m_________________ test_positive_arrays_have_incrementing_sums __________________\u001b[0m\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(\n",
"> xps.arrays(\n",
" dtype=xps.scalar_dtypes(),\n",
" shape=xps.array_shapes().filter(\u001b[94mlambda\u001b[39;49;00m s: prod(s) > \u001b[94m1\u001b[39;49;00m),\n",
" elements={\u001b[33m\"\u001b[39;49;00m\u001b[33mmin_value\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: \u001b[94m0\u001b[39;49;00m},\n",
" )\n",
" )\n",
"\n",
"\u001b[1m\u001b[31m<cell>\u001b[0m:5: \n",
"_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ \n",
"\n",
"x = Array([4611686018427387904, 4611686018427387904], dtype=int64)\n",
"\n",
" \u001b[37m@given\u001b[39;49;00m(\n",
" xps.arrays(\n",
" dtype=xps.scalar_dtypes(),\n",
" shape=xps.array_shapes().filter(\u001b[94mlambda\u001b[39;49;00m s: prod(s) > \u001b[94m1\u001b[39;49;00m),\n",
" elements={\u001b[33m\"\u001b[39;49;00m\u001b[33mmin_value\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: \u001b[94m0\u001b[39;49;00m},\n",
" )\n",
" )\n",
" \u001b[94mdef\u001b[39;49;00m \u001b[92mtest_positive_arrays_have_incrementing_sums\u001b[39;49;00m(x):\n",
" a = cumulative_sums(x)\n",
"> \u001b[94massert\u001b[39;49;00m nxp.all(a[\u001b[94m1\u001b[39;49;00m:] >= a[:-\u001b[94m1\u001b[39;49;00m])\n",
"\u001b[1m\u001b[31mE assert Array(False, dtype=bool)\u001b[0m\n",
"\u001b[1m\u001b[31mE + where Array(False, dtype=bool) = <function all at 0x7f2d48cc2430>(Array([-9223372036854775808], dtype=int64) >= Array([4611686018427387904], dtype=int64))\u001b[0m\n",
"\u001b[1m\u001b[31mE + where <function all at 0x7f2d48cc2430> = nxp.all\u001b[0m\n",
"\n",
"\u001b[1m\u001b[31m<cell>\u001b[0m:13: AssertionError\n",
"---------------------------------- Hypothesis ----------------------------------\n",
"Falsifying example: test_positive_arrays_have_incrementing_sums(\n",
" x=Array([4611686018427387904, 4611686018427387904], dtype=int64),\n",
")\n",
"=========================== short test summary info ============================\n",
"FAILED <cell>::test_positive_arrays_have_incrementing_sums - assert A...\n",
"\u001b[31m\u001b[31m\u001b[1m1 failed\u001b[0m, \u001b[33m1 deselected\u001b[0m\u001b[31m in 1.24s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"def cumulative_sums(x):\n",
" xp = x.__array_namespace__()\n",
" \n",
" x = xp.reshape(x, x.size)\n",
" \n",
" if x.dtype == xp.bool:\n",
" mask = x\n",
" dtype = xp.uint64\n",
" x = xp.zeros(x.shape, dtype=xp.uint64)\n",
" x[mask] = 1\n",
" \n",
" result = xp.empty(x.size, dtype=max_dtype(xp, x.dtype))\n",
" result[0] = x[0]\n",
" for i in range(1, x.size):\n",
" result[i] = result[i - 1] + x[i]\n",
" \n",
" return result\n",
"\n",
"ipytest.run(\"-k positive_arrays_have_incrementing_sums\", \"--hypothesis-seed=3\")"
]
},
{
"cell_type": "markdown",
"id": "80eb6df7",
"metadata": {},
"source": [
"We resolved our two previous issues...\n",
"but Hypothesis has found yet another failing scenario 🙃\n",
"\n",
"```python\n",
">>> x = xp.full(2, 4611686018427387904, dtype=xp.int64)\n",
">>> x\n",
"Array([ 4611686018427387904, 4611686018427387904], dtype=int64)\n",
">>> cumulative_sums(x)\n",
"Array([ 4611686018427387904, -9223372036854775808], dtype=int64)\n",
"```\n",
"\n",
"An overflow has occurred again,\n",
"which we can't do much about it this time.\n",
"There's no larger signed integer dtype than `int64` (in the Array API),\n",
"so we'll just have `cumulative_sums()` detect overflows itself."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fdeb886e-bb64-40c6-9689-b56bb58d3c0c",
"metadata": {},
"outputs": [],
"source": [
"def cumulative_sums(x):\n",
" xp = x.__array_namespace__()\n",
" \n",
" x = xp.reshape(x, x.size)\n",
" \n",
" if x.dtype == xp.bool:\n",
" mask = x\n",
" dtype = xp.uint64\n",
" x = xp.zeros(x.shape, dtype=xp.uint64)\n",
" x[mask] = 1\n",
" \n",
" result = xp.empty(x.size, dtype=max_dtype(xp, x.dtype))\n",
" result[0] = x[0]\n",
" for i in range(1, x.size):\n",
" result[i] = result[i - 1] + x[i]\n",
" if result[i] < result[i - 1]:\n",
" raise OverflowError(\"Cumulative sum cannot be represented\")\n",
" \n",
" return result"
]
},
{
"cell_type": "markdown",
"id": "d3ae1eae-a914-4f76-a8aa-83a648224b25",
"metadata": {},
"source": [
"If Hypothesis generates arrays which raise `OverflowError`,\n",
"we can just catch it\n",
"and use [`assume(False)`](https://hypothesis.readthedocs.io/en/latest/details.html#making-assumptions)\n",
"to ignore testing these arrays on runtime.\n",
"This \"filter-on-runtime\" behaviour\n",
"can be very handy at times,\n",
"although [their docs note `assume()` can be problematic](https://hypothesis.readthedocs.io/en/latest/details.html#how-good-is-assume).\n",
"\n",
"We can also explicitly cover overflows in a separate test."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "44722cd9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m.\u001b[0m\u001b[32m [100%]\u001b[0m\n",
"\u001b[32m\u001b[32m\u001b[1m3 passed\u001b[0m\u001b[32m in 0.27s\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"from hypothesis import assume\n",
"import pytest\n",
"\n",
"@given(\n",
" xps.arrays(\n",
" dtype=xps.scalar_dtypes(),\n",
" shape=xps.array_shapes().filter(lambda s: prod(s) > 1),\n",
" elements={\"min_value\": 0},\n",
" )\n",
")\n",
"def test_positive_arrays_have_incrementing_sums(x):\n",
" try:\n",
" a = cumulative_sums(x)\n",
" assert nxp.all(a[1:] >= a[:-1])\n",
" except OverflowError:\n",
" assume(False)\n",
" \n",
"def test_error_on_overflow():\n",
" x = nxp.asarray([nxp.iinfo(nxp.uint64).max, 1], dtype=nxp.uint64)\n",
" with pytest.raises(OverflowError):\n",
" cumulative_sums(x)\n",
"\n",
"ipytest.run()"
]
},
{
"cell_type": "markdown",
"id": "aba76cab-ccd1-48e9-8372-26c50c6383ae",
"metadata": {},
"source": [
"Our little test suite finally passes 😅\n",
"\n",
"If you're feeling adventurous,\n",
"you might want to get\n",
"[this very notebook](https://github.com/Quansight-Labs/quansight-labs-site/tree/main/posts/2021/09/hypothesis-array-api.ipynb) running\n",
"and see if you can write some test cases yourself—bonus points if they fail!\n",
"For starters,\n",
"how about testing that cumulative sums *decrease*\n",
"with arrays containing negative elements?\n",
"\n",
"When you're developing an Array API array-consuming method,\n",
"and an equivalent method already exists for one of the adopting libraries,\n",
"I highly recommend using Hypothesis to\n",
"compare its results to your own.\n",
"For example,\n",
"we could use the battle-tested [`np.cumsum()`](https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html)\n",
"to see how our `cumulative_sums()` method compares:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "436bdcb4-85a7-4e31-8015-917ed9ff964e",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()))\n",
"def test_reference_implementation(x):\n",
" a = cumulative_sums(x)\n",
" # In the future NumPy should provide a public API to interop between its\n",
" # ndarray and its Array API array. For now I'm using the private and\n",
" # unstable x._array attribute to retrieve x's underlying ndarray.\n",
" # See https://github.com/HypothesisWorks/hypothesis/issues/3101\n",
" assert np.all(a._array == np.cumsum(x._array))"
]
},
{
"cell_type": "markdown",
"id": "d409875c-4a72-42a8-9a62-36cceab5d216",
"metadata": {},
"source": [
"Comparison testing is a great exercise to really think about what your code does,\n",
"even if ultimately you conclude that you are happy with different results."
]
},
{
"cell_type": "markdown",
"id": "c976c705-1a8c-4308-a343-e172f72898aa",
"metadata": {},
"source": [
"## Watch this space\n",
"\n",
"This year should see\n",
"a first [version](https://data-apis.org/array-api/latest/future_API_evolution.html#versioning)\n",
"of the Array API standard,\n",
"and subsequently\n",
"NumPy shipping `numpy.array_api` out to the world—this means\n",
"array-consuming libraries will be able to\n",
"reliably develop for the Array API quite soon.\n",
"I hope I've demonstrated\n",
"why you should try Hypothesis\n",
"when the time comes 🙂\n",
"\n",
"Good news is that I'm extending my stay at Quansight.\n",
"My job is to help unify Python's fragmented scientific ecosystem,\n",
"so I'm more than happy\n",
"to respond to any inquiries about using Hypothesis for the Array API\n",
"via [email](mailto:quitesimplymatt@gmail.com) or [Twitter](https://twitter.com/whostolehonno).\n",
"\n",
"For now I'm contributing to\n",
"the Hypothesis-powered [Array API compliance suite](https://github.com/data-apis/array-api-tests),\n",
"which is already being used by the NumPy team to ensure `numpy.array_api`\n",
"actually complies with every tiny detail of the [specification](https://data-apis.org/array-api/latest/).\n",
"This process has the added side-effect\n",
"of finding limitations in [`hypothesis.extra.array_api`](https://hypothesis.readthedocs.io/en/latest/numpy.html#array-api),\n",
"so you can expect Hypothesis to only improve from here on out!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"nikola": {
"author": "Matthew Barber",
"date": "2021-09-17 00:00:00 UTC",
"previewimage": "/images/2021/09/hypothesis-array-api-preview.png",
"slug": "hypothesis-array-api",
"title": "Using Hypothesis to test array-consuming libraries"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment