-
-
Save tehrengruber/ce94dea127b30cf045ff8cd932eb148d to your computer and use it in GitHub Desktop.
Strides check
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import ctypes | |
import numpy as np | |
import gt4py as gt | |
from gt4py.backend import REGISTRY | |
backend = "gt:cpu_kfirst" # (0, 1, 2) | |
#backend = "gt:cpu_ifirst" # (2, 0, 1) | |
# layout map: sort strides by desceding order. then layout is are the indices in sorted array | |
inp_i = gt.storage.zeros(default_origin=(0,0,0), shape=(3,3,3), dtype=np.int64, backend=backend) | |
inp_j = gt.storage.zeros(default_origin=(0,0,0), shape=(3,3,3), dtype=np.int64, backend=backend) | |
inp_k = gt.storage.zeros(default_origin=(0,0,0), shape=(3,3,3), dtype=np.int64, backend=backend) | |
for i, j, k in itertools.product(range(0, 3), range(0, 3), range(0, 3)): | |
inp_i[i, j, k] = i | |
inp_j[i, j, k] = j | |
inp_k[i, j, k] = k | |
def raw_array_access(arr, i, j, k): | |
strides = arr.__array_interface__['strides'] | |
data_pointer = arr.__array_interface__['data'][0] + sum(el * strides[n] for n, el in enumerate((i, j, k))) | |
data_pointer = ctypes.cast(data_pointer,ctypes.POINTER(ctypes.c_int64)) | |
return data_pointer.contents.value | |
strides = np.asarray(inp_i).__array_interface__["strides"] | |
layout_map = REGISTRY[backend].storage_info["layout_map"]((True, True, True)) | |
for i, j, k in itertools.product(range(0, 3), range(0, 3), range(0, 3)): | |
assert raw_array_access(inp_i, i, j, k) == i | |
assert raw_array_access(inp_j, i, j, k) == j | |
assert raw_array_access(inp_k, i, j, k) == k | |
bla = 1+1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment