Created
August 22, 2023 10:38
-
-
Save LeiWang1999/09c9d02d5ce775bebeb9497e58ce3b25 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Considering a gemm problem, in this part we try to leverage the ldmatrix, mma, and stmatrix to do the computation. | |
The ldmatrix and stmatrix are used to load and store the data from global memory to shared memory. | |
The mma is used to do the computation. | |
thread_x will be set into 32, which represents the number of threads in a warp. | |
thread_y and thread_z will be set into value which represents the array of warps. | |
""" | |
import tvm | |
from tvm.script import tir as T | |
from tvm import te, tir, topi | |
import numpy as np | |
import os | |
import math | |
import sys | |
# add path ../.. | |
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) | |
from tvm.tir.function import TensorIntrin | |
from tvm.tir.tensor_intrin.cuda import ( | |
LDMATRIX_16x16_A_INTRIN, | |
LDMATRIX_16x16_B_INTRIN, | |
LDMATRIX_16x16_B_TRANS_INTRIN, | |
LDMATRIX_16x32_A_INTRIN, | |
LDMATRIX_32x16_B_INTRIN, | |
LDMATRIX_16x32_B_TRANS_INTRIN, | |
MMA_f16f16f32_INTRIN, | |
MMA_f16f16f32_TRANS_INTRIN, | |
MMA_f16f16f16_INTRIN, | |
MMA_f16f16f16_TRANS_INTRIN, | |
MMA_i8i8i32_INTRIN, | |
MMA_i8i8i32_TRANS_INTRIN, | |
MMA_fill_16x16_f32_INTRIN, | |
MMA_fill_16x16_f16_INTRIN, | |
MMA_fill_16x16_i32_INTRIN, | |
MMA_store_16x16_f32_global_INTRIN, | |
MMA_store_16x16_f16_global_INTRIN, | |
MMA_store_16x16_i32_global_INTRIN, | |
shared_16x16_to_ldmatrix_32x8_layout, | |
shared_32x16_to_ldmatrix_32x16_layout, | |
shared_16x32_to_ldmatrix_32x16_layout, | |
get_mma_store_intrin | |
) | |
MMA_store_16x16_f16_shared_INTRIN = "mma_store_16x16_f16_shared_" | |
TensorIntrin.register( | |
MMA_store_16x16_f16_shared_INTRIN, *get_mma_store_intrin("float16", 8, "shared") | |
) | |
import sqlite3 | |
class NNIDatabase(object): | |
def __init__(self, db_path: str): | |
self.db_path = db_path | |
self.conn = None | |
self.cursor = None | |
pass | |
def connect(self): | |
self.conn = sqlite3.connect(self.db_path) | |
self.cursor = self.conn.cursor() | |
def get_best_params(self): | |
self.cursor.execute("SELECT * FROM MetricData") | |
all_data_record = self.cursor.fetchall() | |
# find the minimum data record (sort by the last column) | |
all_data_record.sort(key=lambda x: float(x[-1].replace('"',''))) | |
lowest_metric_data_record = all_data_record[0] | |
best_params = {} | |
if lowest_metric_data_record: | |
print(f"Lowest metric data record: {lowest_metric_data_record}") | |
trial_job_id = lowest_metric_data_record[1] | |
self.cursor.execute("SELECT * FROM TrialJobEvent WHERE trialJobId = ?", (trial_job_id,)) | |
trial_job_event_records = self.cursor.fetchall() | |
is_trail_success = False | |
for record in trial_job_event_records: | |
if record[2] == "SUCCEEDED": | |
is_trail_success = True | |
break | |
for record in trial_job_event_records: | |
if record[2] == "WAITING": | |
best_params = eval(record[3])['parameters'] | |
else: | |
print("No metric data records found.") | |
print(f"Best params: {best_params}") | |
return best_params | |
def get_best_latency(self): | |
self.cursor.execute("SELECT * FROM MetricData") | |
all_data_record = self.cursor.fetchall() | |
# find the minimum data record (sort by the last column) | |
all_data_record.sort(key=lambda x: float(x[-1].replace('"', ''))) | |
lowest_metric_data_record = all_data_record[0] | |
if not lowest_metric_data_record: | |
print("No metric data records found.") | |
return lowest_metric_data_record | |
def close(self): | |
self.conn.close() | |
# get file name and remove the suffix | |
fname = os.path.basename(__file__) | |
fname = os.path.splitext(fname)[0] | |
# create log path | |
log_path = "progress/tensorirscript_imma/" + fname | |
count = 0 | |
def write_code(code, path, fname): | |
global count | |
# if path not exist, create it | |
fname = str(count) + "." + fname | |
count += 1 | |
if not os.path.exists(path): | |
os.makedirs(path) | |
# join path and fname | |
fname = os.path.join(path, fname) | |
with open(fname, "w") as f: | |
f.write(code) | |
def write_sch(sch, path, fname): | |
py_fname = fname + ".py" | |
write_code(sch.mod["main"].script(), path, py_fname) | |
cu_fname = fname + ".cu" | |
write_code(sch.mod.astext(), path, cu_fname) | |
@tvm.register_func | |
def tvm_callback_cuda_postproc(code): | |
code = code.replace("#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1", "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0") | |
if stage == 1: | |
code = code.replace( | |
'''__asm__ __volatile__("cp.async.commit_group;");''', ' ') | |
code = code.replace( | |
'''__asm__ __volatile__("cp.async.wait_group 0;");''', '''__asm__ __volatile__("cp.async.commit_group;"); | |
__asm__ __volatile__("cp.async.wait_group 0;");''') | |
# if the next line is a __syncthreads(), replace it with number | |
return code | |
resnet_18_conv2d = [ | |
# N C H W F K S D P G HO WO | |
(128, 3, 224, 224, 64, 7, 2, 1, 3, 1, 112, 112), | |
# (128, 16, 224, 224, 64, 7, 2, 1, 0, 1, 112, 112), | |
] | |
VERIFY = True | |
for N, C, H, W, F, K, S, D, P, G, HO, WO in resnet_18_conv2d: | |
# get file name and remove the suffix | |
fname = f"c3_nhwc_nhwc_conv2d_N{N}_C{C}_H{H}_W{W}_F{F}_K{K}_S{S}_D{D}_P{P}_G{G}_HO{HO}_WO{WO}" | |
# create log path | |
log_path = "/workspace/v-leiwang3/asplos_2024_script/end2end_kernels/" + fname + '/' | |
id = f"c3_nhwc_nhwc_conv2d_N{N}_C{C}_H{H}_W{W}_F{F}_K{K}_S{S}_D{D}_P{P}_G{G}_HO{HO}_WO{WO}" | |
db_path = f"/workspace/v-leiwang3/asplos_2024_script/conv2d/resnet18_nhwc_ladder/nni-experiments-20230807/{id}/db/nni.sqlite" | |
db = NNIDatabase(db_path) | |
db.connect() | |
params = db.get_best_params() | |
print(params) | |
db.close() | |
# The sizes of inputs and filters | |
batch_size = N | |
height = H | |
width = W | |
in_channels = C | |
out_channels = F | |
kernel_h = K | |
kernel_w = K | |
pad_h = P | |
pad_w = P | |
stride_h = S | |
stride_w = S | |
dilation_h = D | |
dilation_w = D | |
groups = G | |
output_height = HO | |
output_width = WO | |
print(height) | |
# TensorCore shape | |
wmma_m = 16 | |
wmma_n = 16 | |
wmma_k = 16 | |
assert batch_size % wmma_k == 0 | |
# assert in_channels % wmma_m == 0 | |
assert out_channels % wmma_n == 0 | |
import nni | |
# tuning params | |
block_row_warps = params['block_row_warps'] | |
block_col_warps = params['block_col_warps'] | |
warp_row_tiles = params['warp_row_tiles'] | |
warp_col_tiles = params['warp_col_tiles'] | |
chunk = params['chunk'] | |
stage = params['stage'] | |
raster = params['raster'] | |
use_async = params['use_async'] | |
vec = params['vec'] | |
warp_size = 32 | |
# Input feature map: (N, H, W, IC, n, ic) | |
data_shape = ( | |
batch_size, | |
height, | |
width, | |
in_channels, | |
) | |
# Kernel: (H, W, IC, OC, ic, oc) | |
kernel_shape = ( | |
out_channels, | |
kernel_h, | |
kernel_w, | |
in_channels, | |
) | |
# Output feature map: (N, H, W, OC, n, oc) | |
output_shape = ( | |
batch_size, | |
output_height, | |
output_width, | |
out_channels, | |
) | |
M = batch_size * output_height * output_width | |
N = out_channels | |
K = kernel_h * kernel_w * in_channels | |
print(M, N, K) | |
# padding KPAD as the multiple chunk * wmma_k | |
KPAD = (K + chunk * wmma_k - 1) // (chunk * wmma_k) * (chunk * wmma_k) | |
print("KPAD", KPAD) | |
# Algorithm | |
@tvm.script.ir_module | |
class MyModule: | |
@T.prim_func | |
def main(a: T.handle, w: T.handle, bias: T.handle, conv: T.handle): | |
T.func_attr({"global_symbol": "main", "tir.noalias": True}) | |
W = T.match_buffer(w, (N, K), dtype="float16") | |
A = T.match_buffer(a, (batch_size, height, width, in_channels), dtype="float16") | |
Conv = T.match_buffer(conv, (M, N), dtype="float16") | |
Bias = T.match_buffer(bias, (out_channels), dtype="float16") | |
data_im2col = T.alloc_buffer([M, K], dtype="float16") | |
Apad = T.alloc_buffer((batch_size, in_channels, height + 2*pad_h, width + 2*pad_w), dtype="float16") | |
Conv_shared = T.alloc_buffer((M, N), dtype="float16", scope="shared") | |
Conv_bias = T.alloc_buffer((M, N), dtype="float16", scope="local") | |
data_im2col_Pad = T.alloc_buffer([M, KPAD], dtype="float16") | |
W_pad = T.alloc_buffer([N, KPAD], dtype="float16") | |
for n, h, w, c in T.grid(batch_size, in_channels, height + 2*pad_h, width + 2*pad_w): | |
with T.block("Apad"): | |
vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c]) | |
Apad[vn, vh, vw, vc] = T.if_then_else(pad_h <= vh and vh < height + pad_h and pad_w <= vw and vw < width + pad_w, A[vn, vh - pad_h, vw - pad_w, vc], T.float16(0), dtype="float16") | |
for x, y in T.grid(M, K): | |
with T.block("data_im2col"): | |
v_x, v_y = T.axis.remap("SS", [x, y]) | |
data_im2col[v_x, v_y] = Apad[ | |
v_x // (output_height * output_width), | |
stride_h * ((v_x % (output_height * output_width)) // output_width) + dilation_h * ((v_y // (in_channels)) // kernel_w), | |
stride_w * ((v_x % (output_height * output_width)) % output_width) + dilation_w * ((v_y // (in_channels)) % kernel_w), | |
v_y % (in_channels), | |
] | |
for i, k in T.grid(M, KPAD): | |
with T.block("data_Pad"): | |
vi, vk = T.axis.remap("SS", [i, k]) | |
data_im2col_Pad[vi, vk] = T.if_then_else(vk < K, data_im2col[vi, vk], T.float16(0), dtype="float16") | |
for j, k in T.grid(N, KPAD): | |
with T.block("weight_Pad"): | |
vk, vj = T.axis.remap("SS", [k, j]) | |
W_pad[vj, vk] = T.if_then_else(vk < K, W[vj, vk], T.float16(0), dtype="float16") | |
for y, x, k in T.grid(M, N, KPAD): | |
with T.block("Conv"): | |
v_x, v_y, v_k = T.axis.remap("SSR", [y, x, k]) | |
with T.init(): | |
Conv_shared[v_x, v_y] = T.float16(0.0) | |
Conv_shared[v_x, v_y] = Conv_shared[v_x, v_y] + data_im2col_Pad[v_x, v_k] * W_pad[v_y, v_k] | |
for x, y in T.grid(M, N): | |
with T.block("Conv_bias"): | |
vom, von = T.axis.remap("SS", [x, y]) | |
Conv_bias[vom, von] = T.if_then_else(Conv_shared[vom, von] + Bias[von] > T.float16(0.0), Conv_shared[vom, von] + Bias[von], T.float16(0), dtype="float16") | |
for x, y in T.grid(M, N): | |
with T.block("Conv_shared"): | |
vom, von = T.axis.remap("SS", [x, y]) | |
Conv[vom, von] = Conv_bias[vom, von] | |
ir_module = MyModule | |
# print(ir_module) | |
sch = tvm.tir.Schedule(ir_module, debug_mask="all") | |
write_sch(sch, log_path, "original") | |
block_pad = sch.get_block("Apad") | |
block_data_pad = sch.get_block("data_Pad") | |
block_weight_pad = sch.get_block("weight_Pad") | |
block_im2col = sch.get_block("data_im2col") | |
block_conv = sch.get_block("Conv") | |
block_conv_bias = sch.get_block("Conv_bias") | |
block_conv_shared = sch.get_block("Conv_shared") | |
block_conv_input_shared = sch.cache_read(block_conv, 0 ,"shared") | |
block_conv_input_frag = sch.cache_read(block_conv, 0, "warp") | |
block_conv_weight_shared = sch.cache_read(block_conv, 1 ,"shared") | |
block_conv_weight_frag = sch.cache_read(block_conv, 1, "warp") | |
block_conv_output_frag = sch.cache_write(block_conv, 0, "warp") | |
block_bias_local = sch.cache_read(block_conv_bias, 1, "local") | |
write_sch(sch, log_path, "cache_related") | |
sch.compute_inline(block_pad) | |
write_sch(sch, log_path, "PadInputInline") | |
sch.compute_inline(block_im2col) | |
write_sch(sch, log_path, "Im2ColInline") | |
sch.compute_inline(block_data_pad) | |
sch.compute_inline(block_weight_pad) | |
(i, j, k) = sch.get_loops(block_conv) | |
i, kernel_i = sch.split(i, factors=[None, wmma_m]) | |
j, kernel_j = sch.split(j, factors=[None, wmma_n]) | |
k, kernel_k = sch.split(k, factors=[None, wmma_k]) | |
block_i, i, ii = sch.split(i, factors=[None, block_row_warps, warp_row_tiles]) | |
block_j, j, jj = sch.split(j, factors=[None, block_col_warps, warp_col_tiles]) | |
ko, ki = sch.split(k, factors=[None, chunk]) | |
sch.reorder(block_i, block_j, i, j, ko, ki, ii, jj, kernel_i, kernel_j, kernel_k) | |
write_sch(sch, log_path, "block_tile") | |
sch.bind(block_i, "blockIdx.y") | |
sch.bind(block_j, "blockIdx.x") | |
sch.bind(i, "threadIdx.y") | |
sch.bind(j, "threadIdx.z") | |
write_sch(sch, log_path, "thread_bind") | |
# cache read A from global memory to shared_memory | |
sch.compute_at(block_conv_input_frag, ki, preserve_unit_loops=True) | |
sch.compute_at(block_conv_input_shared, ko, preserve_unit_loops=True) | |
sch.compute_at(block_conv_weight_frag, ki, preserve_unit_loops=True) | |
sch.compute_at(block_conv_weight_shared, ko, preserve_unit_loops=True) | |
sch.reverse_compute_at(block_conv_output_frag, j, preserve_unit_loops=True) | |
sch.reverse_compute_at(block_bias_local, | |
sch.get_loops(block_conv_output_frag)[-4], preserve_unit_loops=True) | |
sch.reverse_compute_at(block_conv_bias, | |
sch.get_loops(block_conv_output_frag)[-4], preserve_unit_loops=True) | |
sch.reverse_compute_at(block_conv_shared, | |
sch.get_loops(block_conv_output_frag)[-4], preserve_unit_loops=True) | |
write_sch(sch, log_path, "cache_read_compute_at") | |
def tricky_extract_cache(block, sub_i, sub_j): | |
i, j = sch.get_loops(block)[-2:] | |
i, kernel_i = sch.split(i, factors=[None, sub_i]) | |
j, kernel_j = sch.split(j, factors=[None, sub_j]) | |
sch.reorder(i, j, kernel_i, kernel_j) | |
return (i, j, kernel_i, kernel_j) | |
block_conv_input_frag_loops = tricky_extract_cache( | |
block_conv_input_frag, wmma_m, wmma_k) | |
block_conv_input_frag_loops = tricky_extract_cache( | |
block_conv_weight_frag, wmma_m, wmma_k) | |
write_sch(sch, log_path, "tricky_extract_cache") | |
# 128x32 | |
A_shared_fused = sch.fuse(*sch.get_loops(block_conv_input_shared)[-2:]) | |
A_shared_ty, A_shared_tz, A_shared_inner, A_shared_tx, A_shared_vi = sch.split( | |
A_shared_fused, factors=[block_row_warps, block_col_warps, None, warp_size, vec]) | |
sch.vectorize(A_shared_vi) | |
sch.bind(A_shared_tx, "threadIdx.x") | |
sch.bind(A_shared_ty, "threadIdx.y") | |
sch.bind(A_shared_tz, "threadIdx.z") | |
sch.storage_align(block_conv_input_shared, 0, axis=-2, factor=32, offset=8) | |
write_sch(sch, log_path, "schedule_A_shared") | |
B_shared_fused = sch.fuse(*sch.get_loops(block_conv_weight_shared)[-2:]) | |
B_shared_ty, B_shared_tz, B_shared_inner, B_shared_tx, B_shared_vi = sch.split( | |
B_shared_fused, factors=[block_row_warps, block_col_warps, None, warp_size, vec]) | |
sch.storage_align(block_conv_weight_shared, 0, axis=-2, factor=32, offset=8) | |
sch.vectorize(B_shared_vi) | |
sch.bind(B_shared_tx, "threadIdx.x") | |
sch.bind(B_shared_ty, "threadIdx.y") | |
sch.bind(B_shared_tz, "threadIdx.z") | |
write_sch(sch, log_path, "schedule_B_shared") | |
def schedule_shared_output(block): | |
o_shared_fused = sch.fuse(*sch.get_loops(block)[-2:]) | |
_, o_shared_tx, o_shared_vi = sch.split( | |
o_shared_fused, factors=[None, warp_size, vec] | |
) | |
sch.compute_at(block_bias_local, o_shared_vi, preserve_unit_loops=True) | |
sch.reverse_compute_at(block_conv_shared, o_shared_vi, preserve_unit_loops=True) | |
sch.vectorize(o_shared_vi) | |
# sch.unroll(o_shared_vi) | |
sch.bind(o_shared_tx, "threadIdx.x") | |
# schedule the output shared memory | |
schedule_shared_output(block_conv_bias) | |
# decompose reduction | |
init_block_b = sch.decompose_reduction(block_conv, ko) | |
write_sch(sch, log_path, "decompose_reduction") | |
init_block_b_loops = sch.get_loops(init_block_b) | |
def recover_c(i, j): | |
# cnhw | |
c = j | |
n = i // (output_width * output_height) | |
h = (i // output_width) % output_height | |
w = i % output_width | |
return (n, h, w, c) | |
sch.transform_layout(block_conv_shared, ("write", 0), recover_c) | |
write_sch(sch, log_path, "transform_layout") | |
def index_map_A(i, j): | |
return ( | |
i // 16, | |
j // 16, | |
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), | |
) | |
def index_map_B(i, j): | |
return ( | |
i // 16, | |
j // 16, | |
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), | |
) | |
def index_map_C(i, j): | |
return ( | |
i // 16, | |
j // 16, | |
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), | |
) | |
sch.transform_layout(block_conv_input_frag, ("write", 0), index_map_A) | |
sch.transform_layout(block_conv_weight_frag, ("write", 0), index_map_B) | |
sch.transform_layout(block_conv_output_frag, ("read", 0), index_map_C) | |
sch.tensorize( | |
sch.get_loops(init_block_b)[-2], MMA_fill_16x16_f16_INTRIN) | |
write_sch(sch, log_path, "tensorize_wmma_fill") | |
sch.tensorize(sch.get_loops(block_conv_input_frag) | |
[-2], LDMATRIX_16x16_A_INTRIN) | |
sch.tensorize(sch.get_loops(block_conv_weight_frag) | |
[-2], LDMATRIX_16x16_B_TRANS_INTRIN) | |
write_sch(sch, log_path, "tensorize_ldmatrix") | |
sch.tensorize(kernel_i, MMA_f16f16f16_TRANS_INTRIN) | |
write_sch(sch, log_path, "tensorize_wmma_sync") | |
out_i, out_j = sch.get_loops(block_conv_output_frag)[-2:] | |
out_i, ok_i = sch.split(out_i, factors=[None, wmma_m]) | |
out_j, ok_j = sch.split(out_j, factors=[None, wmma_n]) | |
sch.reorder(out_i, out_j, ok_i, ok_j) | |
sch.tensorize(ok_i, MMA_store_16x16_f16_shared_INTRIN) | |
write_sch(sch, log_path, "tensorize_store") | |
# unroll | |
write_sch(sch, log_path, | |
"do_unroll") | |
if stage > 0: | |
sch.annotate(ko, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) | |
sch.annotate(ko, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) | |
if use_async: | |
sch.annotate(ko, "software_pipeline_async_stages", [0]) | |
if raster > 0: | |
sch.annotate(init_block_b_loops[-4], ann_key="thread_rasterization", ann_val=raster) | |
ctx = tvm.cuda(0) | |
with tvm.transform.PassContext(config={"tir.use_async_copy": use_async}): | |
cuda_mod = tvm.build(sch.mod, target="cuda") | |
write_code(cuda_mod.imported_modules[0].get_source(), log_path, "tmp.cu") | |
# random init a and b | |
a_np = np.random.uniform(size=data_shape).astype("float16") | |
# b_np = np.random.uniform(size=(N // wmma_m, K // wmma_k, wmma_n, wmma_k)).astype("float16") | |
# a_np = np.ones(data_shape).astype("float16") | |
b_np = np.ones((N, K)).astype("float16") | |
# aragnge init a and b | |
# a_np = np.mod(np.arange(0, batch_size * in_channels * height * width), 10).reshape(data_shape).astype("float16") | |
# b_np = np.mod(np.arange(0, kernel_h * kernel_w * in_channels * out_channels), 10).reshape(kernel_shape).astype("float16") | |
c_np = np.zeros((batch_size, output_height, output_height, out_channels)).astype("float16") | |
bias_np = np.ones((out_channels)).astype("float16") | |
cuda_a = tvm.nd.array(a_np, ctx) | |
cuda_b = tvm.nd.array(b_np, ctx) | |
cuda_c = tvm.nd.array(c_np, ctx) | |
cuda_bias = tvm.nd.array(bias_np, ctx) | |
cuda_mod(cuda_a, cuda_b, cuda_bias, cuda_c) | |
GridDim_z = 1 | |
GridDim_y = int(sch.get_sref(block_i).stmt.extent) | |
GridDim_x = int(sch.get_sref(block_j).stmt.extent) | |
BlockDim_y = int(sch.get_sref(i).stmt.extent) | |
BlockDim_z = int(sch.get_sref(j).stmt.extent) | |
BlockDim_x = warp_size | |
print("GridDim_z", GridDim_z) | |
print("GridDim_y", GridDim_y) | |
print("GridDim_x", GridDim_x) | |
verified = False | |
if VERIFY: | |
# do conv with torch | |
import torch | |
# convert a from nhwc to nchw | |
a_np = np.transpose(a_np, (0, 3, 1, 2)) | |
a_torch = torch.tensor(a_np, device="cuda", dtype=torch.float32) | |
a_torch = torch.nn.functional.pad(a_torch, (pad_h, pad_h, pad_w, pad_w)) | |
# convert b from oihw1616 into oihw | |
b_np = b_np.reshape((out_channels, kernel_h, kernel_w, in_channels)) | |
b_np = np.transpose(b_np, (0, 3, 1, 2)) | |
b_torch = torch.tensor(b_np, device="cuda", dtype=torch.float32) | |
c_torch = torch.nn.functional.conv2d( | |
a_torch, b_torch, stride=(stride_h, stride_w), groups=1) | |
# c_torch_np = np.transpose(c_torch.cpu().numpy(), (1, 0, 2, 3)) | |
c_torch_np = c_torch.cpu().numpy() | |
print(c_torch_np.shape) | |
# convert c from cnhw into cnhw | |
c_np = cuda_c.numpy().reshape( | |
(batch_size, out_channels, output_height, output_width) | |
) | |
verified = np.allclose(c_torch_np, c_np, atol=1e-1, rtol=1e-1) | |
print("torch result: ", c_torch_np[0][0][0][0:10]) | |
print("tvm result: ", c_np[0][0][0][0:10]) | |
if not verified: | |
print("torch result: ", c_torch_np[0][0][0][0:10]) | |
print("tvm result: ", c_np[0][0][0][0:10]) | |
num_runs = 3 | |
timer_cuda_mod = cuda_mod.time_evaluator( | |
cuda_mod.entry_name, ctx, number=num_runs) | |
t = timer_cuda_mod(cuda_a, cuda_b, cuda_bias, cuda_c).mean | |
print(f"task conv2d N({batch_size}) C({in_channels}) H({height}) W({width}) R({kernel_h}) S({kernel_w}) K({out_channels}) P({pad_h}) Q({pad_w}) stride({stride_h}), verified({verified}) time cost: {t * 1e3} ms, {((2*M*N*K)/t)/ pow((1024), 4)} tflops, {((2*M*N*K)/t)/ pow((1024), 4) / 145 * 100} %") | |
nni.report_final_result(t * 1e3) | |
def export_info_to_json(): | |
code = cuda_mod.imported_modules[0].get_source() | |
# delete str before __global__ in code | |
code = code.split("extern \"C\"")[1] | |
code = code.replace("half* __restrict__ A, half* __restrict__ W, half* __restrict__ Bias, half* __restrict__ Conv", "half* __restrict__ Bias, half* __restrict__ A, half* __restrict__ W, half* __restrict__ Conv") | |
import json | |
infos = {} | |
infos["grid_size"] = [GridDim_x, GridDim_y, GridDim_z] | |
infos["block_size"] = [BlockDim_x, BlockDim_y, BlockDim_z] | |
infos["code"] = code | |
infos["latency"] = t * 1e3 | |
info_path = log_path + "info.json" | |
# print(infos) | |
json.dump(infos, open(info_path, "w"), indent=4) | |
export_info_to_json() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment