Skip to content

Instantly share code, notes, and snippets.

@shifeiwen
Forked from Hzfengsy/run_hexagon.py
Last active March 1, 2024 12:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shifeiwen/3e4185d1cf87f979aec6ba0bf50a17f0 to your computer and use it in GitHub Desktop.
Save shifeiwen/3e4185d1cf87f979aec6ba0bf50a17f0 to your computer and use it in GitHub Desktop.
Hexagon Test Script
import sys
from typing import Dict, List
import numpy as np
import tvm.rpc.tracker
from tvm import tir
from tvm.contrib.hexagon.build import HexagonLauncher, Session
from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME
from tvm.script import tir as T
from tvm.contrib.hexagon import allocate_hexagon_array
TARGET = tvm.target.hexagon("v73")
TARGET = tvm.target.Target(TARGET, host=TARGET)
#N, K = 4096, 4096
N, K = 2048, 2048
# N, K = 1024, 1024
N_W = int(N / 8)
N_S = int(N / 32)
@T.prim_func
def gemv(
A: T.Buffer((1, 1, K), "float16"),
B: T.Buffer((K, N), "float16"),
C: T.Buffer((1, 1, N), "float16"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for i0, i1, i2, k in T.grid(1, 1, N, K):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
with T.init():
C[v_i0, v_i1, v_i2] = T.float16(0)
C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2]
@T.prim_func
def dequantize(
A: T.Buffer((N, N_W), "uint32"),
B: T.Buffer((N, N_S), "float16"),
dequantize: T.Buffer((N, N), "float16")
):
T.func_attr({"tir.noalias": T.bool(True)})
for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(A[v_i0, v_i1 // T.int64(8)],B[v_i0, v_i1 // T.int64(32)])
T.writes(dequantize[v_i0, v_i1])
dequantize[v_i0, v_i1] = ((T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))) - T.float16(7)) * B[v_i0, v_i1 // T.int64(32)]
@T.prim_func
def dequantize_gemv(
A: T.Buffer((N, N_W), "uint32"),
B: T.Buffer((N, N_S), "float16"),
V_in: T.Buffer((1, 1, K), "float16"),
V_out: T.Buffer((1, 1, K), "float16"),
):
T.func_attr({"tir.noalias": T.bool(True)})
dequantize = T.alloc_buffer((2048, 2048), "float16", scope="global.vtcm")
for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(A[v_i0, v_i1 // T.int64(8)],B[v_i0, v_i1 // T.int64(32)])
T.writes(dequantize[v_i0, v_i1])
dequantize[v_i0, v_i1] = ((T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))) - T.float16(7)) * B[v_i0, v_i1 // T.int64(32)]
for i0, i1, i2, k in T.grid(1, 1, N, K):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
with T.init():
V_out[v_i0, v_i1, v_i2] = T.float16(0)
V_out[v_i0, v_i1, v_i2] = V_out[v_i0, v_i1, v_i2] + V_in[v_i0, v_i1, v_k] * dequantize[v_k,v_i2]
def eval(
func: tir.PrimFunc,
var_dict: Dict[str, int],
session: Session,
use_sim: bool = False,
func_name: str = "func"
):
# mod = tvm.build(func, target="c")
lib = tvm.build(func, target=TARGET)
lib.save(f"vec_func.s", "s")
lib.save(f"vec_func.ll", "ll")
rt_mod = session.load_module(lib)
device = session.device
np_args: List[np.ndarray] = []
analyzer = tvm.arith.Analyzer()
for param in func.params:
buffer = func.buffer_map[param]
shape = []
for dim in buffer.shape:
if isinstance(dim, tir.IntImm):
shape.append(dim.value)
elif isinstance(dim, tir.Var):
assert dim.name in var_dict
value = var_dict[dim.name]
shape.append(value)
analyzer.bind(dim, value)
else:
raise ValueError(f"Unknown shape: {buffer.shape}")
np_args.append(np.random.uniform(size=shape).astype(buffer.dtype))
print("Allocating memory")
args = [tvm.nd.array(arg, device) for arg in np_args]
if func_name == 'func':
args = [
tvm.nd.array(np_args[0], device),
allocate_hexagon_array(device, data=np_args[1], mem_scope="global.vtcm"),
tvm.nd.array(np_args[2], device),
]
if func_name == 'dequantize':
args = [
allocate_hexagon_array(device, data=np_args[0], mem_scope="global.vtcm"),
tvm.nd.array(np_args[1], device),
tvm.nd.array(np_args[2], device),
]
if func_name == 'dequantize_gemv':
args = [
tvm.nd.array(np_args[0], device),
tvm.nd.array(np_args[1], device),
tvm.nd.array(np_args[2], device),
tvm.nd.array(np_args[3], device),
]
device.sync()
print("Start running")
if not use_sim:
time_eval = rt_mod.time_evaluator(rt_mod.entry_name, device, number=1, repeat=1)
la = [arg.size * arg.itemsize for arg in np_args]
print(" la = ",la)
total_bytes = sum(arg.size * arg.itemsize for arg in np_args)
time = time_eval(*args).mean * 1e3
bandwidth = total_bytes / time / (1024**2)
print(
f"Time (ms): {time:.4f}",
f"Total Bytes (MB): {total_bytes / (1024**2):.2f}",
f"Memory (GB/s): {bandwidth:.2f}",
sep="\t",
)
else:
rt_mod(*args)
device.sync()
return args[-1].numpy()
def sch(func: tir.PrimFunc):
vec_len = 64
sch = tir.Schedule(func)
# (main_block,) = sch.get_child_blocks(sch.get_block(name="NT_matmul", func_name="main"))
main_block = sch.get_block(name="NT_matmul", func_name="main")
*_, i, j, k = sch.get_loops(main_block)
j0, j1, j2, j3 = sch.split(j, factors=[4, None, 2, vec_len])
k0, k1 = sch.split(k, factors=[None, 4])
sch.reorder(i, j0, j1, k0, j2, k1, j3)
sch.parallel(j0)
sch.unroll(j2)
sch.vectorize(j3)
sch.decompose_reduction(main_block, k0)
return sch.mod["main"]
def sch_switch(func: tir.PrimFunc, func_name: str):
if func_name == 'func':
return sch(func)
elif func_name == "dequantize":
return sch_dequantize(func)
elif func_name == "dequantize_gemv":
return sch(sch_dequantize(func))
else:
return func
def sch_dequantize(func: tir.PrimFunc):
vec_len = 64
sch = tir.Schedule(func)
# (main_block,) = sch.get_child_blocks(sch.get_block(name="NT_matmul", func_name="main"))
main_block = sch.get_block(name="compute", func_name="main")
i, j = sch.get_loops(main_block)
j0, j1, j2, j3 = sch.split(i, factors=[4, None, 2, vec_len])
k0, k1 = sch.split(j, factors=[None, 8])
sch.reorder(j0, j1, j2, j3, k0, k1)
# sch.reorder(i0,k0,i1,k1)
sch.parallel(j0)
# sch.vectorize(k0)
sch.unroll(j2)
return sch.mod["main"]
def main():
use_sim = True if "--simulator" in sys.argv[1:] else False
tracker = tvm.rpc.tracker.Tracker("0.0.0.0", 9191)
rpc_info = {
"rpc_tracker_host": "0.0.0.0",
"rpc_tracker_port": 9191,
"adb_server_socket": "tcp:5037",
}
launcher = (
HexagonLauncher(
serial_number=HEXAGON_SIMULATOR_NAME,
rpc_info=rpc_info,
workspace="dist",
)
if use_sim
else HexagonLauncher(
serial_number="ZT36KM150048",
rpc_info=rpc_info,
)
)
launcher.start_server()
# func_name = "gemv"
# func_name = "dequantize"
func_name = "dequantize_gemv"
with launcher.create_session() as session:
# func = sch_dequantize(globals()[func_name])
func = sch_switch(globals()[func_name],func_name)
# func = globals()[func_name]
print(func)
eval(func, {"n": 256}, session, use_sim=use_sim,func_name=func_name),
launcher.stop_server()
tracker.terminate()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment