Skip to content

Instantly share code, notes, and snippets.

@gmarkall
Last active August 8, 2022 09:59
Show Gist options
  • Save gmarkall/23c0d5e1e879a117bd84bb95a2d8f1c8 to your computer and use it in GitHub Desktop.
Save gmarkall/23c0d5e1e879a117bd84bb95a2d8f1c8 to your computer and use it in GitHub Desktop.
# Implements unicode equality for the CUDA target
from numba import cuda, types
from numba.core.extending import overload
from numba.core.pythonapi import (PY_UNICODE_1BYTE_KIND,
PY_UNICODE_2BYTE_KIND,
PY_UNICODE_4BYTE_KIND)
from numba.cpython.unicode import deref_uint8, deref_uint16, deref_uint32
import numpy as np
import operator
# Copied / modified from numba/cpython/unicode.py
@overload(len, target='cuda')
def unicode_len(s):
if isinstance(s, types.UnicodeType):
def len_impl(s):
return s._length
return len_impl
def get_code_point(a, i):
pass
@overload(get_code_point, target='cuda')
def get_code_point_ol(a, i):
def get_code_point_impl(a, i):
if a._kind == PY_UNICODE_1BYTE_KIND:
return deref_uint8(a._data, i)
elif a._kind == PY_UNICODE_2BYTE_KIND:
return deref_uint16(a._data, i)
elif a._kind == PY_UNICODE_4BYTE_KIND:
return deref_uint32(a._data, i)
else:
# there's also a wchar kind, but that's one of the above, so
# skipping for this example
return 0
return get_code_point_impl
def cmp_region(a, a_offset, b, b_offset, n):
pass
@overload(cmp_region, target='cuda')
def cmp_region_ol(a, a_offset, b, b_offset, n):
def cmp_region_impl(a, a_offset, b, b_offset, n):
if n == 0:
return 0
elif a_offset + n > a._length:
return -1
elif b_offset + n > b._length:
return 1
for i in range(n):
a_chr = get_code_point(a, a_offset + i)
b_chr = get_code_point(b, b_offset + i)
if a_chr < b_chr:
return -1
elif a_chr > b_chr:
return 1
return 0
return cmp_region_impl
@overload(operator.eq, target='cuda')
def unicode_eq(a, b):
if not (a.is_internal and b.is_internal):
return
if isinstance(a, types.Optional):
check_a = a.type
else:
check_a = a
if isinstance(b, types.Optional):
check_b = b.type
else:
check_b = b
accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq)
a_unicode = isinstance(check_a, accept)
b_unicode = isinstance(check_b, accept)
if a_unicode and b_unicode:
def eq_impl(a, b):
# handle Optionals at runtime
a_none = a is None
b_none = b is None
if a_none or b_none:
if a_none and b_none:
return True
else:
return False
# the str() is for UnicodeCharSeq, it's a nop else
# (commented out for CUDA to avoid implementing str())
# a = str(a)
# b = str(b)
if len(a) != len(b):
return False
return cmp_region(a, 0, b, 0, len(a)) == 0
return eq_impl
elif a_unicode ^ b_unicode:
# one of the things is unicode, everything compares False
def eq_impl(a, b):
return False
return eq_impl
@cuda.jit
def find_fruit(arr, string):
y = 0
for x in arr:
if x == string:
return y
break
y += 1
return -1
@cuda.jit
def kernel(loc):
fruits = ('apple', 'banana', 'cherry')
loc[()] = find_fruit(fruits, 'banana')
c1 = np.ndarray((), dtype=np.int64)
kernel[1, 1](c1)
print(c1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment