Last active
July 27, 2018 02:40
-
-
Save vinx13/8bb465e948d5f5883c67bc82d56167c9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
def main(n): | |
a = tvm.placeholder((n,), dtype='int8') | |
c = tvm.compute((n,), lambda i: a[i]) | |
s = tvm.create_schedule(c.op) | |
bx, tx = s[c].split(s[c].op.axis[0], factor=4) | |
s[c].bind(bx, tvm.thread_axis('blockIdx.x')) | |
s[c].vectorize(tx) | |
with tvm.build_config(data_alignment=16): | |
f = tvm.build(s, [a,c], 'cuda') | |
if __name__=='__main__': | |
main(1024) | |
''' | |
Generated CUDA code | |
extern "C" __global__ void default_function_kernel0( signed char* __restrict__ compute, signed char* __restrict__ placeholder) { | |
(( int*)(compute + (((int)blockIdx.x) * 4)))[0] = (( int*)(placeholder + (((int)blockIdx.x) * 4)))[0]; | |
} | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment