-
-
Save shifeiwen/3e4185d1cf87f979aec6ba0bf50a17f0 to your computer and use it in GitHub Desktop.
Hexagon Test Script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from typing import Dict, List | |
import numpy as np | |
import tvm.rpc.tracker | |
from tvm import tir | |
from tvm.contrib.hexagon.build import HexagonLauncher, Session | |
from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME | |
from tvm.script import tir as T | |
from tvm.contrib.hexagon import allocate_hexagon_array | |
TARGET = tvm.target.hexagon("v73") | |
TARGET = tvm.target.Target(TARGET, host=TARGET) | |
#N, K = 4096, 4096 | |
N, K = 2048, 2048 | |
# N, K = 1024, 1024 | |
N_W = int(N / 8) | |
N_S = int(N / 32) | |
@T.prim_func | |
def gemv( | |
A: T.Buffer((1, 1, K), "float16"), | |
B: T.Buffer((K, N), "float16"), | |
C: T.Buffer((1, 1, N), "float16"), | |
): | |
T.func_attr({"tir.noalias": T.bool(True)}) | |
for i0, i1, i2, k in T.grid(1, 1, N, K): | |
with T.block("NT_matmul"): | |
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | |
with T.init(): | |
C[v_i0, v_i1, v_i2] = T.float16(0) | |
C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2] | |
@T.prim_func | |
def dequantize( | |
A: T.Buffer((N, N_W), "uint32"), | |
B: T.Buffer((N, N_S), "float16"), | |
dequantize: T.Buffer((N, N), "float16") | |
): | |
T.func_attr({"tir.noalias": T.bool(True)}) | |
for i0, i1 in T.grid(T.int64(2048), T.int64(2048)): | |
with T.block("compute"): | |
v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) | |
T.reads(A[v_i0, v_i1 // T.int64(8)],B[v_i0, v_i1 // T.int64(32)]) | |
T.writes(dequantize[v_i0, v_i1]) | |
dequantize[v_i0, v_i1] = ((T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))) - T.float16(7)) * B[v_i0, v_i1 // T.int64(32)] | |
@T.prim_func | |
def dequantize_gemv( | |
A: T.Buffer((N, N_W), "uint32"), | |
B: T.Buffer((N, N_S), "float16"), | |
V_in: T.Buffer((1, 1, K), "float16"), | |
V_out: T.Buffer((1, 1, K), "float16"), | |
): | |
T.func_attr({"tir.noalias": T.bool(True)}) | |
dequantize = T.alloc_buffer((2048, 2048), "float16", scope="global.vtcm") | |
for i0, i1 in T.grid(T.int64(2048), T.int64(2048)): | |
with T.block("compute"): | |
v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) | |
T.reads(A[v_i0, v_i1 // T.int64(8)],B[v_i0, v_i1 // T.int64(32)]) | |
T.writes(dequantize[v_i0, v_i1]) | |
dequantize[v_i0, v_i1] = ((T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))) - T.float16(7)) * B[v_i0, v_i1 // T.int64(32)] | |
for i0, i1, i2, k in T.grid(1, 1, N, K): | |
with T.block("NT_matmul"): | |
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | |
with T.init(): | |
V_out[v_i0, v_i1, v_i2] = T.float16(0) | |
V_out[v_i0, v_i1, v_i2] = V_out[v_i0, v_i1, v_i2] + V_in[v_i0, v_i1, v_k] * dequantize[v_k,v_i2] | |
def eval( | |
func: tir.PrimFunc, | |
var_dict: Dict[str, int], | |
session: Session, | |
use_sim: bool = False, | |
func_name: str = "func" | |
): | |
# mod = tvm.build(func, target="c") | |
lib = tvm.build(func, target=TARGET) | |
lib.save(f"vec_func.s", "s") | |
lib.save(f"vec_func.ll", "ll") | |
rt_mod = session.load_module(lib) | |
device = session.device | |
np_args: List[np.ndarray] = [] | |
analyzer = tvm.arith.Analyzer() | |
for param in func.params: | |
buffer = func.buffer_map[param] | |
shape = [] | |
for dim in buffer.shape: | |
if isinstance(dim, tir.IntImm): | |
shape.append(dim.value) | |
elif isinstance(dim, tir.Var): | |
assert dim.name in var_dict | |
value = var_dict[dim.name] | |
shape.append(value) | |
analyzer.bind(dim, value) | |
else: | |
raise ValueError(f"Unknown shape: {buffer.shape}") | |
np_args.append(np.random.uniform(size=shape).astype(buffer.dtype)) | |
print("Allocating memory") | |
args = [tvm.nd.array(arg, device) for arg in np_args] | |
if func_name == 'func': | |
args = [ | |
tvm.nd.array(np_args[0], device), | |
allocate_hexagon_array(device, data=np_args[1], mem_scope="global.vtcm"), | |
tvm.nd.array(np_args[2], device), | |
] | |
if func_name == 'dequantize': | |
args = [ | |
allocate_hexagon_array(device, data=np_args[0], mem_scope="global.vtcm"), | |
tvm.nd.array(np_args[1], device), | |
tvm.nd.array(np_args[2], device), | |
] | |
if func_name == 'dequantize_gemv': | |
args = [ | |
tvm.nd.array(np_args[0], device), | |
tvm.nd.array(np_args[1], device), | |
tvm.nd.array(np_args[2], device), | |
tvm.nd.array(np_args[3], device), | |
] | |
device.sync() | |
print("Start running") | |
if not use_sim: | |
time_eval = rt_mod.time_evaluator(rt_mod.entry_name, device, number=1, repeat=1) | |
la = [arg.size * arg.itemsize for arg in np_args] | |
print(" la = ",la) | |
total_bytes = sum(arg.size * arg.itemsize for arg in np_args) | |
time = time_eval(*args).mean * 1e3 | |
bandwidth = total_bytes / time / (1024**2) | |
print( | |
f"Time (ms): {time:.4f}", | |
f"Total Bytes (MB): {total_bytes / (1024**2):.2f}", | |
f"Memory (GB/s): {bandwidth:.2f}", | |
sep="\t", | |
) | |
else: | |
rt_mod(*args) | |
device.sync() | |
return args[-1].numpy() | |
def sch(func: tir.PrimFunc): | |
vec_len = 64 | |
sch = tir.Schedule(func) | |
# (main_block,) = sch.get_child_blocks(sch.get_block(name="NT_matmul", func_name="main")) | |
main_block = sch.get_block(name="NT_matmul", func_name="main") | |
*_, i, j, k = sch.get_loops(main_block) | |
j0, j1, j2, j3 = sch.split(j, factors=[4, None, 2, vec_len]) | |
k0, k1 = sch.split(k, factors=[None, 4]) | |
sch.reorder(i, j0, j1, k0, j2, k1, j3) | |
sch.parallel(j0) | |
sch.unroll(j2) | |
sch.vectorize(j3) | |
sch.decompose_reduction(main_block, k0) | |
return sch.mod["main"] | |
def sch_switch(func: tir.PrimFunc, func_name: str): | |
if func_name == 'func': | |
return sch(func) | |
elif func_name == "dequantize": | |
return sch_dequantize(func) | |
elif func_name == "dequantize_gemv": | |
return sch(sch_dequantize(func)) | |
else: | |
return func | |
def sch_dequantize(func: tir.PrimFunc): | |
vec_len = 64 | |
sch = tir.Schedule(func) | |
# (main_block,) = sch.get_child_blocks(sch.get_block(name="NT_matmul", func_name="main")) | |
main_block = sch.get_block(name="compute", func_name="main") | |
i, j = sch.get_loops(main_block) | |
j0, j1, j2, j3 = sch.split(i, factors=[4, None, 2, vec_len]) | |
k0, k1 = sch.split(j, factors=[None, 8]) | |
sch.reorder(j0, j1, j2, j3, k0, k1) | |
# sch.reorder(i0,k0,i1,k1) | |
sch.parallel(j0) | |
# sch.vectorize(k0) | |
sch.unroll(j2) | |
return sch.mod["main"] | |
def main(): | |
use_sim = True if "--simulator" in sys.argv[1:] else False | |
tracker = tvm.rpc.tracker.Tracker("0.0.0.0", 9191) | |
rpc_info = { | |
"rpc_tracker_host": "0.0.0.0", | |
"rpc_tracker_port": 9191, | |
"adb_server_socket": "tcp:5037", | |
} | |
launcher = ( | |
HexagonLauncher( | |
serial_number=HEXAGON_SIMULATOR_NAME, | |
rpc_info=rpc_info, | |
workspace="dist", | |
) | |
if use_sim | |
else HexagonLauncher( | |
serial_number="ZT36KM150048", | |
rpc_info=rpc_info, | |
) | |
) | |
launcher.start_server() | |
# func_name = "gemv" | |
# func_name = "dequantize" | |
func_name = "dequantize_gemv" | |
with launcher.create_session() as session: | |
# func = sch_dequantize(globals()[func_name]) | |
func = sch_switch(globals()[func_name],func_name) | |
# func = globals()[func_name] | |
print(func) | |
eval(func, {"n": 256}, session, use_sim=use_sim,func_name=func_name), | |
launcher.stop_server() | |
tracker.terminate() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment