Skip to content

Instantly share code, notes, and snippets.

@Hzfengsy
Last active March 1, 2024 08:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Hzfengsy/dee4d278f0890938d93a7f66ead56bd5 to your computer and use it in GitHub Desktop.
Save Hzfengsy/dee4d278f0890938d93a7f66ead56bd5 to your computer and use it in GitHub Desktop.
Hexagon Test Script
import sys
from typing import Dict, List
import numpy as np
import tvm.rpc.tracker
from tvm import tir
from tvm.contrib.hexagon.build import HexagonLauncher, Session
from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME
from tvm.script import tir as T
TARGET = tvm.target.hexagon("v73")
TARGET = tvm.target.Target(TARGET, host=TARGET)
N, K = 1024, 1024
# N, K = 512, 512
@T.prim_func
def func(
A: T.Buffer((1, 1, K), "float16"),
B: T.Buffer((K, N), "float16"),
C: T.Buffer((1, 1, N), "float16"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for i0, i1, i2, k in T.grid(1, 1, N, K):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
with T.init():
C[v_i0, v_i1, v_i2] = T.float16(0)
C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2]
def eval(
func: tir.PrimFunc,
var_dict: Dict[str, int],
session: Session,
use_sim: bool = False,
):
lib = tvm.build(func, target=TARGET)
rt_mod = session.load_module(lib)
device = session.device
np_args: List[np.ndarray] = []
analyzer = tvm.arith.Analyzer()
for param in func.params:
buffer = func.buffer_map[param]
shape = []
for dim in buffer.shape:
if isinstance(dim, tir.IntImm):
shape.append(dim.value)
elif isinstance(dim, tir.Var):
assert dim.name in var_dict
value = var_dict[dim.name]
shape.append(value)
analyzer.bind(dim, value)
else:
raise ValueError(f"Unknown shape: {buffer.shape}")
np_args.append(np.random.uniform(size=shape).astype(buffer.dtype))
print("Allocating memory")
args = [tvm.nd.array(arg, device) for arg in np_args]
device.sync()
print("Start running")
if not use_sim:
time_eval = rt_mod.time_evaluator(rt_mod.entry_name, device, number=1, repeat=1)
total_bytes = sum(arg.size * arg.itemsize for arg in np_args)
time = time_eval(*args).mean * 1e3
bandwidth = total_bytes / time / (1024**2)
print(
f"Time (ms): {time:.4f}",
f"Total Bytes (MB): {total_bytes / (1024**2):.2f}",
f"Memory (GB/s): {bandwidth:.2f}",
sep="\t",
)
else:
rt_mod(*args)
device.sync()
return args[-1].numpy()
def sch(func: tir.PrimFunc):
vec_len = 64
sch = tir.Schedule(func)
(main_block,) = sch.get_child_blocks(sch.get_block("root"))
*_, i, j, k = sch.get_loops(main_block)
j0, j1, j2, j3 = sch.split(j, factors=[4, None, 2, vec_len])
k0, k1 = sch.split(k, factors=[None, 4])
sch.reorder(i, j0, j1, k0, j2, k1, j3)
sch.parallel(j0)
sch.unroll(j2)
sch.vectorize(j3)
sch.decompose_reduction(main_block, k0)
return sch.mod["main"]
def main():
use_sim = True if "--simulator" in sys.argv[1:] else False
tracker = tvm.rpc.tracker.Tracker("0.0.0.0", 9191)
rpc_info = {
"rpc_tracker_host": "0.0.0.0",
"rpc_tracker_port": 9191,
"adb_server_socket": "tcp:5037",
}
launcher = (
HexagonLauncher(
serial_number=HEXAGON_SIMULATOR_NAME,
rpc_info=rpc_info,
workspace="dist",
)
if use_sim
else HexagonLauncher(
serial_number="172.16.3.155:41835",
rpc_info=rpc_info,
)
)
launcher.start_server()
func_name = "func"
with launcher.create_session() as session:
func = sch(globals()[func_name])
eval(func, {"n": 256}, session, use_sim=use_sim),
launcher.stop_server()
tracker.terminate()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment