-
-
Save Hzfengsy/dee4d278f0890938d93a7f66ead56bd5 to your computer and use it in GitHub Desktop.
Hexagon Test Script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from typing import Dict, List | |
import numpy as np | |
import tvm.rpc.tracker | |
from tvm import tir | |
from tvm.contrib.hexagon.build import HexagonLauncher, Session | |
from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME | |
from tvm.script import tir as T | |
TARGET = tvm.target.hexagon("v73") | |
TARGET = tvm.target.Target(TARGET, host=TARGET) | |
N, K = 1024, 1024 | |
# N, K = 512, 512 | |
@T.prim_func | |
def func( | |
A: T.Buffer((1, 1, K), "float16"), | |
B: T.Buffer((K, N), "float16"), | |
C: T.Buffer((1, 1, N), "float16"), | |
): | |
T.func_attr({"tir.noalias": T.bool(True)}) | |
for i0, i1, i2, k in T.grid(1, 1, N, K): | |
with T.block("NT_matmul"): | |
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | |
with T.init(): | |
C[v_i0, v_i1, v_i2] = T.float16(0) | |
C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2] | |
def eval( | |
func: tir.PrimFunc, | |
var_dict: Dict[str, int], | |
session: Session, | |
use_sim: bool = False, | |
): | |
lib = tvm.build(func, target=TARGET) | |
rt_mod = session.load_module(lib) | |
device = session.device | |
np_args: List[np.ndarray] = [] | |
analyzer = tvm.arith.Analyzer() | |
for param in func.params: | |
buffer = func.buffer_map[param] | |
shape = [] | |
for dim in buffer.shape: | |
if isinstance(dim, tir.IntImm): | |
shape.append(dim.value) | |
elif isinstance(dim, tir.Var): | |
assert dim.name in var_dict | |
value = var_dict[dim.name] | |
shape.append(value) | |
analyzer.bind(dim, value) | |
else: | |
raise ValueError(f"Unknown shape: {buffer.shape}") | |
np_args.append(np.random.uniform(size=shape).astype(buffer.dtype)) | |
print("Allocating memory") | |
args = [tvm.nd.array(arg, device) for arg in np_args] | |
device.sync() | |
print("Start running") | |
if not use_sim: | |
time_eval = rt_mod.time_evaluator(rt_mod.entry_name, device, number=1, repeat=1) | |
total_bytes = sum(arg.size * arg.itemsize for arg in np_args) | |
time = time_eval(*args).mean * 1e3 | |
bandwidth = total_bytes / time / (1024**2) | |
print( | |
f"Time (ms): {time:.4f}", | |
f"Total Bytes (MB): {total_bytes / (1024**2):.2f}", | |
f"Memory (GB/s): {bandwidth:.2f}", | |
sep="\t", | |
) | |
else: | |
rt_mod(*args) | |
device.sync() | |
return args[-1].numpy() | |
def sch(func: tir.PrimFunc): | |
vec_len = 64 | |
sch = tir.Schedule(func) | |
(main_block,) = sch.get_child_blocks(sch.get_block("root")) | |
*_, i, j, k = sch.get_loops(main_block) | |
j0, j1, j2, j3 = sch.split(j, factors=[4, None, 2, vec_len]) | |
k0, k1 = sch.split(k, factors=[None, 4]) | |
sch.reorder(i, j0, j1, k0, j2, k1, j3) | |
sch.parallel(j0) | |
sch.unroll(j2) | |
sch.vectorize(j3) | |
sch.decompose_reduction(main_block, k0) | |
return sch.mod["main"] | |
def main(): | |
use_sim = True if "--simulator" in sys.argv[1:] else False | |
tracker = tvm.rpc.tracker.Tracker("0.0.0.0", 9191) | |
rpc_info = { | |
"rpc_tracker_host": "0.0.0.0", | |
"rpc_tracker_port": 9191, | |
"adb_server_socket": "tcp:5037", | |
} | |
launcher = ( | |
HexagonLauncher( | |
serial_number=HEXAGON_SIMULATOR_NAME, | |
rpc_info=rpc_info, | |
workspace="dist", | |
) | |
if use_sim | |
else HexagonLauncher( | |
serial_number="172.16.3.155:41835", | |
rpc_info=rpc_info, | |
) | |
) | |
launcher.start_server() | |
func_name = "func" | |
with launcher.create_session() as session: | |
func = sch(globals()[func_name]) | |
eval(func, {"n": 256}, session, use_sim=use_sim), | |
launcher.stop_server() | |
tracker.terminate() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment