Skip to content

Instantly share code, notes, and snippets.

@benvanik
Last active December 15, 2020 23:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benvanik/a2c1915c71dfb611f8e6d7ddcdc96539 to your computer and use it in GitHub Desktop.
Save benvanik/a2c1915c71dfb611f8e6d7ddcdc96539 to your computer and use it in GitHub Desktop.
tiled dispatch
// Simple dispatch of static shapes.
func @staticShapeDispatch(%arg0 : tensor<8x4xf32>) -> tensor<4x8xf32> {
%x = constant 100 : index
%y = constant 50 : index
// %x, %y here are the workgroup counts along a 2D grid to dispatch; backends turn them into 3D XYZ.
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<8x4xf32>) -> (tensor<4x8xf32>) = (
// I/O are modeled in the region as ref arguments that have some special ops available.
%arg : !flow.dispatch.input<8x4xf32>, %ret : !flow.dispatch.output<4x8xf32>
) {
// Loads a tensor from an input; can be tiled with offsets/sizes/strides.
%arg_value = flow.dispatch.input.load %arg : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32>
// Shapes can be retrieved from the I/O arguments.
%arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<8x4xf32> -> !shapex.ranked_shape<[8,4]>
%ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<4x8xf32> -> !shapex.ranked_shape<[4,8]>
// Representative "produce a tile from an input with shape information" op.
%ret_value = "test.sink"(%arg_value, %arg_shape, %ret_shape) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> (tensor<4x8xf32>)
// Stores a tile into the output I/O argument.
flow.dispatch.output.store %ret_value, %ret : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32>
flow.return
}
return %0 : tensor<4x8xf32>
}
// Simple transformation that outlines the region; also ran canonicalize afterward.
flow.executable @staticShapeDispatch_dispatch_0 attributes {sym_visibility = "private"} {
flow.dispatch.entry @staticShapeDispatch_dispatch_0 attributes {
// Information that may be useful for logging/tracing/etc, but otherwise not needed.
signature = (tensor<8x4xf32>) -> tensor<4x8xf32>,
// The original rank of the workgroup grid (XY in this example).
workgroup_rank = 2 : index
}
module {
// Arguments match that of the region body (references to I/O).
func @staticShapeDispatch_dispatch_0(%arg0: !flow.dispatch.input<8x4xf32>, %arg1: !flow.dispatch.output<4x8xf32>) {
// Shapes are static and the queries canonicalized to constant values.
%rs8_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[8,4]>
%rs4_8 = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8]>
%0 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32>
%1 = "test.sink"(%0, %rs8_4, %rs4_8) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> tensor<4x8xf32>
flow.dispatch.output.store %1, %arg1 : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32>
return
}
}
}
func @staticShapeDispatch(%arg0: tensor<8x4xf32>) -> tensor<4x8xf32> {
%c100 = constant 100 : index
%c50 = constant 50 : index
%0 = flow.dispatch2 @staticShapeDispatch_dispatch_0::@staticShapeDispatch_dispatch_0[%c100, %c50] (%arg0) : (tensor<8x4xf32>) -> tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// Materializing the hal.interface and executables; note that this is a partial conversion: flow ops remain.
hal.executable @static_tiled_dispatch attributes {sym_visibility = "private"} {
hal.interface @legacy_io {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
}
hal.executable.target @vmla, filter="vmla" {
hal.executable.entry_point @entry attributes {
interface = @legacy_io,
ordinal = 0 : i32,
signature = (!flow.dispatch.input<8x4xf32>, !flow.dispatch.output<4x8xf32>) -> ()
}
module {
func @entry() {
// As with today, an arbitrary byte offset into the binding can be provided.
// Optionally a byte length can be provided too (if we can generate them); may be useful for bounds checking.
%c0 = constant 0 : index
// This op returns AnyType; here it's 1:1 with the original I/O arguments, but further lowering
// can turn it into memref<?xi8>/etc.
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<8x4xf32>
%1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<4x8xf32>
%rs8_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[8,4]>
%rs4_8 = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8]>
%2 = flow.dispatch.input.load %0 : !flow.dispatch.input<8x4xf32> -> tensor<8x4xf32>
%3 = "test.sink"(%2, %rs8_4, %rs4_8) : (tensor<8x4xf32>, !shapex.ranked_shape<[8,4]>, !shapex.ranked_shape<[4,8]>) -> tensor<4x8xf32>
flow.dispatch.output.store %3, %1 : tensor<4x8xf32> -> !flow.dispatch.output<4x8xf32>
return
}
}
}
}
// More complicated example with dynamic shapes.
func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> {
%c1 = constant 1 : index
%c3 = constant 3 : index
// Local query of the dimensions; these will end up getting turned into real values independent of this.
%dim1 = dim %arg0, %c1 : tensor<7x?x24x?xf32>
%dim3 = dim %arg0, %c3 : tensor<7x?x24x?xf32>
%x = constant 1024 : index
%y = constant 512 : index
// Shape ties are used (as they are today) to indicate which shapes correspond to which tensors.
%arg0_shape = shapex.make_ranked_shape %dim1, %dim3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]>
%arg0_shaped = shapex.tie_shape %arg0, %arg0_shape : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>
%ret0_shape = shapex.make_ranked_shape %dim3, %dim1 : (index, index) -> !shapex.ranked_shape<[?,?,1024]>
%ret0 = flow.dispatch.workgroups[%x, %y](%arg0_shaped) : (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> = (
%arg : !flow.dispatch.input<7x?x24x?xf32>, %ret : !flow.dispatch.output<?x?x1024xf32>
) {
// Resolves to 2 when canonicalization runs.
%workgroup_rank = flow.dispatch.workgroup.rank : index
// Can get dynamic shape dimensions of %arg.
%arg_shape = flow.dispatch.shape %arg : !flow.dispatch.input<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]>
%arg_dim1 = shapex.ranked_dim %arg_shape[1] : !shapex.ranked_shape<[7,?,24,?]> -> index
%arg_dim3 = shapex.ranked_dim %arg_shape[3] : !shapex.ranked_shape<[7,?,24,?]> -> index
"test.sink_shape_arg"(%arg_dim1, %arg_dim3) : (index, index) -> ()
// Can get dynamic shape dimensions of %ret.
%ret_shape = flow.dispatch.shape %ret : !flow.dispatch.output<?x?x1024xf32> -> !shapex.ranked_shape<[?,?,1024]>
%ret_dim0 = shapex.ranked_dim %ret_shape[0] : !shapex.ranked_shape<[?,?,1024]> -> index
%ret_dim1 = shapex.ranked_dim %ret_shape[1] : !shapex.ranked_shape<[?,?,1024]> -> index
"test.sink_shape_ret"(%ret_dim0, %ret_dim1) : (index, index) -> ()
// Load a tile (and get the tile size - which if we used offsets/sizes/strides may be smaller than the tensors).
%arg_tile = flow.dispatch.input.load %arg : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32>
%arg_tile_shape = shapex.get_ranked_shape %arg_tile : tensor<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]>
// Produce a new tile.
%ret_tile = "test.tile_math"(%arg_tile, %arg_tile_shape, %ret_shape) :
(tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>, !shapex.ranked_shape<[?,?,1024]>) -> (tensor<?x?x1024xf32>)
// Store tile back.
flow.dispatch.output.store %ret_tile, %ret : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32>
flow.return
}
// Tie here allows us to know what the result shape is and feed it into the dispatch op.
%ret0_shaped = shapex.tie_shape %ret0, %ret0_shape : tensor<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>
return %ret0_shaped : tensor<?x?x1024xf32>
}
// Outlining + canonicalization.
flow.executable @dynamicShapeDispatch_dispatch_0 attributes {sym_visibility = "private"} {
flow.dispatch.entry @dynamicShapeDispatch_dispatch_0 attributes {signature = (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32>, workgroup_rank = 2 : index}
module {
func @dynamicShapeDispatch_dispatch_0(
%arg0: !flow.dispatch.input<7x?x24x?xf32>, %arg1: !flow.dispatch.output<?x?x1024xf32>,
// Dynamic dimensions for arg/ret, expanded to primitive indices here.
%arg2: index, %arg3: index, %arg4: index, %arg5: index
) {
// Constructs and ties shapes for the arg/ret so that any use of the %1/%3 I/O can fetch full dynamic shape values.
%0 = shapex.make_ranked_shape %arg2, %arg3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]>
%1 = flow.dispatch.tie_shape %arg0, %0 : (!flow.dispatch.input<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.input<7x?x24x?xf32>
%2 = shapex.make_ranked_shape %arg4, %arg5 : (index, index) -> !shapex.ranked_shape<[?,?,1024]>
%3 = flow.dispatch.tie_shape %arg1, %2 : (!flow.dispatch.output<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.output<?x?x1024xf32>
"test.sink_shape_arg"(%arg2, %arg3) : (index, index) -> ()
"test.sink_shape_ret"(%arg4, %arg5) : (index, index) -> ()
%4 = flow.dispatch.input.load %1 : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32>
%5 = shapex.get_ranked_shape %4 : tensor<7x?x24x?xf32> -> !shapex.ranked_shape<[7,?,24,?]>
%6 = "test.tile_math"(%4, %5, %2) : (tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>, !shapex.ranked_shape<[?,?,1024]>) -> tensor<?x?x1024xf32>
flow.dispatch.output.store %6, %3 : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32>
return
}
}
}
func @dynamicShapeDispatch(%arg0: tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> {
%c1 = constant 1 : index
%c3 = constant 3 : index
%c1024 = constant 1024 : index
%c512 = constant 512 : index
%0 = dim %arg0, %c1 : tensor<7x?x24x?xf32>
%1 = dim %arg0, %c3 : tensor<7x?x24x?xf32>
%2 = shapex.make_ranked_shape %0, %1 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]>
%3 = shapex.tie_shape %arg0, %2 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>
%4 = shapex.make_ranked_shape %1, %0 : (index, index) -> !shapex.ranked_shape<[?,?,1024]>
// Note that the dynamic shape dimensions are passed in here.
%5 = flow.dispatch2 @dynamicShapeDispatch_dispatch_0::@dynamicShapeDispatch_dispatch_0[%c1024, %c512] (%3, %0, %1, %1, %0) : (tensor<7x?x24x?xf32>, index, index, index, index) -> tensor<?x?x1024xf32>
%6 = shapex.tie_shape %5, %4 : tensor<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>
return %6 : tensor<?x?x1024xf32>
}
// Interface materialized and dynamic dimensions ended up as push constants.
hal.executable @dynamic_tiled_dispatch attributes {sym_visibility = "private"} {
hal.interface @legacy_io attributes {push_constants = 4 : i32} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
}
hal.executable.target @vmla, filter="vmla" {
hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output<?x?x1024xf32>, index, index, index, index) -> ()}
module {
func @entry() {
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<7x?x24x?xf32>
%1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<?x?x1024xf32>
// These are the 4 (2 + 2) dynamic dimensions that were arguments, now fetched via the interface.
%2 = hal.interface.load.constant offset = 0 : index
%3 = hal.interface.load.constant offset = 1 : index
%4 = hal.interface.load.constant offset = 2 : index
%5 = hal.interface.load.constant offset = 3 : index
// Shapes are constructed using the dynamic dimensions and tied such that following code has everything.
%6 = shapex.make_ranked_shape %2, %3 : (index, index) -> !shapex.ranked_shape<[7,?,24,?]>
%7 = flow.dispatch.tie_shape %0, %6 : (!flow.dispatch.input<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]>) -> !flow.dispatch.input<7x?x24x?xf32>
%8 = shapex.make_ranked_shape %4, %5 : (index, index) -> !shapex.ranked_shape<[?,?,1024]>
%9 = flow.dispatch.tie_shape %1, %8 : (!flow.dispatch.output<?x?x1024xf32>, !shapex.ranked_shape<[?,?,1024]>) -> !flow.dispatch.output<?x?x1024xf32>
%10 = flow.dispatch.input.load %7 : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32>
%11 = "test.tile_math"(%10) : (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32>
flow.dispatch.output.store %11, %9 : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32>
return
}
}
}
}
// WIP to canonicalize tie shape into the load/stores so that the shape ops aren't required.
// Done by hand here!
hal.executable @dynamic_tiled_dispatch attributes {sym_visibility = "private"} {
hal.interface @legacy_io attributes {push_constants = 4 : i32} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
}
hal.executable.target @vmla, filter="vmla" {
hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<7x?x24x?xf32>, !flow.dispatch.output<?x?x1024xf32>, index, index, index, index) -> ()}
module {
func @entry() {
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<7x?x24x?xf32>
%1 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<?x?x1024xf32>
// These are the 4 (2 + 2) dynamic dimensions that were arguments, now fetched via the interface.
%2 = hal.interface.load.constant offset = 0 : index
%3 = hal.interface.load.constant offset = 1 : index
%4 = hal.interface.load.constant offset = 2 : index
%5 = hal.interface.load.constant offset = 3 : index
// Shape tie is canonicalized away and becomes direct values/attrs on the load/stores.
%10 = flow.dispatch.input.load %0, shape = [7, %2, 24, %3] : !flow.dispatch.input<7x?x24x?xf32> -> tensor<7x?x24x?xf32>
%11 = "test.tile_math"(%10) : (tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32>
flow.dispatch.output.store %11, %1, shape = [%4, %5, 1024] : tensor<?x?x1024xf32> -> !flow.dispatch.output<?x?x1024xf32>
return
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment