Skip to content

Instantly share code, notes, and snippets.

@pmeier
Created May 21, 2021 09:51
Show Gist options
  • Save pmeier/ea35bdffb597b35f4f6592c5ac201cd4 to your computer and use it in GitHub Desktop.
Save pmeier/ea35bdffb597b35f4f6592c5ac201cd4 to your computer and use it in GitHub Desktop.
Check inter-category type promotion behavior of 0d-tensors for array API compatibility
import itertools
from typing import Collection
import networkx as nx
# overwrite this with the array API that you want to test
import numpy as array_api
def maybe_add_dtype(
graph: nx.Graph, name: str, promotes_to_names: Collection[str] = ()
) -> None:
try:
dtype = getattr(array_api, name)
except AttributeError:
return
graph.add_node(dtype)
for name in promotes_to_names:
try:
promoted_dtype = getattr(array_api, name)
except AttributeError:
continue
graph.add_edge(dtype, promoted_dtype)
integral_graph = nx.DiGraph()
maybe_add_dtype(integral_graph, "int8", ("int16",))
maybe_add_dtype(integral_graph, "int16", ("int32",))
maybe_add_dtype(integral_graph, "int32", ("int64",))
maybe_add_dtype(integral_graph, "int64")
maybe_add_dtype(integral_graph, "uint8", ("uint16", "int16"))
maybe_add_dtype(integral_graph, "uint16", ("uint32", "int32"))
maybe_add_dtype(integral_graph, "uint32", ("uint64", "int64"))
maybe_add_dtype(integral_graph, "uint64")
floating_graph = nx.DiGraph()
maybe_add_dtype(floating_graph, "float32", ("float64",))
maybe_add_dtype(floating_graph, "float64")
for graph in (integral_graph, floating_graph):
reverse_graph = graph.reverse()
for dtype_0d, dtype_nd in itertools.product(graph.nodes, repeat=2):
dtype_expected = nx.lowest_common_ancestor(reverse_graph, dtype_0d, dtype_nd)
if not dtype_expected:
continue
a = array_api.empty((), dtype=dtype_0d)
b = array_api.empty((1,), dtype=dtype_nd)
dtype_actual = array_api.result_type(a, b)
if dtype_actual != dtype_expected:
print(f"0d {dtype_0d} + nd {dtype_nd} = {dtype_actual} != {dtype_expected}")
@krshrimali
Copy link

This is cool! Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment