Skip to content

Instantly share code, notes, and snippets.

@JackWeiw
Created October 19, 2023 07:57
Show Gist options
  • Save JackWeiw/b40f6d1bcea0c85c282dd9cd9e5de46d to your computer and use it in GitHub Desktop.
Save JackWeiw/b40f6d1bcea0c85c282dd9cd9e5de46d to your computer and use it in GitHub Desktop.
Below is the script that comes across with Symbolic floormod
import tvm
from tvm import tir
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
dtype="float16"
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_NT_matmul4_add1(p_lv55: T.handle, lv11: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), lv12: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv55 = T.match_buffer(p_lv55, (T.int64(1), n, T.int64(10240)), "float16")
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16")
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16")
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv55[v_i0, v_i1, v_k], lv11[v_i2, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv55[v_i0, v_i1, v_k] * lv11[v_i2, v_k]
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], lv12[v_ax2])
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + lv12[v_ax2]
@R.function
def WT_test(A: R.Tensor((1, "n", 10240), dtype=dtype), w_q: R.Tensor((2560, 10240), dtype=dtype), bias_q: R.Tensor((2560,), dtype=dtype)) -> R.Tensor((1, "n", 2560), dtype=dtype):
n=T.int64()
cls = Module
with R.dataflow():
lv = R.call_tir(cls.fused_NT_matmul4_add1, (A, w_q, bias_q), out_sinfo=R.Tensor((1, n, 2560), dtype=dtype))
gv: R.Tensor((1, n, 2560), dtype=dtype) = lv
R.output(gv)
return gv
target = tvm.target.Target("nvidia/geforce-rtx-3090")
dtype="float16"
dev=tvm.cuda(7)
mod=Module
func=mod["fused_NT_matmul4_add1"]
#-------------------------schedule transform----------------------
from tvm.dlight.base import ScheduleRule, analysis
from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel
get_wmma_intrin_group,
)
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple
from tvm import tir
from tvm.ir import Range
from tvm.target import Target
from tvm.tir import IterVar, PrimExpr, Var
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for producer in sch.get_producers(block):
result.append(producer)
result.extend(_collect_producers(sch, producer))
return result
def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
result = []
for consumer in sch.get_consumers(block):
result.append(consumer)
result.extend(_collect_consumers(sch, consumer))
return result
def auto_inline_producers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
):
while True:
inlined_cnt = 0
producers = _collect_producers(sch, block)
for producer in producers:
try:
sch.compute_inline(producer)
inlined_cnt += 1
except: # pylint: disable=bare-except
continue
if inlined_cnt == 0:
return
def auto_inline_consumers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
):
while True:
inlined_cnt = 0
consumers = _collect_consumers(sch, block)
for consumer in consumers:
try:
sch.compute_inline(consumer)
inlined_cnt += 1
except: # pylint: disable=bare-except
continue
for consumer in consumers:
try:
sch.reverse_compute_inline(consumer)
inlined_cnt += 1
except: # pylint: disable=bare-except
continue
if inlined_cnt == 0:
return
class IterKind(Enum):
"""Iter kinds for GEMM-liked programs.
We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K],
where `I, J, K` are fundamental axes for gemm and `S` represents all
other spatial axes (e.g. batches)
kIter_S: spatial axes
kIter_I: I axes
kIter_J: J axes
kIter_K: K axes
kIter_T: trivial axes (i.e. with extent 1)
"""
kIter_S = 0
kIter_I = 1
kIter_J = 2
kIter_K = 3
kIter_T = 4
@dataclass
class IterTrait:
kind: IterKind
extent: PrimExpr
def _is_one(x: PrimExpr) -> bool:
return isinstance(x, tir.IntImm) and x.value == 1
def make_iter_fusion_index_map(
traits: List[IterTrait],
kind_order: List[IterKind],
) -> tir.IndexMap:
fused_iters: Dict[IterKind, PrimExpr] = {}
input_iters: List[tir.Var] = []
for i, trait in enumerate(traits):
v_i = tir.Var(f"i{i}", "int64")
input_iters.append(v_i)
if trait.kind == IterKind.kIter_T:
continue
if trait.kind not in kind_order:
raise ValueError(f"Unknown iter kind {trait.kind}")
if trait.kind in fused_iters:
fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i
else:
fused_iters[trait.kind] = v_i
final_indices: List[tir.PrimExpr] = [
fused_iters.get(kind, tir.IntImm("int64", 0)) for kind in kind_order
]
return tir.IndexMap(input_iters, final_indices, None)
def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
"""Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
Parameters
----------
block : tir.Block
The block to be analyzed
Returns
-------
traits : Optional[Tuple[List[IterTrait]]]
The detected iter traits for axes in A, B and C. None if the block
does not match the pattern.
"""
if len(block.reads) != 2 or len(block.writes) != 1:
return None
def get_access_axes(region: List[Range]) -> Set[Var]:
axes: Set[Var] = set()
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
axes = axes.union(set(undefined_vars(r.min)))
return axes
try:
A_axes = get_access_axes(block.reads[0].region)
B_axes = get_access_axes(block.reads[1].region)
C_axes = get_access_axes(block.writes[0].region)
except ValueError:
return None
traits: Dict[Var, IterTrait] = {}
for iter_var in block.iter_vars:
var = iter_var.var
kind: IterKind
if _is_one(iter_var.dom.extent):
kind = IterKind.kIter_T
elif iter_var.iter_type == iter_var.DataPar:
if var in A_axes and var in B_axes and var in C_axes:
kind = IterKind.kIter_S
elif var in A_axes and var in C_axes:
kind = IterKind.kIter_I
elif var in B_axes and var in C_axes:
kind = IterKind.kIter_J
else:
return None
elif iter_var.iter_type == tir.IterVar.CommReduce:
if var in A_axes and var in B_axes and var not in C_axes:
kind = IterKind.kIter_K
else:
return None
else:
return None
traits[var] = IterTrait(kind, iter_var.dom.extent)
# A Gemm-kernel requires have I, J and K axes
gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K}
if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits:
return None
A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes]
B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes]
C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes]
block_traits = [traits[i.var] for i in block.iter_vars]
return A_traits, B_traits, C_traits, block_traits
def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]:
"""Get index maps for the block
Parameters
----------
block : tir.Block
The block to be analyzed
Returns
-------
index_maps : Optional[Tuple[tir.IndexMap]]
The index maps for the block, or None if the block is not a gemm-liked kernel
"""
traits = detect_iter_traits(block)
if traits is None:
return None
A_traits, B_traits, C_traits, block_traits = traits
A_index_map = make_iter_fusion_index_map(
A_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_K]
)
B_index_map = make_iter_fusion_index_map(
B_traits, [IterKind.kIter_S, IterKind.kIter_J, IterKind.kIter_K]
)
C_index_map = make_iter_fusion_index_map(
C_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J]
)
matmul_index_map = make_iter_fusion_index_map(
block_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K]
)
return (
matmul_index_map,
A_index_map,
B_index_map,
C_index_map,
)
def get_reduction_blocks(sch, blocks) -> bool:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.CommReduce, IterVar.DataPar}
def is_spatial(block: BlockRV) -> bool:
block_stmt = sch.get(block)
iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
return iter_types == {IterVar.DataPar}
# NOTE: We assume there is only one reduction block in the function
# all blocks are required to be spatial or reduction
if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
return None
# There is only one reduction block
reduction_blocks = [block for block in blocks if is_reduction(block)]
if len(reduction_blocks) != 1:
return None
return reduction_blocks
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
reduction_blocks = get_reduction_blocks(sch, blocks)
main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
# Start Schedule
# Step 0. Get schedule config.
# NOTE: we can analyze the config by the hardware spec in the future
# tensor core intrinsic size
micro_size_x = 16
micro_size_y = 16
micro_size_k = 16
warp_size = 32
vector_size = 4
i_factors, j_factors, k_factors = (
[None, 1, 2, 2, 2],
[1, None, 2, 2, 2],
[None, 4],
)
#num_ty:2*2=16
num_ty = i_factors[3] * j_factors[3]
#x_pad_factor:2*2=4
x_pad_factor = i_factors[3] * i_factors[4]
#y_pad_factor:2*2=4
y_pad_factor = j_factors[3] * j_factors[4]
#k_pad_factor:4
k_pad_factor = k_factors[1]
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)
# Step 2. Padding for dynamic shape kernels
sch.pad_einsum(
main_block,
[
1,
micro_size_x * x_pad_factor,
micro_size_y * y_pad_factor,
micro_size_k * k_pad_factor,
],
)
# Step 3. Schedule matmul to use tensor core
block = main_block
batch, i, j, k = sch.get_loops(block)
# inner loops for tensor core computation
i, i_inner = sch.split(i, factors=[None, micro_size_x])
j, j_inner = sch.split(j, factors=[None, micro_size_y])
k, k_inner = sch.split(k, factors=[None, micro_size_k])
#i=m/16,j=n/16,k=k/16,16,16,16
sch.reorder(i, j, k, i_inner, j_inner, k_inner)
block_inner = block
block_outer = sch.blockize(i_inner)
#i0,i1,i2,i3,i4 = (m/16)/(2*2*2), 1, 2, 2, 2
i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
#j0,j1,j2,j3,j4 = 1, (n/16)/(2*2*2), 2, 2, 2
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
#k0,k1=(k/16)/4,4
k0, k1 = sch.split(k, k_factors)
#i0, j0, i1, j1, j2, i2, i3, j3, k0, k1, i4, j4 = m/128,1, 1,n/128, 2,2, 2,2, k/64,4, 2,2
sch.reorder(i0, j0, i1, j1, j2, i2, i3, j3, k0, k1, i4, j4)
#block_idx=m/128
block_idx = sch.fuse(i0, j0)
#block_idy=n/128
block_idy = sch.fuse(i1, j1)
#thread_idy=4
thread_idy = sch.fuse(j2, i2)
sch.bind(batch, "blockIdx.z")
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
sch.bind(thread_idy, "threadIdx.y")
def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, "shared.dyn")
# for A_m*k, compute_at k0,那么在k0之内的loop
sch.compute_at(block_read, k0)
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
_, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8)
return block_read
a_g2s = fetch_to_shared(block_outer, 0, 2)
b_g2s = fetch_to_shared(block_outer, 1, 2)
auto_inline_producers(sch, a_g2s)
auto_inline_producers(sch, b_g2s)
# create read cache to load matrix from shared memory to wmma fragments
A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
sch.compute_at(A_mat, k1)
sch.compute_at(B_mat, k1)
# create write cache to store matrix from wmma fragments to shared memory and global memory
accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn")
sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
store = sch.cache_write(block_outer, 0, "wmma.accumulator")
sch.reverse_compute_at(store, thread_idy)
sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
# split the store loop to match hardware intrinsic pattern
i, j = sch.get_loops(store)[-2:]
i0, i1 = sch.split(i, factors=[None, 16])
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(i0, j0, i1, j1)
block_init_c = sch.decompose_reduction(block_outer, k0)
block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
# Tensorization by hardware intrinsics
intrin_group = get_wmma_intrin_group(
load_scope="shared.dyn",
store_scope="shared.dyn",
in_dtype="float16",
out_dtype="float32",
trans_b=True,
)
try:
print("here")
i, j = sch.get_loops(A_mat)[-2:]
i0, i1 = sch.split(i, factors=[None, 16])
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(i0, j0, i1, j1)
sch.unroll(i0)
sch.unroll(j0)
sch.tensorize(i1, intrin_group["load_a"])
i, j = sch.get_loops(B_mat)[-2:]
i0, i1 = sch.split(i, factors=[None, 16])
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(i0, j0, i1, j1)
sch.unroll(i0)
sch.unroll(j0)
sch.tensorize(i1, intrin_group["load_b"])
print("here")
except: # pylint: disable=bare-except
print("failed")
# Try to tensorize the init, store and compute block with f16 or f32 intrinsics
tensorize_success: bool = False
def tensorize_init_store_compute():
sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"])
sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"])
try:
print("try1")
tensorize_init_store_compute()
tensorize_success = True
except: # pylint: disable=bare-except
intrin_group = get_wmma_intrin_group(
load_scope="shared.dyn",
store_scope="shared.dyn",
in_dtype="float16",
out_dtype="float16",
trans_b=True,
)
if not tensorize_success:
try:
print("try2")
tensorize_init_store_compute()
tensorize_success = True
except: # pylint: disable=bare-except
print("failed2")
auto_inline_consumers(sch, accumulator_shared_to_global)
fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:])
_, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
mod["fused_NT_matmul4_add1"] = sch.mod["main"]
print(mod)
pass_list=[]
pass_list.append(tir.transform.InjectPrefetch())
pass_list.append(tir.transform.TextureFlatten())
pass_list.append(tir.transform.StorageFlatten(64,False))
pass_list.append(tir.transform.LowerInitBlock())
pass_list.append(tir.transform.PlanAndUpdateBufferAllocationLocation())
pass_list.append(tir.transform.ConvertBlocksToOpaque())
pass_list.append(tir.transform.LiftThreadBinding())
pass_list.append(tir.transform.ManifestSharedMemoryLocalStage())
pass_list.append(tir.transform.CompactBufferAllocation())
pass_list.append(tir.transform.LowerAutoCopy())
pass_list.append(tir.transform.UnifyThreadBinding())
pass_list.append(tir.transform.LowerMatchBuffer())
pass_list.append(tir.transform.Simplify())
pass_list.append(tir.transform.InjectPermutedLayout())
pass_list.append(tir.transform.Simplify())
pass_list.append(tir.transform.InjectSoftwarePipeline())
pass_list.append(tir.transform.TransformMmaBufferLayout())
pass_list.append(tir.transform.LowerOpaqueBlock())
seq = tvm.transform.Sequential(
pass_list
)
mod = seq(mod)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment