Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created April 26, 2022 09:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yzh119/eb5baacdb7b346e272cc82c043e95b16 to your computer and use it in GitHub Desktop.
Save yzh119/eb5baacdb7b346e272cc82c043e95b16 to your computer and use it in GitHub Desktop.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm.script import tir as T
@T.prim_func
def csrmm(
a: T.handle,
b: T.handle,
c: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
k: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [n * k], dtype="float32", strides=[1])
C_data = T.match_buffer(c, [m * k], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi, v_vk in T.grid(m, k):
with T.block("csrmm0"):
vi, vk = T.axis.remap("SS", [v_vi, v_vk])
T.reads(
J_indptr_data[0 : m + 1], A_data[0:nnz], B_data[0 : n * k], J_indices_data[0:nnz]
)
T.writes(C_data[vi * k + vk])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("csrmm1"):
vj = T.axis.reduce(n, v_vj)
T.reads(
A_data[vj + J_indptr_data[vi]],
B_data[J_indices_data[vj + J_indptr_data[vi]] * k + vk],
J_indices_data[vj + J_indptr_data[vi]],
)
T.writes(C_data[vi * k + vk])
T.block_attr({"sparse": True})
with T.init():
C_data[vi * k + vk] = T.float32(0)
C_data[vi * k + vk] = (
C_data[vi * k + vk]
+ A_data[vj + J_indptr_data[vi]]
* B_data[J_indices_data[vj + J_indptr_data[vi]] * k + vk]
)
@T.prim_func
def segment_reduce(a: T.handle, b: T.handle, indptr: T.handle, n: T.int32, nnz: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [n], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi in T.serial(n):
with T.block("segment_reduce0"):
vi = T.axis.spatial(n, v_vi)
T.reads(J_indptr_data[0 : n + 1], A_data[0:nnz])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("segment_reduce1"):
vj = T.axis.reduce(100, v_vj)
T.reads(A_data[vj + J_indptr_data[vi]])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
with T.init():
B_data[vi] = T.float32(0)
B_data[vi] = B_data[vi] + A_data[vj + J_indptr_data[vi]]
@T.prim_func
def csr_reduce(
a: T.handle,
b: T.handle,
indptr: T.handle,
indices: T.handle,
n: T.int32,
m: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [n], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi in T.serial(n):
with T.block("csr_reduce0"):
vi = T.axis.spatial(n, v_vi)
T.reads(J_indptr_data[0 : n + 1], A_data[0:nnz])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("csr_reduce1"):
vj = T.axis.reduce(m, v_vj)
T.reads(A_data[vj + J_indptr_data[vi]])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
with T.init():
B_data[vi] = T.float32(0)
B_data[vi] = B_data[vi] + A_data[vj + J_indptr_data[vi]]
@T.prim_func
def csrmm_dense_iter(
a: T.handle,
b: T.handle,
c: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
k: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [n * k], dtype="float32", strides=[1])
C_data = T.match_buffer(c, [m * k], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0[0] = -1
for v_vi, v_vj, v_vk in T.grid(m, n, k):
with T.block("binary_search_0"):
ax0, ax1, ax2 = T.axis.remap("SSS", [v_vi, v_vj, v_vk])
T.where(mid_0[0] == -1)
T.reads(J_indices_data[0:nnz])
T.writes(mid_0[0])
T.block_attr({"sparse": True})
low[0] = 0
high[0] = J_indptr_data[ax0 + 1] - J_indptr_data[ax0]
while low[0] < high[0]:
mid_0[0] = low[0] + (high[0] - low[0]) // 2
if J_indices_data[mid_0[0] + J_indptr_data[ax0]] < ax1:
low[0] = mid_0[0] + 1
else:
high[0] = mid_0[0]
with T.block("csrmm0"):
vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk])
T.reads(A_data[mid_0[0] + J_indptr_data[vi]], mid_0[0], B_data[vj * k + vk])
T.writes(C_data[vi * k + vk])
T.block_attr({"sparse": True})
with T.init():
C_data[vi * k + vk] = T.float32(0)
C_data[vi * k + vk] = (
C_data[vi * k + vk] + A_data[mid_0[0] + J_indptr_data[vi]] * B_data[vj * k + vk]
)
@T.prim_func
def bsrmm(
a: T.handle,
b: T.handle,
c: T.handle,
indptr: T.handle,
indices: T.handle,
nb: T.int32,
mb: T.int32,
nnzb: T.int32,
blk: T.int32,
feat_size: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnzb * blk * blk], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32", strides=[1])
C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [nb + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnzb], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi, v_vbi, v_vbj, v_vf in T.grid(nb, blk, blk, feat_size):
with T.block("bsrmm0"):
vi, vbi, vbj, vf = T.axis.remap("SSRS", [v_vi, v_vbi, v_vbj, v_vf])
T.reads(
J_indptr_data[0 : nb + 1],
A_data[0 : nnzb * blk * blk],
B_data[0 : mb * blk * feat_size],
J_indices_data[0:nnzb],
)
T.writes(C_data[vi * (blk * feat_size) + vbi * feat_size + vf])
T.block_attr({"sparse": True})
with T.init():
C_data[vi * (blk * feat_size) + vbi * feat_size + vf] = T.float32(0)
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("bsrmm1"):
vj = T.axis.reduce(mb, v_vj)
T.reads(
A_data[(vj + J_indptr_data[vi]) * (blk * blk) + vbi * blk + vbj],
B_data[
J_indices_data[vj + J_indptr_data[vi]] * (blk * feat_size)
+ vbj * feat_size
+ vf
],
J_indices_data[vj + J_indptr_data[vi]],
)
T.writes(C_data[vi * (blk * feat_size) + vbi * feat_size + vf])
T.block_attr({"sparse": True})
C_data[vi * (blk * feat_size) + vbi * feat_size + vf] = (
C_data[vi * (blk * feat_size) + vbi * feat_size + vf]
+ A_data[(vj + J_indptr_data[vi]) * (blk * blk) + vbi * blk + vbj]
* B_data[
J_indices_data[vj + J_indptr_data[vi]] * (blk * feat_size)
+ vbj * feat_size
+ vf
]
)
@T.prim_func
def ellmm(
a: T.handle,
b: T.handle,
c: T.handle,
indices: T.handle,
nb: T.int32,
mb: T.int32,
feat_size: T.int32,
col: T.int32,
blk: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nb * col * blk * blk], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32", strides=[1])
C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32", strides=[1])
J_indices_data = T.match_buffer(indices, [nb * col], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size):
with T.block("ellmm0"):
vi = T.axis.spatial(nb, v_vi)
vj = T.axis.reduce(col, v_vj)
vbi, vbj, vf = T.axis.remap("SRS", [v_vbi, v_vbj, v_vf])
T.reads(
A_data[vi * (col * blk * blk) + vj * (blk * blk) + vbi * blk + vbj],
B_data[J_indices_data[vi * col + vj] * (blk * feat_size) + vbj * feat_size + vf],
J_indices_data[vi * col + vj],
)
T.writes(C_data[vi * (blk * feat_size) + vbi * feat_size + vf])
T.block_attr({"sparse": True})
with T.init():
C_data[vi * (blk * feat_size) + vbi * feat_size + vf] = T.float32(0)
C_data[vi * (blk * feat_size) + vbi * feat_size + vf] = (
C_data[vi * (blk * feat_size) + vbi * feat_size + vf]
+ A_data[vi * (col * blk * blk) + vj * (blk * blk) + vbi * blk + vbj]
* B_data[J_indices_data[vi * col + vj] * (blk * feat_size) + vbj * feat_size + vf]
)
@T.prim_func
def csr_element_wise(
a: T.handle,
b: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [nnz], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi in T.serial(m):
with T.block("csr_element_wise0"):
vi = T.axis.spatial(m, v_vi)
T.reads(J_indptr_data[0 : m + 1], A_data[0:nnz])
T.writes(B_data[0:nnz])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("csr_element_wise1"):
vj = T.axis.spatial(n, v_vj)
T.reads(A_data[vj + J_indptr_data[vi]])
T.writes(B_data[vj + J_indptr_data[vi]])
T.block_attr({"sparse": True})
B_data[vj + J_indptr_data[vi]] = A_data[vj + J_indptr_data[vi]] * T.float32(2.5)
@T.prim_func
def hyper_gnn(
x: T.handle,
y: T.handle,
indptr: T.handle,
indices: T.handle,
indptr_T: T.handle,
indices_T: T.handle,
n: T.int32,
m: T.int32,
nnz: T.int32,
feat_size: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
X_data = T.match_buffer(x, [n * feat_size], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [n * feat_size], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
I_T_indptr_data = T.match_buffer(indptr_T, [m + 1], dtype="int32", strides=[1])
I_T_indices_data = T.match_buffer(indices_T, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi, v_vf in T.grid(n, feat_size):
with T.block("hyper_gnn0"):
vi, vf = T.axis.remap("SS", [v_vi, v_vf])
T.reads(
J_indptr_data[0 : n + 1],
I_T_indptr_data[0 : m + 1],
J_indices_data[0:nnz],
X_data[0 : n * feat_size],
I_T_indices_data[0:nnz],
)
T.writes(Y_data[vi * feat_size + vf])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("hyper_gnn1"):
vj = T.axis.reduce(m, v_vj)
T.reads(
I_T_indptr_data[0 : m + 1],
J_indices_data[vj + J_indptr_data[vi]],
X_data[0 : n * feat_size],
I_T_indices_data[0:nnz],
)
T.writes(Y_data[vi * feat_size + vf])
T.block_attr({"sparse": True})
with T.init():
Y_data[vi * feat_size + vf] = T.float32(0)
for v_vi_t in T.serial(
I_T_indptr_data[J_indices_data[vj + J_indptr_data[vi]] + 1]
- I_T_indptr_data[J_indices_data[vj + J_indptr_data[vi]]]
):
with T.block("hyper_gnn2"):
vi_t = T.axis.reduce(n, v_vi_t)
T.reads(
X_data[
I_T_indices_data[vi_t + I_T_indptr_data[vj]] * feat_size + vf
],
I_T_indices_data[vi_t + I_T_indptr_data[vj]],
)
T.writes(Y_data[vi * feat_size + vf])
T.block_attr({"sparse": True})
Y_data[vi * feat_size + vf] = (
Y_data[vi * feat_size + vf]
+ X_data[
I_T_indices_data[vi_t + I_T_indptr_data[vj]] * feat_size + vf
]
)
@T.prim_func
def sddmm(
a: T.handle,
b: T.handle,
c: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
k: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [m * k], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [n * k], dtype="float32", strides=[1])
C_data = T.match_buffer(c, [nnz], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi in T.serial(m):
with T.block("sddmm0"):
vi = T.axis.spatial(m, v_vi)
T.reads(
J_indptr_data[0 : m + 1],
A_data[0 : m * k],
B_data[0 : n * k],
J_indices_data[0:nnz],
)
T.writes(C_data[0:nnz])
T.block_attr({"sparse": True})
for v_vj, v_vk in T.grid(J_indptr_data[vi + 1] - J_indptr_data[vi], k):
with T.block("sddmm1"):
vj = T.axis.spatial(n, v_vj)
vk = T.axis.reduce(k, v_vk)
T.reads(
A_data[vi * k + vk],
B_data[J_indices_data[vj + J_indptr_data[vi]] * k + vk],
J_indices_data[vj + J_indptr_data[vi]],
)
T.writes(C_data[vj + J_indptr_data[vi]])
T.block_attr({"sparse": True})
with T.init():
C_data[vj + J_indptr_data[vi]] = T.float32(0)
C_data[vj + J_indptr_data[vi]] = (
C_data[vj + J_indptr_data[vi]]
+ A_data[vi * k + vk]
* B_data[J_indices_data[vj + J_indptr_data[vi]] * k + vk]
)
@T.prim_func
def fused_sddmm(
a: T.handle,
b: T.handle,
c: T.handle,
indptr: T.handle,
indices: T.handle,
m: T.int32,
n: T.int32,
k: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [m * k], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [n * k], dtype="float32", strides=[1])
C_data = T.match_buffer(c, [nnz], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [m + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0[0] = -1
for v_vj, v_vk in T.grid(nnz, k):
with T.block("binary_search_0"):
ax0 = T.axis.spatial(1, 0)
ax1, ax2 = T.axis.remap("SS", [v_vj, v_vk])
T.where(mid_0[0] == -1)
T.reads(J_indptr_data[0 : m + 1])
T.writes(mid_0[0])
T.block_attr({"sparse": True})
low[0] = 0
high[0] = m + 1
while low[0] < high[0]:
mid_0[0] = low[0] + (high[0] - low[0]) // 2
if J_indptr_data[mid_0[0]] > ax1:
high[0] = mid_0[0]
else:
low[0] = mid_0[0] + 1
mid_0[0] = mid_0[0] - 1
with T.block("sddmm0"):
vi = T.axis.spatial(1, 0)
vj, vk = T.axis.remap("SR", [v_vj, v_vk])
T.reads(
A_data[mid_0[0] * k + vk],
mid_0[0],
B_data[J_indices_data[vj + J_indptr_data[vi]] * k + vk],
J_indices_data[vj + J_indptr_data[vi]],
)
T.writes(C_data[vj + J_indptr_data[vi]])
T.block_attr({"sparse": True})
with T.init():
C_data[vj] = T.float32(0)
C_data[vj] = (
C_data[vj] + A_data[mid_0[0] * k + vk] * B_data[J_indices_data[vj] * k + vk]
)
@T.prim_func
def square_sum(
a: T.handle,
b: T.handle,
indptr_j: T.handle,
indices_j: T.handle,
indptr_k: T.handle,
indices_k: T.handle,
nnz_j: T.int32,
nnz_k: T.int32,
M: T.int32,
N1: T.int32,
N2: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz_k], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [M], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr_j, [M + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices_j, [nnz_j], dtype="int32", strides=[1])
K_indptr_data = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32", strides=[1])
K_indices_data = T.match_buffer(indices_k, [nnz_k], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi in T.serial(M):
with T.block("square_sum0"):
vi = T.axis.spatial(M, v_vi)
T.reads(J_indptr_data[0 : M + 1], K_indptr_data[0 : nnz_j + 1], A_data[0:nnz_k])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("square_sum1"):
vj = T.axis.reduce(N1, v_vj)
T.reads(K_indptr_data[0 : nnz_j + 1], A_data[0:nnz_k])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
with T.init():
B_data[vi] = T.float32(0)
for v_vk in T.serial(
K_indptr_data[vj + J_indptr_data[vi] + 1]
- K_indptr_data[vj + J_indptr_data[vi]]
):
with T.block("square_sum2"):
vk = T.axis.reduce(N2, v_vk)
T.reads(A_data[vk + K_indptr_data[vj + J_indptr_data[vi]]])
T.writes(B_data[vi])
T.block_attr({"sparse": True})
B_data[vi] = (
B_data[vi] + A_data[vk + K_indptr_data[vj + J_indptr_data[vi]]]
)
@T.prim_func
def square_sum_two_K(
a: T.handle,
b: T.handle,
indptr_j: T.handle,
indices_j: T.handle,
indptr_k0: T.handle,
indices_k0: T.handle,
indptr_k1: T.handle,
indices_k1: T.handle,
nnz_j: T.int32,
nnz_k: T.int32,
M: T.int32,
N1: T.int32,
N2: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz_k], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [M], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr_j, [M + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices_j, [nnz_j], dtype="int32", strides=[1])
K0_indptr_data = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32", strides=[1])
K0_indices_data = T.match_buffer(indices_k0, [nnz_k], dtype="int32", strides=[1])
K1_indptr_data = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32", strides=[1])
K1_indices_data = T.match_buffer(indices_k1, [nnz_k], dtype="int32", strides=[1])
# body
# with T.block("root")
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0[0] = -1
for v_vi in T.serial(M):
with T.block("square_sum0"):
vi = T.axis.spatial(M, v_vi)
T.reads(
J_indptr_data[0 : M + 1],
K1_indptr_data[0 : nnz_j + 1],
K0_indices_data[0:nnz_k],
A_data[0:nnz_k],
)
T.writes(B_data[vi], mid_0[0])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("square_sum1"):
vj = T.axis.reduce(N1, v_vj)
T.reads(
K1_indptr_data[0 : nnz_j + 1], K0_indices_data[0:nnz_k], A_data[0:nnz_k]
)
T.writes(B_data[vi], mid_0[0])
T.block_attr({"sparse": True})
with T.init():
B_data[vi] = T.float32(0)
for v_vk in T.serial(
K1_indptr_data[vj + J_indptr_data[vi] + 1]
- K1_indptr_data[vj + J_indptr_data[vi]]
):
with T.block("binary_search_0"):
ax0 = T.axis.spatial(N2, v_vk)
T.where(mid_0[0] == -1)
T.reads(K0_indices_data[0:nnz_k])
T.writes(mid_0[0])
T.block_attr({"sparse": True})
low[0] = 0
high[0] = (
K0_indptr_data[vj + J_indptr_data[vi] + 1]
- K0_indptr_data[vj + J_indptr_data[vi]]
)
while low[0] < high[0]:
mid_0[0] = low[0] + (high[0] - low[0]) // 2
if (
K0_indices_data[
mid_0[0] + K0_indptr_data[vj + J_indptr_data[vi]]
]
< K1_indices_data[ax0 + K1_indptr_data[vj + J_indptr_data[vi]]]
):
low[0] = mid_0[0] + 1
else:
high[0] = mid_0[0]
with T.block("square_sum2"):
vk = T.axis.reduce(N2, v_vk)
T.reads(
A_data[mid_0[0] + K0_indptr_data[vj + J_indptr_data[vi]]], mid_0[0]
)
T.writes(B_data[vi])
T.block_attr({"sparse": True})
B_data[vi] = (
B_data[vi]
+ A_data[mid_0[0] + K0_indptr_data[vj + J_indptr_data[vi]]]
)
@T.prim_func
def fused_reduction_4d_2d(
x: T.handle,
y: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_l: T.handle,
n: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_l: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
X_data = T.match_buffer(x, [nnz_l], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [nnz_j], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr_j, [n + 1], dtype="int32", strides=[1])
K_indptr_data = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32", strides=[1])
L_indptr_data = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vj in T.serial(nnz_j):
with T.block("reduction_4d_2d0"):
vi = T.axis.spatial(1, 0)
vj = T.axis.spatial(nnz_j, v_vj)
T.reads(K_indptr_data[0 : nnz_j + 1], L_indptr_data[0 : nnz_k + 1], X_data[0:nnz_l])
T.writes(Y_data[vj + J_indptr_data[vi]])
T.block_attr({"sparse": True})
for v_vk in T.serial(K_indptr_data[vj + 1] - K_indptr_data[vj]):
with T.block("reduction_4d_2d1"):
vk = T.axis.reduce(32768, v_vk)
T.reads(L_indptr_data[0 : nnz_k + 1], X_data[0:nnz_l])
T.writes(Y_data[vj])
T.block_attr({"sparse": True})
with T.init():
Y_data[vj] = T.float32(0)
for v_vl in T.serial(
L_indptr_data[vk + K_indptr_data[vj] + 1]
- L_indptr_data[vk + K_indptr_data[vj]]
):
with T.block("reduction_4d_2d2"):
vl = T.axis.reduce(32768, v_vl)
T.reads(X_data[vl + L_indptr_data[vk + K_indptr_data[vj]]])
T.writes(Y_data[vj])
T.block_attr({"sparse": True})
Y_data[vj] = (
Y_data[vj] + X_data[vl + L_indptr_data[vk + K_indptr_data[vj]]]
)
@T.prim_func
def fused_reduction_4d_3d(
x: T.handle,
y: T.handle,
indptr_j: T.handle,
indptr_k: T.handle,
indptr_l: T.handle,
n: T.int32,
nnz_j: T.int32,
nnz_k: T.int32,
nnz_l: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
X_data = T.match_buffer(x, [nnz_l], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [nnz_k], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr_j, [n + 1], dtype="int32", strides=[1])
K_indptr_data = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32", strides=[1])
L_indptr_data = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vk in T.serial(nnz_k):
with T.block("reduction_4d_3d0"):
vi = T.axis.spatial(1, 0)
vj = T.axis.spatial(1, 0)
vk = T.axis.spatial(nnz_k, v_vk)
T.reads(L_indptr_data[0 : nnz_k + 1], X_data[0:nnz_l])
T.writes(Y_data[vk + K_indptr_data[vj + J_indptr_data[vi]]])
T.block_attr({"sparse": True})
for v_vl in T.serial(L_indptr_data[vk + 1] - L_indptr_data[vk]):
with T.block("reduction_4d_3d1"):
vl = T.axis.reduce(32768, v_vl)
T.reads(X_data[vl + L_indptr_data[vk]])
T.writes(Y_data[vk])
T.block_attr({"sparse": True})
with T.init():
Y_data[vk] = T.float32(0)
Y_data[vk] = Y_data[vk] + X_data[vl + L_indptr_data[vk]]
@T.prim_func
def rgcn_forward(
etype: T.handle,
w: T.handle,
x: T.handle,
y: T.handle,
indptr: T.handle,
indices: T.handle,
n: T.int32,
r: T.int32,
feat_size: T.int32,
nnz: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
E_data = T.match_buffer(etype, [nnz], dtype="int32", strides=[1])
W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32", strides=[1])
X_data = T.match_buffer(x, [n * feat_size], dtype="float32", strides=[1])
Y_data = T.match_buffer(y, [n * feat_size], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
for v_vi, v_vout in T.grid(n, feat_size):
with T.block("rgcn-forward0"):
vi, vout = T.axis.remap("SS", [v_vi, v_vout])
T.reads(
J_indptr_data[0 : n + 1],
W_data[0 : r * feat_size * feat_size],
E_data[0:nnz],
X_data[0 : n * feat_size],
J_indices_data[0:nnz],
)
T.writes(Y_data[vi * feat_size + vout])
T.block_attr({"sparse": True})
for v_vj, v_vin in T.grid(J_indptr_data[vi + 1] - J_indptr_data[vi], feat_size):
with T.block("rgcn-forward1"):
vj = T.axis.reduce(n, v_vj)
vin = T.axis.reduce(feat_size, v_vin)
T.reads(
W_data[
E_data[vj + J_indptr_data[vi]] * (feat_size * feat_size)
+ vout * feat_size
+ vin
],
E_data[vj + J_indptr_data[vi]],
X_data[J_indices_data[vj + J_indptr_data[vi]] * feat_size + vin],
J_indices_data[vj + J_indptr_data[vi]],
)
T.writes(Y_data[vi * feat_size + vout])
T.block_attr({"sparse": True})
with T.init():
Y_data[vi * feat_size + vout] = T.float32(0)
Y_data[vi * feat_size + vout] = (
Y_data[vi * feat_size + vout]
+ W_data[
E_data[vj + J_indptr_data[vi]] * (feat_size * feat_size)
+ vout * feat_size
+ vin
]
* X_data[J_indices_data[vj + J_indptr_data[vi]] * feat_size + vin]
)
@T.prim_func
def sparse_softmax(
a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, nnz: T.int32
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [nnz], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr, [n + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices, [nnz], dtype="int32", strides=[1])
# body
# with T.block("root")
TMP_data = T.alloc_buffer([n], dtype="float32", strides=[1])
TMP1_data = T.alloc_buffer([n], dtype="float32", strides=[1])
for v_vi in T.serial(n):
with T.block("sparse_softmax0"):
vi = T.axis.spatial(n, v_vi)
T.reads(J_indptr_data[0 : n + 1], A_data[0:nnz], TMP_data[vi], TMP1_data[vi])
T.writes(TMP_data[vi], TMP1_data[vi], B_data[0:nnz])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("computer_max0"):
vj = T.axis.reduce(n, v_vj)
T.reads(A_data[vj + J_indptr_data[vi]])
T.writes(TMP_data[vi])
T.block_attr({"sparse": True})
with T.init():
TMP_data[vi] = T.float32(-100000)
TMP_data[vi] = T.max(TMP_data[vi], A_data[vj + J_indptr_data[vi]])
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("exp_and_sum0"):
vj = T.axis.reduce(n, v_vj)
T.reads(A_data[vj + J_indptr_data[vi]], TMP_data[vi])
T.writes(TMP1_data[vi])
T.block_attr({"sparse": True})
with T.init():
TMP1_data[vi] = T.float32(-100000)
TMP1_data[vi] = TMP1_data[vi] + T.exp(
A_data[vj + J_indptr_data[vi]] - TMP_data[vi], dtype="float32"
)
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("div0"):
vj = T.axis.spatial(n, v_vj)
T.reads(A_data[vj + J_indptr_data[vi]], TMP1_data[vi])
T.writes(B_data[vj + J_indptr_data[vi]])
T.block_attr({"sparse": True})
B_data[vj + J_indptr_data[vi]] = (
T.exp(A_data[vj + J_indptr_data[vi]], dtype="float32") / TMP1_data[vi]
)
@T.prim_func
def csr2bsr(
a: T.handle,
b: T.handle,
indptr_in: T.handle,
indices_in: T.handle,
indptr_out: T.handle,
indices_out: T.handle,
m_in: T.int32,
n_in: T.int32,
m_out: T.int32,
n_out: T.int32,
nnz_in: T.int32,
nnz_out: T.int32,
blk_size: T.int32,
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 0})
A_data = T.match_buffer(a, [nnz_in], dtype="float32", strides=[1])
B_data = T.match_buffer(b, [nnz_out * blk_size * blk_size], dtype="float32", strides=[1])
J_indptr_data = T.match_buffer(indptr_in, [m_in + 1], dtype="int32", strides=[1])
J_indices_data = T.match_buffer(indices_in, [nnz_in], dtype="int32", strides=[1])
J_bsr_indptr_data = T.match_buffer(indptr_out, [m_out + 1], dtype="int32", strides=[1])
J_bsr_indices_data = T.match_buffer(indices_out, [nnz_out], dtype="int32", strides=[1])
# body
# with T.block("root")
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0[0] = -1
for v_vi in T.serial(m_in):
with T.block("csr2bsr0"):
vi = T.axis.spatial(m_in, v_vi)
T.reads(
J_indptr_data[0 : m_in + 1],
J_bsr_indices_data[0:nnz_out],
A_data[0:nnz_in],
mid_0[0],
J_indices_data[0:nnz_in],
)
T.writes(mid_0[0], B_data[0 : nnz_out * blk_size * blk_size])
T.block_attr({"sparse": True})
for v_vj in T.serial(J_indptr_data[vi + 1] - J_indptr_data[vi]):
with T.block("binary_search_0"):
ax0 = T.axis.spatial(n_in, v_vj)
T.where(mid_0[0] == -1)
T.reads(J_bsr_indices_data[0:nnz_out])
T.writes(mid_0[0])
T.block_attr({"sparse": True})
low[0] = 0
high[0] = (
J_bsr_indptr_data[vi // blk_size + 1] - J_bsr_indptr_data[vi // blk_size]
)
while low[0] < high[0]:
mid_0[0] = low[0] + (high[0] - low[0]) // 2
if (
J_bsr_indices_data[mid_0[0] + J_bsr_indptr_data[vi // blk_size]]
< J_indices_data[ax0 + J_indptr_data[vi]] // blk_size
):
low[0] = mid_0[0] + 1
else:
high[0] = mid_0[0]
with T.block("csr2bsr1"):
vj = T.axis.spatial(n_in, v_vj)
T.reads(
A_data[vj + J_indptr_data[vi]],
mid_0[0],
J_indices_data[vj + J_indptr_data[vi]],
)
T.writes(B_data[0 : nnz_out * blk_size * blk_size])
T.block_attr({"sparse": True})
B_data[
(mid_0[0] + J_bsr_indptr_data[vi // blk_size]) * (blk_size * blk_size)
+ vi % blk_size * blk_size
+ J_indices_data[vj + J_indptr_data[vi]] % blk_size
] = A_data[vj + J_indptr_data[vi]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment