Skip to content

Instantly share code, notes, and snippets.

@jpivarski
Created May 26, 2021 21:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jpivarski/e6c48a9153b9e472fa0ce5fba5859357 to your computer and use it in GitHub Desktop.
Save jpivarski/e6c48a9153b9e472fa0ce5fba5859357 to your computer and use it in GitHub Desktop.
Assigning ctypes argument types to all Awkward Array kernels from the YAML specification
import yaml
import awkward as ak
arg2key = {
"List[bool]": "np.bool_",
"Const[List[bool]]": "np.bool_",
"List[int8_t]": "np.int8",
"Const[List[int8_t]]": "np.int8",
"List[uint8_t]": "np.uint8",
"Const[List[uint8_t]]": "np.uint8",
"List[int16_t]": "np.int16",
"Const[List[int16_t]]": "np.int16",
"List[uint16_t]": "np.uint16",
"Const[List[uint16_t]]": "np.uint16",
"List[int32_t]": "np.int32",
"Const[List[int32_t]]": "np.int32",
"List[uint32_t]": "np.uint32",
"Const[List[uint32_t]]": "np.uint32",
"List[int64_t]": "np.int64",
"Const[List[int64_t]]": "np.int64",
"List[uint64_t]": "np.uint64",
"Const[List[uint64_t]]": "np.uint64",
"List[float]": "np.float32",
"Const[List[float]]": "np.float32",
"List[double]": "np.float64",
"Const[List[double]]": "np.float64",
"List[List[int8_t]]": "np.intp",
"Const[List[List[uint8_t]]]": "np.intp",
"List[List[int64_t]]": "np.intp",
"List[List[int64_t]]": "np.uintp",
}
arg2ctypes = {
"bool": "ctypes.c_bool",
"int8_t": "ctypes.c_int8",
"uint8_t": "ctypes.c_uint8",
"int16_t": "ctypes.c_int16",
"uint16_t": "ctypes.c_uint16",
"int32_t": "ctypes.c_int32",
"uint32_t": "ctypes.c_uint32",
"int64_t": "ctypes.c_int64",
"uint64_t": "ctypes.c_uint64",
"float": "ctypes.c_float",
"double": "ctypes.c_double",
"List[bool]": "ctypes.POINTER(ctypes.c_bool)",
"Const[List[bool]]": "ctypes.POINTER(ctypes.c_bool)",
"List[int8_t]": "ctypes.POINTER(ctypes.c_int8)",
"Const[List[int8_t]]": "ctypes.POINTER(ctypes.c_int8)",
"List[uint8_t]": "ctypes.POINTER(ctypes.c_uint8)",
"Const[List[uint8_t]]": "ctypes.POINTER(ctypes.c_uint8)",
"List[int16_t]": "ctypes.POINTER(ctypes.c_int16)",
"Const[List[int16_t]]": "ctypes.POINTER(ctypes.c_int16)",
"List[uint16_t]": "ctypes.POINTER(ctypes.c_uint16)",
"Const[List[uint16_t]]": "ctypes.POINTER(ctypes.c_uint16)",
"List[int32_t]": "ctypes.POINTER(ctypes.c_int32)",
"Const[List[int32_t]]": "ctypes.POINTER(ctypes.c_int32)",
"List[uint32_t]": "ctypes.POINTER(ctypes.c_uint32)",
"Const[List[uint32_t]]": "ctypes.POINTER(ctypes.c_uint32)",
"List[int64_t]": "ctypes.POINTER(ctypes.c_int64)",
"Const[List[int64_t]]": "ctypes.POINTER(ctypes.c_int64)",
"List[uint64_t]": "ctypes.POINTER(ctypes.c_uint64)",
"Const[List[uint64_t]]": "ctypes.POINTER(ctypes.c_uint64)",
"List[float]": "ctypes.POINTER(ctypes.c_float)",
"Const[List[float]]": "ctypes.POINTER(ctypes.c_float)",
"List[double]": "ctypes.POINTER(ctypes.c_double)",
"Const[List[double]]": "ctypes.POINTER(ctypes.c_double)",
"List[List[int8_t]]": "ctypes.POINTER(ctypes.POINTER(ctypes.c_int8))",
"Const[List[List[uint8_t]]]": "ctypes.POINTER(ctypes.POINTER(ctypes.c_uint8))",
"List[List[int64_t]]": "ctypes.POINTER(ctypes.POINTER(ctypes.c_int64))",
# not correct
"const int64_t": "ctypes.c_int64",
}
print("""import ctypes
import awkward as ak
import numpy as np
import time
start_time = time.time()
class ERROR(ctypes.Structure):
_fields_ = [
("str", ctypes.c_char_p),
("fliename", ctypes.c_char_p),
("id", ctypes.c_int64),
("attempt", ctypes.c_int64),
("pass_through", ctypes.c_bool)
]
""")
for spec in yaml.safe_load(open("kernel-specification.yml"))["kernels"]:
print("{0} = {{}}".format(spec["name"]))
for special in spec["specializations"]:
key = ", ".join(
[arg2key[x["type"]] for x in special["args"] if "[" in x["type"]]
)
print("{0}[{1}] = ak._cpu_kernels.lib.{2}".format(
spec["name"], key, special["name"])
)
print("ak._cpu_kernels.lib.{0}.argtypes = [".format(special["name"]))
for x in special["args"]:
print(" {0}, # {1}".format(arg2ctypes[x["type"]], x["name"]))
print("]")
print("ak._cpu_kernels.lib.{0}.restype = ERROR".format(special["name"]))
print("")
print("print(time.time() - start_time)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment