Skip to content

Instantly share code, notes, and snippets.

View silvasean's full-sized avatar

Sean Silva silvasean

View GitHub Profile
@silvasean
silvasean / no-tie-shape.md
Last active May 9, 2020 03:35
A possible trick to avoid tie_shape and better structure lowering - May 8 2020

In IREE we use shapex.tie_shape to tie together a tensor value and its shape. This is an integral part of how we handle dynamic shapes, which can be summarized as:

  1. before dispatch region formation, we
  2. create shapex.get_ranked_shape (similar to upstream shape.shape_of) on every tensor in the program, creating !shapex.ranked_shape values for every tensor.
  3. create shapex.tie_shape ops tieing each tensor to its corresponding shape
  4. Have a series of patterns that replace shapex.get_ranked_shape with some computation on the operands of the op that defines the operand of shapex.get_ranked_shape
  5. Apply these patterns iteratively to "bypass" all the get_ranked_shape ops using shape transfer functions so that all shapes (or as many as possible) in the program can be eliminated and replaced with computations on the shapes of the function inputs (which ideally is just sim
// configuration: -pass-pipeline='func(xla-legalize-control-flow, iree-legalize-input-ops, canonicalize, affine-loop-fusion), loop-invariant-code-motion, func(memref-dataflow-opt, canonicalize, simplify-affine-structures, cse, canonicalize), convert-from-tuple-calling-convention, func(canonicalize), iree-identify-reduction-regions, func(cse, iree-identify-dispatch-regions, cse, iree-fold-compatible-dispatch-regions, iree-rematerialize-dispatch-constants), iree-outline-dispatch-regions, iree-outline-reduction-regions, func(canonicalize), iree-drop-unreachable-module-functions, iree-drop-unused-executables'
// note: verifyPasses=true
module {
func @calculate(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>) -> tensor<?x?xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I25!S21!k0_0k1_1k2_2k3_3k4_4R3!_0"}, tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$", "tfshape$", "tfshape$"]} {
%0 = xla_hlo.constant den