Skip to content

Instantly share code, notes, and snippets.

@Lunderberg
Created June 10, 2022 20:21
Show Gist options
  • Save Lunderberg/7dcd4edbdd7bedfb08072037792aa585 to your computer and use it in GitHub Desktop.
Save Lunderberg/7dcd4edbdd7bedfb08072037792aa585 to your computer and use it in GitHub Desktop.
Examinging the hoisting of cached transforms from TIR into graph level
# Initial end-to-end model
@script.ir_module
class EndToEndModel:
@R.func
def main(x: R.Tensor[16]):
F = R.const(shape=[3])
Y = call_tir(conv1d_16, X, F)
Z = call_tir(conv1d_18, Y, F)
return Z
@T.prim_func
def conv1d_16(
X: T.Buffer[(16,), "float32"],
F: T.Buffer[(3,), "float32"],
Y: T.Buffer[(18,), "float32"],
):
for Yi in T.serial(18):
Y[Yi] = 0.0
for fi in T.serial(3):
Xi = Yi - fi + 2
if 0 <= Xi < 16:
Y[Yi] = Y[Yi] + F[fi] * X[Xi]
@T.prim_func
def conv1d_18(
Y: T.Buffer[(18,), "float32"],
F: T.Buffer[(3,), "float32"],
Z: T.Buffer[(20,), "float32"],
):
for Zi in T.serial(20):
Z[Zi] = 0.0
for fi in T.serial(3):
Yi = Zi - fi + 2
if 0 <= Yi < 18:
Z[Zi] = Z[Zi] + F[fi] * Y[Yi]
# After applying the same simplifications as proposed in the RFC, but
# with a cache_read/cache_write stage that contains the transformed
# buffers, rather than treating the input argument as transformed.
@script.ir_module
class EndToEndModel:
@R.func
def main(x: R.Tensor[16]):
F = R.const(shape=[3])
y = call_tir(conv1d_16, x, F)
z = call_tir(conv1d_18, y, F)
return z
# X_read_cache = sched.cache_read(X)
# Y_write_cache = sched.cache_write(Y)
# sched.transform_layout(X_read_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0)
# sched.transform_layout(Y_write_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0)
@T.prim_func
def conv1d_16(
X: T.Buffer[(16,), "float32"],
F: T.Buffer[(3,), "float32"],
Y: T.Buffer[(18,), "float32"],
):
X_read_cache = T.alloc_buffer([3, 8], "float32")
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 16:
X_read_cache[io, ii] = X[i]
else:
X_read_cache[io, ii] = 0.0
Y_write_cache = T.alloc_buffer([3, 8], "float32")
for io, ii in T.serial(3, 8):
Y_write_cache[io, ii] = 0.0
for io in T.serial(3):
for ii in T.serial(8):
for fi in T.serial(3):
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = (
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8]
+ F[fi] * X_read_cache[io, ii]
)
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y[i] = Y_write_cache[io, ii]
# Y_read_cache = sched.cache_read(Y)
# Z_write_cache = sched.cache_write(Z)
# sched.transform_layout(Y_read_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0)
# sched.transform_layout(Z_write_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0)
@T.prim_func
def conv1d_18(
Y: T.Buffer[(18,), "float32"],
F: T.Buffer[(3,), "float32"],
Z: T.Buffer[(20,), "float32"],
):
Y_read_cache = T.alloc_buffer([3, 8], "float32")
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y_read_cache[io, ii] = Y[i]
else:
Y_read_cache[io, ii] = 0.0
Z_write_cache = T.alloc_buffer([3, 8], "float32")
for io, ii in T.serial(3, 8):
Z_write_cache[io, ii] = 0.0
for io in T.serial(3):
for ii in T.serial(8):
for fi in T.serial(3):
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = (
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8]
+ F[fi] * Y_read_cache[io, ii]
)
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 20:
Z[i] = Z_write_cache[io, ii]
# Hoist out layout transformations into independent functions
@script.ir_module
class EndToEndModel:
@R.func
def main(x: R.Tensor[16]):
F = R.const(shape=[3])
X_read_cache = call_tir(transform_X, X)
Y_write_cache = call_tir(conv1d_16, X_read_cache, F)
Y = call_tir(inv_transform_Y, Y_cache)
Y_read_cache = call_tir(transform_Y, Y)
Z_write_cache = call_tir(conv1d_18, Y_read_cache, F)
Z = call_tir(inv_transform_Z, Z_write_cache)
return Z
@T.prim_func
def transform_X(
X: T.Buffer[(16,), "float32"],
X_read_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 16:
X_read_cache[io, ii] = X[i]
else:
X_read_cache[io, ii] = 0.0
@T.prim_func
def conv1d_16(
X_read_cache: T.Buffer[(16,), "float32"],
F: T.Buffer[(3,), "float32"],
Y_write_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.serial(3, 8):
Y_write_cache[io, ii] = 0.0
for io in T.serial(3):
for ii in T.serial(8):
for fi in T.serial(3):
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = (
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8]
+ F[fi] * X_read_cache[io, ii]
)
@T.prim_func
def inv_transform_Y(
Y_write_cache: T.Buffer[(3, 8), "float32"],
Y: T.Buffer[(18,), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y[i] = Y_write_cache[io, ii]
@T.prim_func
def transform_Y(
Y: T.Buffer[(18,), "float32"],
Y_read_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y_read_cache[io, ii] = Y[i]
else:
Y_read_cache[io, ii] = 0.0
@T.prim_func
def conv1d_18(
Y_read_cache: T.Buffer[(3, 8), "float32"],
F: T.Buffer[(3,), "float32"],
Z_write_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.serial(3, 8):
Z_write_cache[io, ii] = 0.0
for io in T.serial(3):
for ii in T.serial(8):
for fi in T.serial(3):
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = (
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8]
+ F[fi] * Y_read_cache[io, ii]
)
@T.prim_func
def inv_transform_Z(
Z_write_cache: T.Buffer[(3, 8), "float32"],
Z: T.Buffer[(20,), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 20:
Z[i] = Z_write_cache[io, ii]
# Merging the calls to inv_transform_Y and transform_Y
@script.ir_module
class EndToEndModel:
@R.func
def main(x: R.Tensor[16]):
F = R.const(shape=[3])
X_read_cache = call_tir(transform_X, X)
Y_write_cache = call_tir(conv1d_16, X_read_cache, F)
Y_read_cache = call_tir(fused_inv_transform_Y_transform_Y, Y_write_cache)
Y_read_cache = call_tir(transform_Y, Y)
Z_write_cache = call_tir(conv1d_18, Y_read_cache, F)
Z = call_tir(inv_transform_Z, Z_write_cache)
return Z
@T.prim_func
def transform_X(
X: T.Buffer[(16,), "float32"],
X_read_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 16:
X_read_cache[io, ii] = X[i]
else:
X_read_cache[io, ii] = 0.0
@T.prim_func
def conv1d_16(
X_read_cache: T.Buffer[(16,), "float32"],
F: T.Buffer[(3,), "float32"],
Y_write_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.serial(3, 8):
Y_write_cache[io, ii] = 0.0
for io in T.serial(3):
for ii in T.serial(8):
for fi in T.serial(3):
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = (
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8]
+ F[fi] * X_read_cache[io, ii]
)
for io in T.serial(3):
for ii in T.serial(8):
i = 8 * io + ii - 2
if not (0 <= i < 18):
Y_write_cache[io, ii] = 0.0
@T.prim_func
def fused_inv_transform_Y_transform_Y(
Y_write_cache: T.Buffer[(3, 8), "float32"],
Y_read_cache: T.Buffer[(3, 8), "float32"],
):
Y = (T.alloc_buffer[(18,), "float32"],)
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y[i] = Y_write_cache[io, ii]
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y_read_cache[io, ii] = Y[i]
else:
Y_read_cache[io, ii] = 0.0
@T.prim_func
def conv1d_18(
Y_read_cache: T.Buffer[(3, 8), "float32"],
F: T.Buffer[(3,), "float32"],
Z_write_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.serial(3, 8):
Z_write_cache[io, ii] = 0.0
for io in T.serial(3):
for ii in T.serial(8):
for fi in T.serial(3):
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = (
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8]
+ F[fi] * Y_read_cache[io, ii]
)
@T.prim_func
def inv_transform_Z(
Z_write_cache: T.Buffer[(3, 8), "float32"],
Z: T.Buffer[(20,), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 20:
Z[i] = Z_write_cache[io, ii]
# Same as previous, but only considering this function. If we can
# prove that this is equivalent to a memcopy, then we are justified in
# removing it from main(), and replacing all use of `Y_read_cache`
# with `Y_write_cache`.
@T.prim_func
def fused_inv_transform_Y_transform_Y(
Y_write_cache: T.Buffer[(3, 8), "float32"],
Y_read_cache: T.Buffer[(3, 8), "float32"],
):
Y = (T.alloc_buffer[(18,), "float32"],)
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y[i] = Y_write_cache[io, ii]
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y_read_cache[io, ii] = Y[i]
else:
Y_read_cache[io, ii] = 0.0
# After inlining Y[i]. In order to prove that this is equivalent to a
# memcpy, we would need to know a priori that when we are outside of
# the bounds of `0 <= i < 18`, `Y_write_cache[io, ii]` contains the
# value 0.0. This is not something that could be determined from any
# local analysis of this function, and would require reconstructing
# the buffer constraint based on analysis of other functions.
@T.prim_func
def fused_inv_transform_Y_transform_Y(
Y_write_cache: T.Buffer[(3, 8), "float32"],
Y_read_cache: T.Buffer[(3, 8), "float32"],
):
for io, ii in T.grid(3, 8):
i = 8 * io + ii - 2
if 0 <= i < 18:
Y_read_cache[io, ii] = Y_write_cache[io, ii]
else:
Y_read_cache[io, ii] = 0.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment