Created
October 19, 2023 07:57
-
-
Save JackWeiw/b40f6d1bcea0c85c282dd9cd9e5de46d to your computer and use it in GitHub Desktop.
Below is the script that comes across with Symbolic floormod
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import tir | |
import tvm | |
from tvm.script import ir as I | |
from tvm.script import tir as T | |
from tvm.script import relax as R | |
dtype="float16" | |
@I.ir_module | |
class Module: | |
@T.prim_func(private=True) | |
def fused_NT_matmul4_add1(p_lv55: T.handle, lv11: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), lv12: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): | |
T.func_attr({"tir.noalias": T.bool(True)}) | |
n = T.int64() | |
lv55 = T.match_buffer(p_lv55, (T.int64(1), n, T.int64(10240)), "float16") | |
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") | |
# with T.block("root"): | |
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") | |
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): | |
with T.block("NT_matmul"): | |
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | |
T.reads(lv55[v_i0, v_i1, v_k], lv11[v_i2, v_k]) | |
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) | |
with T.init(): | |
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) | |
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv55[v_i0, v_i1, v_k] * lv11[v_i2, v_k] | |
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): | |
with T.block("T_add"): | |
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) | |
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], lv12[v_ax2]) | |
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) | |
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + lv12[v_ax2] | |
@R.function | |
def WT_test(A: R.Tensor((1, "n", 10240), dtype=dtype), w_q: R.Tensor((2560, 10240), dtype=dtype), bias_q: R.Tensor((2560,), dtype=dtype)) -> R.Tensor((1, "n", 2560), dtype=dtype): | |
n=T.int64() | |
cls = Module | |
with R.dataflow(): | |
lv = R.call_tir(cls.fused_NT_matmul4_add1, (A, w_q, bias_q), out_sinfo=R.Tensor((1, n, 2560), dtype=dtype)) | |
gv: R.Tensor((1, n, 2560), dtype=dtype) = lv | |
R.output(gv) | |
return gv | |
target = tvm.target.Target("nvidia/geforce-rtx-3090") | |
dtype="float16" | |
dev=tvm.cuda(7) | |
mod=Module | |
func=mod["fused_NT_matmul4_add1"] | |
#-------------------------schedule transform---------------------- | |
from tvm.dlight.base import ScheduleRule, analysis | |
from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel | |
get_wmma_intrin_group, | |
) | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import Dict, List, Optional, Set, Tuple | |
from tvm import tir | |
from tvm.ir import Range | |
from tvm.target import Target | |
from tvm.tir import IterVar, PrimExpr, Var | |
from tvm.tir.analysis import undefined_vars | |
from tvm.tir.schedule.schedule import BlockRV | |
def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): | |
result = [] | |
for producer in sch.get_producers(block): | |
result.append(producer) | |
result.extend(_collect_producers(sch, producer)) | |
return result | |
def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): | |
result = [] | |
for consumer in sch.get_consumers(block): | |
result.append(consumer) | |
result.extend(_collect_consumers(sch, consumer)) | |
return result | |
def auto_inline_producers( | |
sch: tir.Schedule, | |
block: tir.schedule.BlockRV, | |
): | |
while True: | |
inlined_cnt = 0 | |
producers = _collect_producers(sch, block) | |
for producer in producers: | |
try: | |
sch.compute_inline(producer) | |
inlined_cnt += 1 | |
except: # pylint: disable=bare-except | |
continue | |
if inlined_cnt == 0: | |
return | |
def auto_inline_consumers( | |
sch: tir.Schedule, | |
block: tir.schedule.BlockRV, | |
): | |
while True: | |
inlined_cnt = 0 | |
consumers = _collect_consumers(sch, block) | |
for consumer in consumers: | |
try: | |
sch.compute_inline(consumer) | |
inlined_cnt += 1 | |
except: # pylint: disable=bare-except | |
continue | |
for consumer in consumers: | |
try: | |
sch.reverse_compute_inline(consumer) | |
inlined_cnt += 1 | |
except: # pylint: disable=bare-except | |
continue | |
if inlined_cnt == 0: | |
return | |
class IterKind(Enum): | |
"""Iter kinds for GEMM-liked programs. | |
We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], | |
where `I, J, K` are fundamental axes for gemm and `S` represents all | |
other spatial axes (e.g. batches) | |
kIter_S: spatial axes | |
kIter_I: I axes | |
kIter_J: J axes | |
kIter_K: K axes | |
kIter_T: trivial axes (i.e. with extent 1) | |
""" | |
kIter_S = 0 | |
kIter_I = 1 | |
kIter_J = 2 | |
kIter_K = 3 | |
kIter_T = 4 | |
@dataclass | |
class IterTrait: | |
kind: IterKind | |
extent: PrimExpr | |
def _is_one(x: PrimExpr) -> bool: | |
return isinstance(x, tir.IntImm) and x.value == 1 | |
def make_iter_fusion_index_map( | |
traits: List[IterTrait], | |
kind_order: List[IterKind], | |
) -> tir.IndexMap: | |
fused_iters: Dict[IterKind, PrimExpr] = {} | |
input_iters: List[tir.Var] = [] | |
for i, trait in enumerate(traits): | |
v_i = tir.Var(f"i{i}", "int64") | |
input_iters.append(v_i) | |
if trait.kind == IterKind.kIter_T: | |
continue | |
if trait.kind not in kind_order: | |
raise ValueError(f"Unknown iter kind {trait.kind}") | |
if trait.kind in fused_iters: | |
fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i | |
else: | |
fused_iters[trait.kind] = v_i | |
final_indices: List[tir.PrimExpr] = [ | |
fused_iters.get(kind, tir.IntImm("int64", 0)) for kind in kind_order | |
] | |
return tir.IndexMap(input_iters, final_indices, None) | |
def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: | |
"""Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] | |
Parameters | |
---------- | |
block : tir.Block | |
The block to be analyzed | |
Returns | |
------- | |
traits : Optional[Tuple[List[IterTrait]]] | |
The detected iter traits for axes in A, B and C. None if the block | |
does not match the pattern. | |
""" | |
if len(block.reads) != 2 or len(block.writes) != 1: | |
return None | |
def get_access_axes(region: List[Range]) -> Set[Var]: | |
axes: Set[Var] = set() | |
for r in region: | |
if not _is_one(r.extent): | |
raise ValueError("Expect elemwise block access") | |
axes = axes.union(set(undefined_vars(r.min))) | |
return axes | |
try: | |
A_axes = get_access_axes(block.reads[0].region) | |
B_axes = get_access_axes(block.reads[1].region) | |
C_axes = get_access_axes(block.writes[0].region) | |
except ValueError: | |
return None | |
traits: Dict[Var, IterTrait] = {} | |
for iter_var in block.iter_vars: | |
var = iter_var.var | |
kind: IterKind | |
if _is_one(iter_var.dom.extent): | |
kind = IterKind.kIter_T | |
elif iter_var.iter_type == iter_var.DataPar: | |
if var in A_axes and var in B_axes and var in C_axes: | |
kind = IterKind.kIter_S | |
elif var in A_axes and var in C_axes: | |
kind = IterKind.kIter_I | |
elif var in B_axes and var in C_axes: | |
kind = IterKind.kIter_J | |
else: | |
return None | |
elif iter_var.iter_type == tir.IterVar.CommReduce: | |
if var in A_axes and var in B_axes and var not in C_axes: | |
kind = IterKind.kIter_K | |
else: | |
return None | |
else: | |
return None | |
traits[var] = IterTrait(kind, iter_var.dom.extent) | |
# A Gemm-kernel requires have I, J and K axes | |
gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K} | |
if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: | |
return None | |
A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] | |
B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] | |
C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] | |
block_traits = [traits[i.var] for i in block.iter_vars] | |
return A_traits, B_traits, C_traits, block_traits | |
def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]: | |
"""Get index maps for the block | |
Parameters | |
---------- | |
block : tir.Block | |
The block to be analyzed | |
Returns | |
------- | |
index_maps : Optional[Tuple[tir.IndexMap]] | |
The index maps for the block, or None if the block is not a gemm-liked kernel | |
""" | |
traits = detect_iter_traits(block) | |
if traits is None: | |
return None | |
A_traits, B_traits, C_traits, block_traits = traits | |
A_index_map = make_iter_fusion_index_map( | |
A_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_K] | |
) | |
B_index_map = make_iter_fusion_index_map( | |
B_traits, [IterKind.kIter_S, IterKind.kIter_J, IterKind.kIter_K] | |
) | |
C_index_map = make_iter_fusion_index_map( | |
C_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J] | |
) | |
matmul_index_map = make_iter_fusion_index_map( | |
block_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K] | |
) | |
return ( | |
matmul_index_map, | |
A_index_map, | |
B_index_map, | |
C_index_map, | |
) | |
def get_reduction_blocks(sch, blocks) -> bool: | |
# Get the main computation block | |
def is_reduction(block: BlockRV) -> bool: | |
block_stmt = sch.get(block) | |
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} | |
return iter_types == {IterVar.CommReduce, IterVar.DataPar} | |
def is_spatial(block: BlockRV) -> bool: | |
block_stmt = sch.get(block) | |
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} | |
return iter_types == {IterVar.DataPar} | |
# NOTE: We assume there is only one reduction block in the function | |
# all blocks are required to be spatial or reduction | |
if not all([is_reduction(block) or is_spatial(block) for block in blocks]): | |
return None | |
# There is only one reduction block | |
reduction_blocks = [block for block in blocks if is_reduction(block)] | |
if len(reduction_blocks) != 1: | |
return None | |
return reduction_blocks | |
sch = tir.Schedule(func) | |
root_block = analysis.get_root_block(sch) | |
blocks = sch.get_child_blocks(root_block) | |
reduction_blocks = get_reduction_blocks(sch, blocks) | |
main_block = reduction_blocks[0] | |
block_stmt = sch.get(main_block) | |
index_maps = get_index_map(block_stmt) | |
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps | |
# Start Schedule | |
# Step 0. Get schedule config. | |
# NOTE: we can analyze the config by the hardware spec in the future | |
# tensor core intrinsic size | |
micro_size_x = 16 | |
micro_size_y = 16 | |
micro_size_k = 16 | |
warp_size = 32 | |
vector_size = 4 | |
i_factors, j_factors, k_factors = ( | |
[None, 1, 2, 2, 2], | |
[1, None, 2, 2, 2], | |
[None, 4], | |
) | |
#num_ty:2*2=16 | |
num_ty = i_factors[3] * j_factors[3] | |
#x_pad_factor:2*2=4 | |
x_pad_factor = i_factors[3] * i_factors[4] | |
#y_pad_factor:2*2=4 | |
y_pad_factor = j_factors[3] * j_factors[4] | |
#k_pad_factor:4 | |
k_pad_factor = k_factors[1] | |
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] | |
block = sch.reindex(main_block, ("read", 0)) | |
sch.transform_layout(block, ("write", 0), a_index_map) | |
block = sch.reindex(main_block, ("read", 1)) | |
sch.transform_layout(block, ("write", 0), b_index_map) | |
block = sch.reindex(main_block, ("write", 0)) | |
sch.transform_layout(block, ("read", 0), c_index_map) | |
sch.transform_block_layout(main_block, matmul_index_map) | |
# Step 2. Padding for dynamic shape kernels | |
sch.pad_einsum( | |
main_block, | |
[ | |
1, | |
micro_size_x * x_pad_factor, | |
micro_size_y * y_pad_factor, | |
micro_size_k * k_pad_factor, | |
], | |
) | |
# Step 3. Schedule matmul to use tensor core | |
block = main_block | |
batch, i, j, k = sch.get_loops(block) | |
# inner loops for tensor core computation | |
i, i_inner = sch.split(i, factors=[None, micro_size_x]) | |
j, j_inner = sch.split(j, factors=[None, micro_size_y]) | |
k, k_inner = sch.split(k, factors=[None, micro_size_k]) | |
#i=m/16,j=n/16,k=k/16,16,16,16 | |
sch.reorder(i, j, k, i_inner, j_inner, k_inner) | |
block_inner = block | |
block_outer = sch.blockize(i_inner) | |
#i0,i1,i2,i3,i4 = (m/16)/(2*2*2), 1, 2, 2, 2 | |
i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) | |
#j0,j1,j2,j3,j4 = 1, (n/16)/(2*2*2), 2, 2, 2 | |
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) | |
#k0,k1=(k/16)/4,4 | |
k0, k1 = sch.split(k, k_factors) | |
#i0, j0, i1, j1, j2, i2, i3, j3, k0, k1, i4, j4 = m/128,1, 1,n/128, 2,2, 2,2, k/64,4, 2,2 | |
sch.reorder(i0, j0, i1, j1, j2, i2, i3, j3, k0, k1, i4, j4) | |
#block_idx=m/128 | |
block_idx = sch.fuse(i0, j0) | |
#block_idy=n/128 | |
block_idy = sch.fuse(i1, j1) | |
#thread_idy=4 | |
thread_idy = sch.fuse(j2, i2) | |
sch.bind(batch, "blockIdx.z") | |
sch.bind(block_idx, "blockIdx.x") | |
sch.bind(block_idy, "blockIdx.y") | |
sch.bind(thread_idy, "threadIdx.y") | |
def fetch_to_shared(block, idx, ndim): | |
block_read = sch.cache_read(block, idx, "shared.dyn") | |
# for A_m*k, compute_at k0,那么在k0之内的loop | |
sch.compute_at(block_read, k0) | |
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) | |
_, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) | |
sch.bind(f_2, "threadIdx.x") | |
sch.bind(f_1, "threadIdx.y") | |
sch.vectorize(f_3) | |
sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) | |
return block_read | |
a_g2s = fetch_to_shared(block_outer, 0, 2) | |
b_g2s = fetch_to_shared(block_outer, 1, 2) | |
auto_inline_producers(sch, a_g2s) | |
auto_inline_producers(sch, b_g2s) | |
# create read cache to load matrix from shared memory to wmma fragments | |
A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") | |
B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") | |
sch.compute_at(A_mat, k1) | |
sch.compute_at(B_mat, k1) | |
# create write cache to store matrix from wmma fragments to shared memory and global memory | |
accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") | |
sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) | |
store = sch.cache_write(block_outer, 0, "wmma.accumulator") | |
sch.reverse_compute_at(store, thread_idy) | |
sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) | |
# split the store loop to match hardware intrinsic pattern | |
i, j = sch.get_loops(store)[-2:] | |
i0, i1 = sch.split(i, factors=[None, 16]) | |
j0, j1 = sch.split(j, factors=[None, 16]) | |
sch.reorder(i0, j0, i1, j1) | |
block_init_c = sch.decompose_reduction(block_outer, k0) | |
block_init_c_inner = sch.get_child_blocks(block_init_c)[0] | |
# Tensorization by hardware intrinsics | |
intrin_group = get_wmma_intrin_group( | |
load_scope="shared.dyn", | |
store_scope="shared.dyn", | |
in_dtype="float16", | |
out_dtype="float32", | |
trans_b=True, | |
) | |
try: | |
print("here") | |
i, j = sch.get_loops(A_mat)[-2:] | |
i0, i1 = sch.split(i, factors=[None, 16]) | |
j0, j1 = sch.split(j, factors=[None, 16]) | |
sch.reorder(i0, j0, i1, j1) | |
sch.unroll(i0) | |
sch.unroll(j0) | |
sch.tensorize(i1, intrin_group["load_a"]) | |
i, j = sch.get_loops(B_mat)[-2:] | |
i0, i1 = sch.split(i, factors=[None, 16]) | |
j0, j1 = sch.split(j, factors=[None, 16]) | |
sch.reorder(i0, j0, i1, j1) | |
sch.unroll(i0) | |
sch.unroll(j0) | |
sch.tensorize(i1, intrin_group["load_b"]) | |
print("here") | |
except: # pylint: disable=bare-except | |
print("failed") | |
# Try to tensorize the init, store and compute block with f16 or f32 intrinsics | |
tensorize_success: bool = False | |
def tensorize_init_store_compute(): | |
sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) | |
sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) | |
sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) | |
try: | |
print("try1") | |
tensorize_init_store_compute() | |
tensorize_success = True | |
except: # pylint: disable=bare-except | |
intrin_group = get_wmma_intrin_group( | |
load_scope="shared.dyn", | |
store_scope="shared.dyn", | |
in_dtype="float16", | |
out_dtype="float16", | |
trans_b=True, | |
) | |
if not tensorize_success: | |
try: | |
print("try2") | |
tensorize_init_store_compute() | |
tensorize_success = True | |
except: # pylint: disable=bare-except | |
print("failed2") | |
auto_inline_consumers(sch, accumulator_shared_to_global) | |
fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) | |
_, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) | |
sch.bind(f1, "threadIdx.x") | |
sch.vectorize(f2) | |
mod["fused_NT_matmul4_add1"] = sch.mod["main"] | |
print(mod) | |
pass_list=[] | |
pass_list.append(tir.transform.InjectPrefetch()) | |
pass_list.append(tir.transform.TextureFlatten()) | |
pass_list.append(tir.transform.StorageFlatten(64,False)) | |
pass_list.append(tir.transform.LowerInitBlock()) | |
pass_list.append(tir.transform.PlanAndUpdateBufferAllocationLocation()) | |
pass_list.append(tir.transform.ConvertBlocksToOpaque()) | |
pass_list.append(tir.transform.LiftThreadBinding()) | |
pass_list.append(tir.transform.ManifestSharedMemoryLocalStage()) | |
pass_list.append(tir.transform.CompactBufferAllocation()) | |
pass_list.append(tir.transform.LowerAutoCopy()) | |
pass_list.append(tir.transform.UnifyThreadBinding()) | |
pass_list.append(tir.transform.LowerMatchBuffer()) | |
pass_list.append(tir.transform.Simplify()) | |
pass_list.append(tir.transform.InjectPermutedLayout()) | |
pass_list.append(tir.transform.Simplify()) | |
pass_list.append(tir.transform.InjectSoftwarePipeline()) | |
pass_list.append(tir.transform.TransformMmaBufferLayout()) | |
pass_list.append(tir.transform.LowerOpaqueBlock()) | |
seq = tvm.transform.Sequential( | |
pass_list | |
) | |
mod = seq(mod) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment