Skip to content

Instantly share code, notes, and snippets.

@csullivan
Last active November 14, 2022 17:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csullivan/6598dd9cad8772d945685b9ee823a78d to your computer and use it in GitHub Desktop.
Save csullivan/6598dd9cad8772d945685b9ee823a78d to your computer and use it in GitHub Desktop.

Authored-by: Eric Lunderberg

Notes summarizing discussion between @Lunderberg and @csullivan on 2022_10_25

Considerations of Pad/Crop represented separately from bijective transformations

From previous conversation, possibility of representing pad/crop separately from the layout transform. This would allow algebraic proofs to be done in the simpler coordinate system, before applying the layout transform.

However, not all types of padding can be represented in this manner. As an example, suppose we want to pad a buffer such that every 8th value is padding. This could be represented in one step with a padded transform, but would require three steps when the padding is introduced in separate steps.

# With padded transforms
transform_layout(index_map = lambda i: [i//7, (i%7)%8], pad_value=0)

# With pad/crop and bijective transforms
insert_pad_crop(new_shape = [7*ceildiv(A.shape, 7)])
transform_layout(index_map = lambda i: [i//7, i%7])
insert_pad_crop(new_shape = [buf.shape[0], 8])

Any cancellation of the second pad/crop would need to be done after the layout transform. Therefore, we can't get away from performing algebraic proofs within the transformed layout.

While this is a somewhat contrived example, it could easily occur in practice. Support a conv1d with filter size 2 uses vector operations of size 8. The implementation uses a sliding window of size 8, which advances by 7 elements at a time. (Assume alignment restrictions are handled by a cache_read.) Each application of the vector operations would produce 8 values, the last of which is junk. If the output of the conv1d is then matrix multiplied by a constant matrix, the above example could be applied to the constant matrix. This would result in a pad value (zero) at every location corresponding to a junk value, which could be used to vectorize the matrix multiplication.

@psrivas2
Copy link

psrivas2 commented Nov 1, 2022

Thanks for providing an example! My goal of introducing separate pad, crop and pure_layout_transform primitives was to ease the burden of algebraic simplification. So any contrary examples would be very useful here and we can abandon this approach in favor of a better one.

But I want to make sure I understand this example completely.

Perhaps I don’t follow this statement

Any cancellation of the second pad/crop would need to be done after the layout transform. Therefore, we can't get away from performing algebraic proofs within the transformed layout.

I would have expected this series of transformations to cancel out with its inverse. In your example, the transformation is a series of 3 operations

# x: Tensor[25]
x_pad_0 : Tensor[28] = pad(x, paddings = [(0,3)], pad_value = 0)
x_transform : Tensor[4, 7] = transform_layout(x_pad_0, index_map = lambda i: [i // 7, i % 7])
x_pad_1 : Tensor[4, 8] = pad(x_transform, paddings = [(0, 0), (0, 1)], pad_value = 0) 

The inverse of the above would be

# y: Tensor[4, 8]
y_crop_0 : Tensor[4, 7] = crop(y, start_indices = [0, 0], slice_sizes = [4, 7], cropped_value = 0)
y_transform : Tensor[28] = transform_layout(y_crop_0, index_map = lambda i, j: i * 7 + j)
y_crop_1 : Tensor[25] = crop(y_transform, start_indices = [0], slice_sizes = [25], cropped_value = 0) 

If these transformations occur between two operators, they would look something like below:

y_crop_0 : Tensor[4, 7] = crop(y, start_indices = [0, 0], slice_sizes = [4, 7], cropped_value = 0)
y_transform : Tensor[28] = transform_layout(y_crop_0, index_map = lambda i, j: i * 7 + j)
y_crop_1 : Tensor[25] = crop(y_transform, start_indices = [0], slice_sizes = [25], cropped_value = 0) 

x_pad_0 : Tensor[28] = pad(y_crop_1, paddings = [(0,3)], pad_value = 0)
x_transform : Tensor[4, 7] = transform_layout(x_pad_0, index_map = lambda i: [i // 7, i % 7])
x_pad_1 : Tensor[4, 8] = pad(x_transform, paddings = [(0, 0), (0, 1)], pad_value = 0) 

These could be simplified easily. Cancelling out adjacent ones in each step. Did I misunderstand your example?

@Lunderberg
Copy link

For me, the key bit is that crop cannot be hoisted without first proving that the compact representation is a valid representation of the previous TIR. The cancellations listed would cancel out, but writing in that format assumes that we already made it past the step of extracting the compact representation. When hoisting multiple stages from a TIR function, they must be hoisted from the outside-in, so y_crop_1 must be hoisted first. In order to extract the compact representation of y_crop_1, we need to propagate the known values across the compute function, the TIR stage that will later be hoisted as y_crop_0, and the TIR stage that will later be hoisted as y_transform. Since this requires data-flow analysis across a layout transformation stage, we haven't avoided the propagation of buffer values across a layout transform.

(Chris and I came up with this example last week, which is why I didn't mention it during our earlier conversation about separating pad/crop from bijective transformations.)

@psrivas2
Copy link

psrivas2 commented Nov 3, 2022

Is your argument the following: hoisting y_crop_1 TIR block into a compact representation with cropped_value = 0 would require dataflow analysis?

I agree it probably would. But that is not my main argument for splitting layout transformation into pad, crop and pure_layout_transform primitives. The main argument for splitting is -- it is hard to prove that two layout transforms are inverse of each other if there is implicit padding and cropping in the compact representation or TIR block representation.

For example, there is no compact representation of the inverse layout transform of the following:

transform_layout(index_map = lambda i: [i//7, (i%7)%8], pad_value = 0)

so we would be left with fusing the underlying PrimFunc blocks (shown below) and proving that it is identity.

    @T.prim_func
    def inv_fused_layout_trans(rxplaceholder: T.Buffer[(4, 8), "float32"], compute: T.Buffer[(4, 8), "float32"]) -> None:
        # body
        # with T.block("root")
        compute_1 = T.alloc_buffer([25], dtype="float32")
        for i0 in T.serial(25):
            with T.block("compute"):
                i = T.axis.spatial(25, i0)
                T.reads(rxplaceholder[i // 7, i % 7])
                T.writes(compute_1[i])
                compute_1[i] = rxplaceholder[i // 7, i % 7]
        for i0, i1 in T.grid(4, 8):
            with T.block("compute_1"):
                i, j = T.axis.remap("SS", [i0, i1])
                T.reads(compute_1[i * 7 + j])
                T.writes(compute[i, j])
                compute[i, j] = T.if_then_else(j < 7 and i * 7 + j < 25, compute_1[i * 7 + j], T.float32(0), dtype="float32")

It seems it would be non-trivial to prove that the above PrimFunc is identity, but I would be happy to be corrected.

Hence I proposed to split pad/crop from layout_transform. This would allow easy cancelling out in compact representation as dscussed in my previous comment.

When hoisting multiple stages from a TIR function, they must be hoisted from the outside-in, so y_crop_1 must be hoisted first. In order to extract the compact representation of y_crop_1, we need to propagate the known values across the compute function, the TIR stage that will later be hoisted as y_crop_0, and the TIR stage that will later be hoisted as y_transform

I disagree that we have to hoist them one at a time, outside in. It depends on the implementation. If it wants, it can do all three of them together. It might be cleaner to do it one at a time, but there is no fundamental reason to do it that way.

Even if we do it one at a time, as long as T.assume(i < 25 or y_transform[i] == 0) are placed before the y_crop_1 block using dataflow analysis or otherwise, we should be able to raise it to the compact form.

@Lunderberg
Copy link

Is your argument the following: hoisting y_crop_1 TIR block into a compact representation with cropped_value = 0 would require dataflow analysis?

Essentially, yes. The difficulties involved in identifying a TIR block and hoisting out a compact representation are roughly the same as the difficulties in proving a memcpy in TIR.

But that is not my main argument for splitting layout transformation into pad, crop and pure_layout_transform primitives. The main argument for splitting is -- it is hard to prove that two layout transforms are inverse of each other if there is implicit padding and cropping in the compact representation or TIR block representation.

Hmm. I suppose I'm not seeing the difficulty in proving two layout transforms to be inverses of each other. I would see three different cases for compact representations that could be canceled out.

  1. A layout_transform(A) followed by a layout_transform(B). Since the layout_transform can introduce implicit padding, if either layout_transform introduces padding, the sequence of two transforms introduces padding, and is therefore not a no-op. The transformations cancel out if A(B(indices)) == indices and both transformations are bijective.

  2. An inv_layout_transform(A, pad_value=x) followed by a layout_transform(B, pad_value=y). The inv_layout_transform can crop out padding, which is then added back in by the inv_layout_transform. The two compact representations cancel out if A is equivalent to B, and x == y.

  3. A layout_transform(A, pad_value=x) followed by an inv_layout_transform(B, pad_value=y). The layout_transform can introduce implicit padding, which is removed by the inv_layout_transform. The two compact representations cancel out if A is the same as B.

  4. An inv_layout_transform(A) followed by an inv_layout_transform(B). Since the inv_layout_transform can crop out implicit padding, if either inv_layout_transform crops out padding, the sequence of two inverse transforms changes the size of the buffer padding, and is therefore not a no-op. The transformations cancel out if B(A(indices)) == indices and both transformations are bijective.

@csullivan
Copy link
Author

Memory layout for the above mentioned IndexMap supposing an input buffer of 16 elements.


  ┌─Physical-index-space───IndexMap:[i//7,i%7]─────────────────┐
  │                                                            │
 ┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┐
 │00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│16│17│18│19│20│
 └▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲└──┴──┴──┴──┴──┘
  │                                             │
  └─Logical-index-space─────────────────────────┘



  ┌─────IndexMap:[i//7,i%7]─┐
  │                         │
  │      ┌──┬──┬──┬──┬──┬──┐▼─┐
  │      │00│01│02│03│04│05│06│
  │      ├──┼──┼──┼──┼──┼──┼──┤
  │      │07│08│09│10│11│12│13│
  │      └──┼──┼──┼──┼──┼──┼──┤
  └──────►14│15│16│17│18│19│20│
         └──┴──┘▲─┴──┴──┴──┴─▲┘
                │            │
                │            │
                └─pad-values─┘


  ┌─IndexMap:[i//7,(i%7)%8]─┐
  │                         │
  │      ┌──┬──┬──┬──┬──┬──┐▼─┬──┐
  │      │00│01│02│03│04│05│06│07◄─┐
  │      ├──┼──┼──┼──┼──┼──┼──┼──┐ │
  │      │08│09│10│11│12│13│14│15◄─┤
  │      └──┼──┼──┼──┼──┼──┼──┼──┐ │
  └──────►16│17│18│19│20│21│22│23◄─┤
         └──┴──┘▲─┴──┴──┴──┴─▲└─▲┘ │
                │            │  │  │
                │            │  │  │
                └─pad-values─┴──┴──┘


  ┌─Physical-index-space───IndexMap:[i//7,(i%7)%8]─────────────────────┐
  │                                                                    │
 ┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐▼─┐
 │00│01│02│03│04│05│06│xx│08│09│10│11│12│13│14│xx│16│17│18│19│20│21│22│xx│
 └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘



Note the pad values every 8th value in the physical memory layout in the last figure above.

@sunggg
Copy link

sunggg commented Nov 7, 2022

@csullivan thank you for the beautiful figure! 😄
One question - this example adds padding at the rightmost side of axis=1.
How does this repr add the padding at the leftmost side?

@Lunderberg
Copy link

@sunggg Padding at the left side would be represented as [i//7, ( (i%7) + 1) %8]. The %8 introduces the same requirement that its left argument be padded out to be divisible by the right argument, and the expression dictates the exact mapping from pre-transformation indices to post-transformation indices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment