Skip to content

Instantly share code, notes, and snippets.

Created September 21, 2022 12:51
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save honno/6531d1e8d1acef9b3ef713200c76d91c to your computer and use it in GitHub Desktop.
def make_strategies_namespace(
xp: Any, *, api_version: Optional[NominalVersion] = None
) -> SimpleNamespace:
if api_version is None:
# When api_version=None, we infer the most recent API version for which
# the passed xp is valid. We go through the released versions in
# descending order, passing them to x.__array_namespace__() until no
# errors are raised, thus inferring that specific api_version is
# supported. If errors are raised for all released versions, we raise
# our own useful error.
hasattr(xp, "zeros"),
f"Array module {xp.__name__} has no function zeros(), which is "
"required when inferring api_version.",
errmsg = (
f"Could not infer any api_version which module {xp.__name__} "
f"supports. If you believe {xp.__name__} is indeed an Array API "
"module, try explicitly passing an api_version."
array = xp.zeros(1)
except Exception:
raise InvalidArgument(errmsg)
for api_version in reversed(RELEASED_VERSIONS):
with contextlib.suppress(Exception):
xp = array.__array_namespace__(api_version=api_version)
break # i.e. a valid xp and api_version has been inferred
raise InvalidArgument(errmsg)
# Tests ------------------------------------------------------------------------
def test_raises_on_inferring_with_no_zeros_func():
"""When xp has no zeros(), inferring api_version raises helpful error."""
xp = make_mock_xp(exclude=("zeros",))
with pytest.raises(InvalidArgument, match="has no function"):
def test_raises_on_erroneous_zeros_func():
"""When xp has erroneous zeros(), inferring api_version raises helpful error."""
xp = make_mock_xp()
xp.zeros = None
with pytest.raises(InvalidArgument):
class MockArray:
def __init__(self, supported_versions: Tuple[NominalVersion, ...]):
assert len(set(supported_versions)) == len(supported_versions) # sanity check
self.supported_versions = supported_versions
def __array_namespace__(self, *, api_version: Optional[NominalVersion] = None):
if api_version is not None and api_version not in self.supported_versions:
return SimpleNamespace(
__name__="foopy", zeros=lambda _: MockArray(self.supported_versions)
version_permutations: List[Tuple[NominalVersion, ...]] = [
RELEASED_VERSIONS[:i] for i in range(1, len(RELEASED_VERSIONS) + 1)
ids=lambda supported_versions: "-".join(supported_versions),
def test_version_inferrence(supported_versions):
"""Latest supported api_version is inferred."""
xp = MockArray(supported_versions).__array_namespace__()
xps = make_strategies_namespace(xp)
assert xps.api_version == supported_versions[-1]
def test_raises_on_inferring_with_no_supported_versions():
"""When xp supports no versions, inferring api_version raises helpful error."""
xp = MockArray(()).__array_namespace__()
with pytest.raises(InvalidArgument):
xps = make_strategies_namespace(xp)
("api_version", "supported_versions"),
[pytest.param(p[-1], p[:-1], id=p[-1]) for p in version_permutations],
def test_warns_on_specifying_unsupported_version(api_version, supported_versions):
"""Specifying an api_version which xp does not support executes with a warning."""
xp = MockArray(supported_versions).__array_namespace__()
xp.zeros = None
with pytest.warns(HypothesisWarning):
xps = make_strategies_namespace(xp, api_version=api_version)
assert xps.api_version == api_version
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment