Last active
January 18, 2022 09:42
-
-
Save honno/70623811b88c0f5b224c668ec4ed86b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Wrapper of dask for use with github.com/data-apis/array-api-tests | |
# Tested with dask version 2022.01.0 | |
# How to use: | |
# 1. Place this file in `array_api_tests/_dask.py` | |
# 2. In `array_api_tests/_array_module.py` replace `array_module = None` with | |
# `from ._dask import array_module` | |
from dask import array as da | |
import numpy as np | |
from numpy import array_api as nxp | |
array_module = da | |
# dask needs namespaced dtypes | |
uint_names = ("uint8", "uint16", "uint32", "uint64") | |
int_names = ("int8", "int16", "int32", "int64") | |
float_names = ("float32", "float64") | |
dtype_names = ("bool",) + uint_names + int_names + float_names | |
for name in dtype_names: | |
# note dask is using np.dtype(), not NumPy's namespaced dtypes e.g. np.int64 | |
dtype = np.dtype(name) | |
setattr(array_module, name, dtype) | |
# dask's info objects don't hold Python scalars | |
array_module.iinfo = nxp.iinfo | |
array_module.finfo = nxp.finfo |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment