Skip to content

Instantly share code, notes, and snippets.

@LeiWang1999
Created August 22, 2023 10:38
Show Gist options
  • Save LeiWang1999/09c9d02d5ce775bebeb9497e58ce3b25 to your computer and use it in GitHub Desktop.
Save LeiWang1999/09c9d02d5ce775bebeb9497e58ce3b25 to your computer and use it in GitHub Desktop.
"""
Considering a gemm problem, in this part we try to leverage the ldmatrix, mma, and stmatrix to do the computation.
The ldmatrix and stmatrix are used to load and store the data from global memory to shared memory.
The mma is used to do the computation.
thread_x will be set into 32, which represents the number of threads in a warp.
thread_y and thread_z will be set into value which represents the array of warps.
"""
import tvm
from tvm.script import tir as T
from tvm import te, tir, topi
import numpy as np
import os
import math
import sys
# add path ../..
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from tvm.tir.function import TensorIntrin
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_INTRIN,
LDMATRIX_16x16_B_INTRIN,
LDMATRIX_16x16_B_TRANS_INTRIN,
LDMATRIX_16x32_A_INTRIN,
LDMATRIX_32x16_B_INTRIN,
LDMATRIX_16x32_B_TRANS_INTRIN,
MMA_f16f16f32_INTRIN,
MMA_f16f16f32_TRANS_INTRIN,
MMA_f16f16f16_INTRIN,
MMA_f16f16f16_TRANS_INTRIN,
MMA_i8i8i32_INTRIN,
MMA_i8i8i32_TRANS_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
shared_16x16_to_ldmatrix_32x8_layout,
shared_32x16_to_ldmatrix_32x16_layout,
shared_16x32_to_ldmatrix_32x16_layout,
get_mma_store_intrin
)
MMA_store_16x16_f16_shared_INTRIN = "mma_store_16x16_f16_shared_"
TensorIntrin.register(
MMA_store_16x16_f16_shared_INTRIN, *get_mma_store_intrin("float16", 8, "shared")
)
import sqlite3
class NNIDatabase(object):
def __init__(self, db_path: str):
self.db_path = db_path
self.conn = None
self.cursor = None
pass
def connect(self):
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
def get_best_params(self):
self.cursor.execute("SELECT * FROM MetricData")
all_data_record = self.cursor.fetchall()
# find the minimum data record (sort by the last column)
all_data_record.sort(key=lambda x: float(x[-1].replace('"','')))
lowest_metric_data_record = all_data_record[0]
best_params = {}
if lowest_metric_data_record:
print(f"Lowest metric data record: {lowest_metric_data_record}")
trial_job_id = lowest_metric_data_record[1]
self.cursor.execute("SELECT * FROM TrialJobEvent WHERE trialJobId = ?", (trial_job_id,))
trial_job_event_records = self.cursor.fetchall()
is_trail_success = False
for record in trial_job_event_records:
if record[2] == "SUCCEEDED":
is_trail_success = True
break
for record in trial_job_event_records:
if record[2] == "WAITING":
best_params = eval(record[3])['parameters']
else:
print("No metric data records found.")
print(f"Best params: {best_params}")
return best_params
def get_best_latency(self):
self.cursor.execute("SELECT * FROM MetricData")
all_data_record = self.cursor.fetchall()
# find the minimum data record (sort by the last column)
all_data_record.sort(key=lambda x: float(x[-1].replace('"', '')))
lowest_metric_data_record = all_data_record[0]
if not lowest_metric_data_record:
print("No metric data records found.")
return lowest_metric_data_record
def close(self):
self.conn.close()
# get file name and remove the suffix
fname = os.path.basename(__file__)
fname = os.path.splitext(fname)[0]
# create log path
log_path = "progress/tensorirscript_imma/" + fname
count = 0
def write_code(code, path, fname):
global count
# if path not exist, create it
fname = str(count) + "." + fname
count += 1
if not os.path.exists(path):
os.makedirs(path)
# join path and fname
fname = os.path.join(path, fname)
with open(fname, "w") as f:
f.write(code)
def write_sch(sch, path, fname):
py_fname = fname + ".py"
write_code(sch.mod["main"].script(), path, py_fname)
cu_fname = fname + ".cu"
write_code(sch.mod.astext(), path, cu_fname)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
code = code.replace("#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1", "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0")
if stage == 1:
code = code.replace(
'''__asm__ __volatile__("cp.async.commit_group;");''', ' ')
code = code.replace(
'''__asm__ __volatile__("cp.async.wait_group 0;");''', '''__asm__ __volatile__("cp.async.commit_group;");
__asm__ __volatile__("cp.async.wait_group 0;");''')
# if the next line is a __syncthreads(), replace it with number
return code
resnet_18_conv2d = [
# N C H W F K S D P G HO WO
(128, 3, 224, 224, 64, 7, 2, 1, 3, 1, 112, 112),
# (128, 16, 224, 224, 64, 7, 2, 1, 0, 1, 112, 112),
]
VERIFY = True
for N, C, H, W, F, K, S, D, P, G, HO, WO in resnet_18_conv2d:
# get file name and remove the suffix
fname = f"c3_nhwc_nhwc_conv2d_N{N}_C{C}_H{H}_W{W}_F{F}_K{K}_S{S}_D{D}_P{P}_G{G}_HO{HO}_WO{WO}"
# create log path
log_path = "/workspace/v-leiwang3/asplos_2024_script/end2end_kernels/" + fname + '/'
id = f"c3_nhwc_nhwc_conv2d_N{N}_C{C}_H{H}_W{W}_F{F}_K{K}_S{S}_D{D}_P{P}_G{G}_HO{HO}_WO{WO}"
db_path = f"/workspace/v-leiwang3/asplos_2024_script/conv2d/resnet18_nhwc_ladder/nni-experiments-20230807/{id}/db/nni.sqlite"
db = NNIDatabase(db_path)
db.connect()
params = db.get_best_params()
print(params)
db.close()
# The sizes of inputs and filters
batch_size = N
height = H
width = W
in_channels = C
out_channels = F
kernel_h = K
kernel_w = K
pad_h = P
pad_w = P
stride_h = S
stride_w = S
dilation_h = D
dilation_w = D
groups = G
output_height = HO
output_width = WO
print(height)
# TensorCore shape
wmma_m = 16
wmma_n = 16
wmma_k = 16
assert batch_size % wmma_k == 0
# assert in_channels % wmma_m == 0
assert out_channels % wmma_n == 0
import nni
# tuning params
block_row_warps = params['block_row_warps']
block_col_warps = params['block_col_warps']
warp_row_tiles = params['warp_row_tiles']
warp_col_tiles = params['warp_col_tiles']
chunk = params['chunk']
stage = params['stage']
raster = params['raster']
use_async = params['use_async']
vec = params['vec']
warp_size = 32
# Input feature map: (N, H, W, IC, n, ic)
data_shape = (
batch_size,
height,
width,
in_channels,
)
# Kernel: (H, W, IC, OC, ic, oc)
kernel_shape = (
out_channels,
kernel_h,
kernel_w,
in_channels,
)
# Output feature map: (N, H, W, OC, n, oc)
output_shape = (
batch_size,
output_height,
output_width,
out_channels,
)
M = batch_size * output_height * output_width
N = out_channels
K = kernel_h * kernel_w * in_channels
print(M, N, K)
# padding KPAD as the multiple chunk * wmma_k
KPAD = (K + chunk * wmma_k - 1) // (chunk * wmma_k) * (chunk * wmma_k)
print("KPAD", KPAD)
# Algorithm
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, w: T.handle, bias: T.handle, conv: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
W = T.match_buffer(w, (N, K), dtype="float16")
A = T.match_buffer(a, (batch_size, height, width, in_channels), dtype="float16")
Conv = T.match_buffer(conv, (M, N), dtype="float16")
Bias = T.match_buffer(bias, (out_channels), dtype="float16")
data_im2col = T.alloc_buffer([M, K], dtype="float16")
Apad = T.alloc_buffer((batch_size, in_channels, height + 2*pad_h, width + 2*pad_w), dtype="float16")
Conv_shared = T.alloc_buffer((M, N), dtype="float16", scope="shared")
Conv_bias = T.alloc_buffer((M, N), dtype="float16", scope="local")
data_im2col_Pad = T.alloc_buffer([M, KPAD], dtype="float16")
W_pad = T.alloc_buffer([N, KPAD], dtype="float16")
for n, h, w, c in T.grid(batch_size, in_channels, height + 2*pad_h, width + 2*pad_w):
with T.block("Apad"):
vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c])
Apad[vn, vh, vw, vc] = T.if_then_else(pad_h <= vh and vh < height + pad_h and pad_w <= vw and vw < width + pad_w, A[vn, vh - pad_h, vw - pad_w, vc], T.float16(0), dtype="float16")
for x, y in T.grid(M, K):
with T.block("data_im2col"):
v_x, v_y = T.axis.remap("SS", [x, y])
data_im2col[v_x, v_y] = Apad[
v_x // (output_height * output_width),
stride_h * ((v_x % (output_height * output_width)) // output_width) + dilation_h * ((v_y // (in_channels)) // kernel_w),
stride_w * ((v_x % (output_height * output_width)) % output_width) + dilation_w * ((v_y // (in_channels)) % kernel_w),
v_y % (in_channels),
]
for i, k in T.grid(M, KPAD):
with T.block("data_Pad"):
vi, vk = T.axis.remap("SS", [i, k])
data_im2col_Pad[vi, vk] = T.if_then_else(vk < K, data_im2col[vi, vk], T.float16(0), dtype="float16")
for j, k in T.grid(N, KPAD):
with T.block("weight_Pad"):
vk, vj = T.axis.remap("SS", [k, j])
W_pad[vj, vk] = T.if_then_else(vk < K, W[vj, vk], T.float16(0), dtype="float16")
for y, x, k in T.grid(M, N, KPAD):
with T.block("Conv"):
v_x, v_y, v_k = T.axis.remap("SSR", [y, x, k])
with T.init():
Conv_shared[v_x, v_y] = T.float16(0.0)
Conv_shared[v_x, v_y] = Conv_shared[v_x, v_y] + data_im2col_Pad[v_x, v_k] * W_pad[v_y, v_k]
for x, y in T.grid(M, N):
with T.block("Conv_bias"):
vom, von = T.axis.remap("SS", [x, y])
Conv_bias[vom, von] = T.if_then_else(Conv_shared[vom, von] + Bias[von] > T.float16(0.0), Conv_shared[vom, von] + Bias[von], T.float16(0), dtype="float16")
for x, y in T.grid(M, N):
with T.block("Conv_shared"):
vom, von = T.axis.remap("SS", [x, y])
Conv[vom, von] = Conv_bias[vom, von]
ir_module = MyModule
# print(ir_module)
sch = tvm.tir.Schedule(ir_module, debug_mask="all")
write_sch(sch, log_path, "original")
block_pad = sch.get_block("Apad")
block_data_pad = sch.get_block("data_Pad")
block_weight_pad = sch.get_block("weight_Pad")
block_im2col = sch.get_block("data_im2col")
block_conv = sch.get_block("Conv")
block_conv_bias = sch.get_block("Conv_bias")
block_conv_shared = sch.get_block("Conv_shared")
block_conv_input_shared = sch.cache_read(block_conv, 0 ,"shared")
block_conv_input_frag = sch.cache_read(block_conv, 0, "warp")
block_conv_weight_shared = sch.cache_read(block_conv, 1 ,"shared")
block_conv_weight_frag = sch.cache_read(block_conv, 1, "warp")
block_conv_output_frag = sch.cache_write(block_conv, 0, "warp")
block_bias_local = sch.cache_read(block_conv_bias, 1, "local")
write_sch(sch, log_path, "cache_related")
sch.compute_inline(block_pad)
write_sch(sch, log_path, "PadInputInline")
sch.compute_inline(block_im2col)
write_sch(sch, log_path, "Im2ColInline")
sch.compute_inline(block_data_pad)
sch.compute_inline(block_weight_pad)
(i, j, k) = sch.get_loops(block_conv)
i, kernel_i = sch.split(i, factors=[None, wmma_m])
j, kernel_j = sch.split(j, factors=[None, wmma_n])
k, kernel_k = sch.split(k, factors=[None, wmma_k])
block_i, i, ii = sch.split(i, factors=[None, block_row_warps, warp_row_tiles])
block_j, j, jj = sch.split(j, factors=[None, block_col_warps, warp_col_tiles])
ko, ki = sch.split(k, factors=[None, chunk])
sch.reorder(block_i, block_j, i, j, ko, ki, ii, jj, kernel_i, kernel_j, kernel_k)
write_sch(sch, log_path, "block_tile")
sch.bind(block_i, "blockIdx.y")
sch.bind(block_j, "blockIdx.x")
sch.bind(i, "threadIdx.y")
sch.bind(j, "threadIdx.z")
write_sch(sch, log_path, "thread_bind")
# cache read A from global memory to shared_memory
sch.compute_at(block_conv_input_frag, ki, preserve_unit_loops=True)
sch.compute_at(block_conv_input_shared, ko, preserve_unit_loops=True)
sch.compute_at(block_conv_weight_frag, ki, preserve_unit_loops=True)
sch.compute_at(block_conv_weight_shared, ko, preserve_unit_loops=True)
sch.reverse_compute_at(block_conv_output_frag, j, preserve_unit_loops=True)
sch.reverse_compute_at(block_bias_local,
sch.get_loops(block_conv_output_frag)[-4], preserve_unit_loops=True)
sch.reverse_compute_at(block_conv_bias,
sch.get_loops(block_conv_output_frag)[-4], preserve_unit_loops=True)
sch.reverse_compute_at(block_conv_shared,
sch.get_loops(block_conv_output_frag)[-4], preserve_unit_loops=True)
write_sch(sch, log_path, "cache_read_compute_at")
def tricky_extract_cache(block, sub_i, sub_j):
i, j = sch.get_loops(block)[-2:]
i, kernel_i = sch.split(i, factors=[None, sub_i])
j, kernel_j = sch.split(j, factors=[None, sub_j])
sch.reorder(i, j, kernel_i, kernel_j)
return (i, j, kernel_i, kernel_j)
block_conv_input_frag_loops = tricky_extract_cache(
block_conv_input_frag, wmma_m, wmma_k)
block_conv_input_frag_loops = tricky_extract_cache(
block_conv_weight_frag, wmma_m, wmma_k)
write_sch(sch, log_path, "tricky_extract_cache")
# 128x32
A_shared_fused = sch.fuse(*sch.get_loops(block_conv_input_shared)[-2:])
A_shared_ty, A_shared_tz, A_shared_inner, A_shared_tx, A_shared_vi = sch.split(
A_shared_fused, factors=[block_row_warps, block_col_warps, None, warp_size, vec])
sch.vectorize(A_shared_vi)
sch.bind(A_shared_tx, "threadIdx.x")
sch.bind(A_shared_ty, "threadIdx.y")
sch.bind(A_shared_tz, "threadIdx.z")
sch.storage_align(block_conv_input_shared, 0, axis=-2, factor=32, offset=8)
write_sch(sch, log_path, "schedule_A_shared")
B_shared_fused = sch.fuse(*sch.get_loops(block_conv_weight_shared)[-2:])
B_shared_ty, B_shared_tz, B_shared_inner, B_shared_tx, B_shared_vi = sch.split(
B_shared_fused, factors=[block_row_warps, block_col_warps, None, warp_size, vec])
sch.storage_align(block_conv_weight_shared, 0, axis=-2, factor=32, offset=8)
sch.vectorize(B_shared_vi)
sch.bind(B_shared_tx, "threadIdx.x")
sch.bind(B_shared_ty, "threadIdx.y")
sch.bind(B_shared_tz, "threadIdx.z")
write_sch(sch, log_path, "schedule_B_shared")
def schedule_shared_output(block):
o_shared_fused = sch.fuse(*sch.get_loops(block)[-2:])
_, o_shared_tx, o_shared_vi = sch.split(
o_shared_fused, factors=[None, warp_size, vec]
)
sch.compute_at(block_bias_local, o_shared_vi, preserve_unit_loops=True)
sch.reverse_compute_at(block_conv_shared, o_shared_vi, preserve_unit_loops=True)
sch.vectorize(o_shared_vi)
# sch.unroll(o_shared_vi)
sch.bind(o_shared_tx, "threadIdx.x")
# schedule the output shared memory
schedule_shared_output(block_conv_bias)
# decompose reduction
init_block_b = sch.decompose_reduction(block_conv, ko)
write_sch(sch, log_path, "decompose_reduction")
init_block_b_loops = sch.get_loops(init_block_b)
def recover_c(i, j):
# cnhw
c = j
n = i // (output_width * output_height)
h = (i // output_width) % output_height
w = i % output_width
return (n, h, w, c)
sch.transform_layout(block_conv_shared, ("write", 0), recover_c)
write_sch(sch, log_path, "transform_layout")
def index_map_A(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
)
def index_map_B(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
)
def index_map_C(i, j):
return (
i // 16,
j // 16,
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
)
sch.transform_layout(block_conv_input_frag, ("write", 0), index_map_A)
sch.transform_layout(block_conv_weight_frag, ("write", 0), index_map_B)
sch.transform_layout(block_conv_output_frag, ("read", 0), index_map_C)
sch.tensorize(
sch.get_loops(init_block_b)[-2], MMA_fill_16x16_f16_INTRIN)
write_sch(sch, log_path, "tensorize_wmma_fill")
sch.tensorize(sch.get_loops(block_conv_input_frag)
[-2], LDMATRIX_16x16_A_INTRIN)
sch.tensorize(sch.get_loops(block_conv_weight_frag)
[-2], LDMATRIX_16x16_B_TRANS_INTRIN)
write_sch(sch, log_path, "tensorize_ldmatrix")
sch.tensorize(kernel_i, MMA_f16f16f16_TRANS_INTRIN)
write_sch(sch, log_path, "tensorize_wmma_sync")
out_i, out_j = sch.get_loops(block_conv_output_frag)[-2:]
out_i, ok_i = sch.split(out_i, factors=[None, wmma_m])
out_j, ok_j = sch.split(out_j, factors=[None, wmma_n])
sch.reorder(out_i, out_j, ok_i, ok_j)
sch.tensorize(ok_i, MMA_store_16x16_f16_shared_INTRIN)
write_sch(sch, log_path, "tensorize_store")
# unroll
write_sch(sch, log_path,
"do_unroll")
if stage > 0:
sch.annotate(ko, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1])
sch.annotate(ko, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
if use_async:
sch.annotate(ko, "software_pipeline_async_stages", [0])
if raster > 0:
sch.annotate(init_block_b_loops[-4], ann_key="thread_rasterization", ann_val=raster)
ctx = tvm.cuda(0)
with tvm.transform.PassContext(config={"tir.use_async_copy": use_async}):
cuda_mod = tvm.build(sch.mod, target="cuda")
write_code(cuda_mod.imported_modules[0].get_source(), log_path, "tmp.cu")
# random init a and b
a_np = np.random.uniform(size=data_shape).astype("float16")
# b_np = np.random.uniform(size=(N // wmma_m, K // wmma_k, wmma_n, wmma_k)).astype("float16")
# a_np = np.ones(data_shape).astype("float16")
b_np = np.ones((N, K)).astype("float16")
# aragnge init a and b
# a_np = np.mod(np.arange(0, batch_size * in_channels * height * width), 10).reshape(data_shape).astype("float16")
# b_np = np.mod(np.arange(0, kernel_h * kernel_w * in_channels * out_channels), 10).reshape(kernel_shape).astype("float16")
c_np = np.zeros((batch_size, output_height, output_height, out_channels)).astype("float16")
bias_np = np.ones((out_channels)).astype("float16")
cuda_a = tvm.nd.array(a_np, ctx)
cuda_b = tvm.nd.array(b_np, ctx)
cuda_c = tvm.nd.array(c_np, ctx)
cuda_bias = tvm.nd.array(bias_np, ctx)
cuda_mod(cuda_a, cuda_b, cuda_bias, cuda_c)
GridDim_z = 1
GridDim_y = int(sch.get_sref(block_i).stmt.extent)
GridDim_x = int(sch.get_sref(block_j).stmt.extent)
BlockDim_y = int(sch.get_sref(i).stmt.extent)
BlockDim_z = int(sch.get_sref(j).stmt.extent)
BlockDim_x = warp_size
print("GridDim_z", GridDim_z)
print("GridDim_y", GridDim_y)
print("GridDim_x", GridDim_x)
verified = False
if VERIFY:
# do conv with torch
import torch
# convert a from nhwc to nchw
a_np = np.transpose(a_np, (0, 3, 1, 2))
a_torch = torch.tensor(a_np, device="cuda", dtype=torch.float32)
a_torch = torch.nn.functional.pad(a_torch, (pad_h, pad_h, pad_w, pad_w))
# convert b from oihw1616 into oihw
b_np = b_np.reshape((out_channels, kernel_h, kernel_w, in_channels))
b_np = np.transpose(b_np, (0, 3, 1, 2))
b_torch = torch.tensor(b_np, device="cuda", dtype=torch.float32)
c_torch = torch.nn.functional.conv2d(
a_torch, b_torch, stride=(stride_h, stride_w), groups=1)
# c_torch_np = np.transpose(c_torch.cpu().numpy(), (1, 0, 2, 3))
c_torch_np = c_torch.cpu().numpy()
print(c_torch_np.shape)
# convert c from cnhw into cnhw
c_np = cuda_c.numpy().reshape(
(batch_size, out_channels, output_height, output_width)
)
verified = np.allclose(c_torch_np, c_np, atol=1e-1, rtol=1e-1)
print("torch result: ", c_torch_np[0][0][0][0:10])
print("tvm result: ", c_np[0][0][0][0:10])
if not verified:
print("torch result: ", c_torch_np[0][0][0][0:10])
print("tvm result: ", c_np[0][0][0][0:10])
num_runs = 3
timer_cuda_mod = cuda_mod.time_evaluator(
cuda_mod.entry_name, ctx, number=num_runs)
t = timer_cuda_mod(cuda_a, cuda_b, cuda_bias, cuda_c).mean
print(f"task conv2d N({batch_size}) C({in_channels}) H({height}) W({width}) R({kernel_h}) S({kernel_w}) K({out_channels}) P({pad_h}) Q({pad_w}) stride({stride_h}), verified({verified}) time cost: {t * 1e3} ms, {((2*M*N*K)/t)/ pow((1024), 4)} tflops, {((2*M*N*K)/t)/ pow((1024), 4) / 145 * 100} %")
nni.report_final_result(t * 1e3)
def export_info_to_json():
code = cuda_mod.imported_modules[0].get_source()
# delete str before __global__ in code
code = code.split("extern \"C\"")[1]
code = code.replace("half* __restrict__ A, half* __restrict__ W, half* __restrict__ Bias, half* __restrict__ Conv", "half* __restrict__ Bias, half* __restrict__ A, half* __restrict__ W, half* __restrict__ Conv")
import json
infos = {}
infos["grid_size"] = [GridDim_x, GridDim_y, GridDim_z]
infos["block_size"] = [BlockDim_x, BlockDim_y, BlockDim_z]
infos["code"] = code
infos["latency"] = t * 1e3
info_path = log_path + "info.json"
# print(infos)
json.dump(infos, open(info_path, "w"), indent=4)
export_info_to_json()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment