Skip to content

Instantly share code, notes, and snippets.

@psrivas2
Created November 3, 2022 13:54
Show Gist options
  • Save psrivas2/1a4657e24f29b4f98dfe111191b3b16a to your computer and use it in GitHub Desktop.
Save psrivas2/1a4657e24f29b4f98dfe111191b3b16a to your computer and use it in GitHub Desktop.

Here is how the TIR Primfunc would like before hoisting

    @T.prim_func
    def fused_primitives(x: T.Buffer[25, "float32"], y_crop_1: T.Buffer[25, "float32"]) -> None:
        x_pad_0 = T.alloc_buffer([28], dtype="float32")
        x_transform = T.alloc_buffer([4, 7], dtype="float32")
        x_pad_1 = T.alloc_buffer([4, 8], dtype="float32")
        y_crop_0 = T.alloc_buffer([4, 7], dtype="float32")
        y_transform = T.alloc_buffer([28], dtype="float32")

        # x_pad_0 : Tensor[28]= pad(x, paddings = [(0,3)], pad_value = 0)
        for i0 in T.serial(28):
            with T.block("pad_1d"):
                i = T.axis.spatial(28, i0)
                T.reads(x[i])
                T.writes(x_pad_0[i])
                x_pad_0[i] = T.if_then_else(i < 25, x[i], T.float32(0), dtype="float32")

        # x_transform : Tensor[4, 7] = transform_layout(x_pad_0, index_map = lambda i: [i // 7, i % 7])
        for i0, i1 in T.grid(4, 7):
            with T.block("tile"):
                i, j = T.axis.remap("SS", [i0, i1])
                T.reads(x_pad_0[i * 7 + j])
                T.writes(x_transform[i, j])
                x_transform[i, j] = x_pad_0[i * 7 + j]

        # x_pad_1 : Tensor[4, 8] = pad(x_transform, paddings = [(0, 0), (0, 1)], pad_value = 0)
        for i0, i1 in T.grid(4, 8):
            with T.block("pad_2d"):
                i, j = T.axis.remap("SS", [i0, i1])
                T.reads(x_transform[i, j])
                T.writes(x_pad_1[i, j])
                x_pad_1[i, j] = T.if_then_else(j < 7, x_transform[i, j], T.float32(0), dtype="float32")

        # some computation ....
        # y : Tensor[4, 8] = computation(x_pad_1)
 
        # y_crop_0 : Tensor[4, 7] = crop(y, start_indices = [0, 0], slice_sizes = [4, 7], cropped_value = 0)
        T.assume(j < 7 or y[i, j] == 0)
        for i0, i1 in T.grid(4, 7):
            with T.block("crop_2d"):
                i, j = T.axis.remap("SS", [i0, i1])
                T.reads(y[i, j])
                T.writes(y_crop_0[i, j])
                y_crop_0[i, j] = y[i, j]

        # y_transform : Tensor[28] = transform_layout(y_crop_0, index_map = lambda i, j: i * 7 + j)
        for i0 in T.serial(28):
            with T.block("inv_tile"):
                i = T.axis.spatial(28, i0)
                T.reads(y_crop_0[i // 7, i % 7])
                T.writes(y_transform[i])
                y_transform[i] = y_crop_0[i // 7, i % 7]

        # y_crop_1 : Tensor[25] = crop(y_transform, start_indices = [0], slice_sizes = [25], cropped_value = 0)
        T.assume(i < 25 or y_transform[i] == 0)
        for i0 in T.serial(25):
            with T.block("crop_1d"):
                i = T.axis.spatial(25, i0)
                T.reads(y_transform[i])
                T.writes(y_crop_1[i])
                y_crop_1[i] = y_transform[i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment