Skip to content

Instantly share code, notes, and snippets.

@benvanik
Last active December 11, 2020 20:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benvanik/454d9638eedf6bd4056a2e7fba52ec24 to your computer and use it in GitHub Desktop.
Save benvanik/454d9638eedf6bd4056a2e7fba52ec24 to your computer and use it in GitHub Desktop.
// RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s
func @workgroups(%arg0 : tensor<?x4xf32>, %arg1 : index) -> tensor<4x?xf32> {
%x = constant 100 : index
%y = constant 50 : index
%0 = flow.dispatch.workgroups[%x, %y](%arg0, %arg1) : (tensor<?x4xf32>, index) -> (tensor<4x?xf32>) =
(%arg0_capture : !flow.dispatch.input<?x4xf32>, %arg1_capture : index, %ret0 : !flow.dispatch.output<4x?xf32>) {
// Query symbolic workgroup info:
%id_x = flow.dispatch.workgroup.id[0] : index
%id_y = flow.dispatch.workgroup.id[1] : index
%count_x = flow.dispatch.workgroup.count[0] : index
%count_y = flow.dispatch.workgroup.count[1] : index
%size_x = flow.dispatch.workgroup.size[0] : index
%size_y = flow.dispatch.workgroup.size[1] : index
// Query shapes directly from IO (static dims will fold):
%arg0_shape = flow.dispatch.shape %arg0_capture : !flow.dispatch.input<?x4xf32> -> !shapex.ranked_shape<[?,4]>
%arg0_dim0 = shapex.ranked_dim %arg0_shape[0] : !shapex.ranked_shape<[?,4]> -> index
%arg0_dim1 = shapex.ranked_dim %arg0_shape[1] : !shapex.ranked_shape<[?,4]> -> index
"test.sink"(%arg0_dim0, %arg0_dim1) : (index, index) -> ()
%ret0_shape = flow.dispatch.shape %ret0 : !flow.dispatch.output<4x?xf32> -> !shapex.ranked_shape<[4,?]>
%ret0_dim0 = shapex.ranked_dim %ret0_shape[0] : !shapex.ranked_shape<[4,?]> -> index
%ret0_dim1 = shapex.ranked_dim %ret0_shape[1] : !shapex.ranked_shape<[4,?]> -> index
"test.sink"(%ret0_dim0, %ret0_dim1) : (index, index) -> ()
// Load tensors (optional offsets/sizes/strides):
%arg0_value = flow.dispatch.input.load %arg0_capture : !flow.dispatch.input<?x4xf32> -> tensor<?x4xf32>
// Can query shapes on these tensors too:
%arg0_shape_indirect = shapex.get_ranked_shape %arg0_value : tensor<?x4xf32> -> !shapex.ranked_shape<[?,4]>
// Operate on tensors with full IO shapes:
%0 = "test.math"(%arg0_value, %arg0_shape_indirect, %ret0_shape) : (tensor<?x4xf32>, !shapex.ranked_shape<[?,4]>, !shapex.ranked_shape<[4,?]>) -> (tensor<4x?xf32>)
// Store tensors (optional offsets/sizes/strides):
flow.dispatch.output.store %0, %ret0 : tensor<4x?xf32> -> !flow.dispatch.output<4x?xf32>
flow.return
}
return %0 : tensor<4x?xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment