Skip to content

Instantly share code, notes, and snippets.

@ailzhang
Created November 10, 2022 04:56
Show Gist options
  • Save ailzhang/d207a53c26720cccb7e7c5a9687009c2 to your computer and use it in GitHub Desktop.
Save ailzhang/d207a53c26720cccb7e7c5a9687009c2 to your computer and use it in GitHub Desktop.
import taichi as ti
ti.init(ti.vulkan)
tp_ivec3 = ti.types.vector(3, ti.i32)
tp_ivec2 = ti.types.vector(2, ti.i32)
tp_vec3 = ti.types.vector(3, ti.f32)
x = ti.ndarray(ti.f32, shape=(12, 13))
y = ti.ndarray(tp_ivec3, shape=(12,4))
z = ti.ndarray(ti.i32, shape=(12, 12))
m = ti.ndarray(tp_ivec2, shape=(2, 3))
n = ti.ndarray(ti.f32, shape=(12,))
k = ti.ndarray(tp_vec3, shape=(12,))
@ti.kernel
def test1(arr: ti.types.ndarray()):
for I in ti.grouped(arr):
arr[I] = 1
test1(x)
print(x.to_numpy())
# The following currently errors out. Shall we automatically expand?
#test1(y)
#print(y.to_numpy())
@ti.kernel
def test2(arr: ti.types.ndarray(dtype=ti.f32)):
for I in ti.grouped(arr):
arr[I] = 2.0
test2(x)
print(x.to_numpy())
# This works, but should it error out since it doesn't match type annotation?
test2(z)
print(z.to_numpy())
# And this should definitely error out, but currently error message is bad
# test2(y)
@ti.kernel
def test3(arr: ti.types.ndarray(dtype=tp_ivec3)):
for I in ti.grouped(arr):
arr[I] = [0, 1, 2]
test3(y)
print(y.to_numpy())
# The following errors out as expected, but message can be improved.
# test3(m)
# This should have errored out but didn't
# test3(k)
@ti.kernel
def test4(arr: ti.types.ndarray(dtype=tp_ivec3)):
for I in ti.grouped(arr):
arr[I] = [1, 2]
# test4 is just bad kernel. This errors out in kernel compilation as expected, maybe better error message?
# test4(y)
@ti.kernel
def test5(arr: ti.types.ndarray(field_dim=1)):
for i, j in arr:
arr[i, j] = 0
# test5 is just bad kernel, this should error out but error message is weird
# test5(x)
@ti.kernel
def test6(arr: ti.types.ndarray(field_dim=2)):
for i, j in arr:
arr[i, j] = 6
# Works as expected
test6(x)
print(x.to_numpy())
# Errors out but message can be more helpful
# test6(m)
# Errors out but message is weird
# test6(n)
# Similar behavior should happen in AOT compilation as well.
m = ti.aot.Module(ti.vulkan)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment