Skip to content

Instantly share code, notes, and snippets.

@rsuderman
Created June 17, 2024 22:13
Show Gist options
  • Save rsuderman/db4e923c8e01677e5a6033e8db8e7d0d to your computer and use it in GitHub Desktop.
Save rsuderman/db4e923c8e01677e5a6033e8db8e7d0d to your computer and use it in GitHub Desktop.
module @module {
util.global private @__auto.token_embd.weight = #stream.parameter.named<"model"::"token_embd.weight"> : tensor<32000x3200xf16>
util.global private @__auto.blk.0.attn_norm.weight = #stream.parameter.named<"model"::"blk.0.attn_norm.weight"> : tensor<3200xf32>
util.global private @__auto.blk.0.attn_k.weight = #stream.parameter.named<"model"::"blk.0.attn_k.weight"> : tensor<3200x3200xf16>
util.global private @__auto.blk.0.attn_v.weight = #stream.parameter.named<"model"::"blk.0.attn_v.weight"> : tensor<3200x3200xf16>
util.func public @decode_bs4$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant dense<1> : tensor<1x1xi64>
%cst_0 = arith.constant dense<0> : tensor<1x1xi64>
%c0_i64 = arith.constant 0 : i64
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c32000 = arith.constant 32000 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%c16_i64 = arith.constant 16 : i64
%c1_i64 = arith.constant 1 : i64
%cst_2 = arith.constant 2.000000e+00 : f32
%cst_3 = arith.constant 3.200000e+03 : f32
%cst_4 = arith.constant 9.99999997E-7 : f32
%cst_5 = arith.constant 1.600000e+01 : f32
%0 = hal.tensor.import wait(%arg5) => %arg0 : !hal.buffer_view -> tensor<4x1xi64>
%1 = hal.tensor.import wait(%arg5) => %arg2 : !hal.buffer_view -> tensor<4xi64>
%2 = hal.buffer_view.dim<%arg3 : !hal.buffer_view>[1] : index
%3 = hal.tensor.import wait(%arg5) => %arg3 : !hal.buffer_view -> tensor<4x?xi64>{%2}
%__auto.token_embd.weight = util.global.load @__auto.token_embd.weight : tensor<32000x3200xf16>
%__auto.blk.0.attn_norm.weight = util.global.load @__auto.blk.0.attn_norm.weight : tensor<3200xf32>
%__auto.blk.0.attn_k.weight = util.global.load @__auto.blk.0.attn_k.weight : tensor<3200x3200xf16>
%__auto.blk.0.attn_v.weight = util.global.load @__auto.blk.0.attn_v.weight : tensor<3200x3200xf16>
%4 = hal.buffer_view.dim<%arg4 : !hal.buffer_view>[0] : index
%5 = hal.tensor.import wait(%arg5) => %arg4 : !hal.buffer_view -> tensor<?x2662400xf16>{%4}
%6 = tensor.empty() : tensor<4x1x3200xf16>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<4x1xi64>) outs(%6 : tensor<4x1x3200xf16>) {
^bb0(%in: i64, %out: f16):
%52 = arith.index_cast %in : i64 to index
%53 = linalg.index 2 : index
%54 = arith.cmpi slt, %52, %c32000 : index
cf.assert %54, "index must be smaller than dim size"
%55 = arith.cmpi sge, %in, %c0_i64 : i64
cf.assert %55, "index must be larger or equal to 0"
%extracted = tensor.extract %__auto.token_embd.weight[%52, %53] : tensor<32000x3200xf16>
linalg.yield %extracted : f16
} -> tensor<4x1x3200xf16>
%8 = tensor.empty() : tensor<4x1x3200xf32>
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7 : tensor<4x1x3200xf16>) outs(%8 : tensor<4x1x3200xf32>) {
^bb0(%in: f16, %out: f32):
%52 = arith.extf %in : f16 to f32
linalg.yield %52 : f32
} -> tensor<4x1x3200xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9 : tensor<4x1x3200xf32>) outs(%8 : tensor<4x1x3200xf32>) {
^bb0(%in: f32, %out: f32):
%52 = math.powf %in, %cst_2 : f32
linalg.yield %52 : f32
} -> tensor<4x1x3200xf32>
%11 = tensor.empty() : tensor<4x1x1xf32>
%12 = linalg.fill ins(%cst_1 : f32) outs(%11 : tensor<4x1x1xf32>) -> tensor<4x1x1xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%10 : tensor<4x1x3200xf32>) outs(%12 : tensor<4x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%52 = arith.addf %in, %out : f32
linalg.yield %52 : f32
} -> tensor<4x1x1xf32>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<4x1x1xf32>) outs(%11 : tensor<4x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%52 = arith.divf %in, %cst_3 : f32
linalg.yield %52 : f32
} -> tensor<4x1x1xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%14 : tensor<4x1x1xf32>) outs(%11 : tensor<4x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%52 = arith.addf %in, %cst_4 : f32
linalg.yield %52 : f32
} -> tensor<4x1x1xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15 : tensor<4x1x1xf32>) outs(%11 : tensor<4x1x1xf32>) {
^bb0(%in: f32, %out: f32):
%52 = math.rsqrt %in : f32
linalg.yield %52 : f32
} -> tensor<4x1x1xf32>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, 0, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9, %16 : tensor<4x1x3200xf32>, tensor<4x1x1xf32>) outs(%8 : tensor<4x1x3200xf32>) {
^bb0(%in: f32, %in_39: f32, %out: f32):
%52 = arith.mulf %in, %in_39 : f32
linalg.yield %52 : f32
} -> tensor<4x1x3200xf32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%17, %__auto.blk.0.attn_norm.weight : tensor<4x1x3200xf32>, tensor<3200xf32>) outs(%8 : tensor<4x1x3200xf32>) {
^bb0(%in: f32, %in_39: f32, %out: f32):
%52 = arith.mulf %in, %in_39 : f32
linalg.yield %52 : f32
} -> tensor<4x1x3200xf32>
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18 : tensor<4x1x3200xf32>) outs(%6 : tensor<4x1x3200xf16>) {
^bb0(%in: f32, %out: f16):
%52 = arith.truncf %in : f32 to f16
linalg.yield %52 : f16
} -> tensor<4x1x3200xf16>
%20 = tensor.empty() : tensor<3200x3200xf16>
%21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%__auto.blk.0.attn_k.weight : tensor<3200x3200xf16>) outs(%20 : tensor<3200x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<3200x3200xf16>
%collapsed = tensor.collapse_shape %19 [[0], [1, 2]] : tensor<4x1x3200xf16> into tensor<4x3200xf16>
%22 = tensor.empty() : tensor<4x3200xf32>
%23 = linalg.fill ins(%cst_1 : f32) outs(%22 : tensor<4x3200xf32>) -> tensor<4x3200xf32>
%24 = linalg.matmul ins(%collapsed, %21 : tensor<4x3200xf16>, tensor<3200x3200xf16>) outs(%23 : tensor<4x3200xf32>) -> tensor<4x3200xf32>
%25 = tensor.empty() : tensor<4x3200xf16>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<4x3200xf32>) outs(%25 : tensor<4x3200xf16>) {
^bb0(%in: f32, %out: f16):
%52 = arith.truncf %in : f32 to f16
linalg.yield %52 : f16
} -> tensor<4x3200xf16>
%27 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%__auto.blk.0.attn_v.weight : tensor<3200x3200xf16>) outs(%20 : tensor<3200x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<3200x3200xf16>
%28 = linalg.matmul ins(%collapsed, %27 : tensor<4x3200xf16>, tensor<3200x3200xf16>) outs(%23 : tensor<4x3200xf32>) -> tensor<4x3200xf32>
%29 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%28 : tensor<4x3200xf32>) outs(%25 : tensor<4x3200xf16>) {
^bb0(%in: f32, %out: f16):
%52 = arith.truncf %in : f32 to f16
linalg.yield %52 : f16
} -> tensor<4x3200xf16>
%expanded = tensor.expand_shape %26 [[0], [1, 2, 3]] output_shape [4, 1, 32, 100] : tensor<4x3200xf16> into tensor<4x1x32x100xf16>
%expanded_6 = tensor.expand_shape %29 [[0], [1, 2, 3]] output_shape [4, 1, 32, 100] : tensor<4x3200xf16> into tensor<4x1x32x100xf16>
%30 = tensor.empty() : tensor<4xi64>
%31 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1 : tensor<4xi64>) outs(%30 : tensor<4xi64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.addi %in, %c1_i64 : i64
linalg.yield %52 : i64
} -> tensor<4xi64>
%dim = tensor.dim %5, %c0 : tensor<?x2662400xf16>
%expanded_7 = tensor.expand_shape %5 [[0], [1, 2, 3, 4, 5]] output_shape [%dim, 26, 2, 16, 32, 100] : tensor<?x2662400xf16> into tensor<?x26x2x16x32x100xf16>
%extracted_slice = tensor.extract_slice %31[0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%collapsed_8 = tensor.collapse_shape %extracted_slice [] : tensor<1xi64> into tensor<i64>
%dim_9 = tensor.dim %3, %c1 : tensor<4x?xi64>
%extracted_slice_10 = tensor.extract_slice %3[0, 0] [1, %dim_9] [1, 1] : tensor<4x?xi64> to tensor<1x?xi64>
%collapsed_11 = tensor.collapse_shape %extracted_slice_10 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
%32 = tensor.empty() : tensor<i64>
%33 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%collapsed_8 : tensor<i64>) outs(%32 : tensor<i64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.sitofp %in : i64 to f32
%53 = arith.divf %52, %cst_5 : f32
%54 = math.floor %53 : f32
%55 = arith.fptosi %54 : f32 to i64
linalg.yield %55 : i64
} -> tensor<i64>
%expanded_12 = tensor.expand_shape %33 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%34 = tensor.empty() : tensor<1xi64>
%35 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%expanded_12 : tensor<1xi64>) outs(%34 : tensor<1xi64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.index_cast %in : i64 to index
%extracted = tensor.extract %collapsed_11[%52] : tensor<?xi64>
linalg.yield %extracted : i64
} -> tensor<1xi64>
%36 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%collapsed_8 : tensor<i64>) outs(%32 : tensor<i64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.remsi %in, %c16_i64 : i64
linalg.yield %52 : i64
} -> tensor<i64>
%extracted_slice_13 = tensor.extract_slice %expanded[0, 0, 0, 0] [1, 1, 32, 100] [1, 1, 1, 1] : tensor<4x1x32x100xf16> to tensor<1x1x32x100xf16>
%collapsed_14 = tensor.collapse_shape %extracted_slice_13 [[0, 1, 2], [3]] : tensor<1x1x32x100xf16> into tensor<32x100xf16>
%expanded_15 = tensor.expand_shape %35 [[0, 1]] output_shape [1, 1] : tensor<1xi64> into tensor<1x1xi64>
%expanded_16 = tensor.expand_shape %36 [] output_shape [1, 1] : tensor<i64> into tensor<1x1xi64>
%concat = tensor.concat dim(1) %expanded_15, %cst_0, %cst_0, %expanded_16 : (tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
%expanded_17 = tensor.expand_shape %collapsed_14 [[0, 1, 2, 3, 4], [5]] output_shape [1, 1, 1, 1, 32, 100] : tensor<32x100xf16> into tensor<1x1x1x1x32x100xf16>
%37 = tensor.empty() : tensor<1x4xi32>
%38 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%concat : tensor<1x4xi64>) outs(%37 : tensor<1x4xi32>) {
^bb0(%in: i64, %out: i32):
%52 = arith.trunci %in : i64 to i32
linalg.yield %52 : i32
} -> tensor<1x4xi32>
%39 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_17, %38 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%expanded_7 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
%extracted_slice_18 = tensor.extract_slice %expanded_6[0, 0, 0, 0] [1, 1, 32, 100] [1, 1, 1, 1] : tensor<4x1x32x100xf16> to tensor<1x1x32x100xf16>
%collapsed_19 = tensor.collapse_shape %extracted_slice_18 [[0, 1, 2], [3]] : tensor<1x1x32x100xf16> into tensor<32x100xf16>
%concat_20 = tensor.concat dim(1) %expanded_15, %cst_0, %cst, %expanded_16 : (tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
%expanded_21 = tensor.expand_shape %collapsed_19 [[0, 1, 2, 3, 4], [5]] output_shape [1, 1, 1, 1, 32, 100] : tensor<32x100xf16> into tensor<1x1x1x1x32x100xf16>
%40 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%concat_20 : tensor<1x4xi64>) outs(%37 : tensor<1x4xi32>) {
^bb0(%in: i64, %out: i32):
%52 = arith.trunci %in : i64 to i32
linalg.yield %52 : i32
} -> tensor<1x4xi32>
%41 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_21, %40 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%39 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
%extracted_slice_22 = tensor.extract_slice %31[1] [1] [1] : tensor<4xi64> to tensor<1xi64>
%collapsed_23 = tensor.collapse_shape %extracted_slice_22 [] : tensor<1xi64> into tensor<i64>
%extracted_slice_24 = tensor.extract_slice %3[1, 0] [1, %dim_9] [1, 1] : tensor<4x?xi64> to tensor<1x?xi64>
%collapsed_25 = tensor.collapse_shape %extracted_slice_24 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
%42 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%collapsed_23 : tensor<i64>) outs(%32 : tensor<i64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.sitofp %in : i64 to f32
%53 = arith.divf %52, %cst_5 : f32
%54 = math.floor %53 : f32
%55 = arith.fptosi %54 : f32 to i64
linalg.yield %55 : i64
} -> tensor<i64>
%expanded_26 = tensor.expand_shape %42 [] output_shape [1] : tensor<i64> into tensor<1xi64>
%43 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%expanded_26 : tensor<1xi64>) outs(%34 : tensor<1xi64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.index_cast %in : i64 to index
%extracted = tensor.extract %collapsed_25[%52] : tensor<?xi64>
linalg.yield %extracted : i64
} -> tensor<1xi64>
%44 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%collapsed_23 : tensor<i64>) outs(%32 : tensor<i64>) {
^bb0(%in: i64, %out: i64):
%52 = arith.remsi %in, %c16_i64 : i64
linalg.yield %52 : i64
} -> tensor<i64>
%extracted_slice_27 = tensor.extract_slice %expanded[1, 0, 0, 0] [1, 1, 32, 100] [1, 1, 1, 1] : tensor<4x1x32x100xf16> to tensor<1x1x32x100xf16>
%collapsed_28 = tensor.collapse_shape %extracted_slice_27 [[0, 1, 2], [3]] : tensor<1x1x32x100xf16> into tensor<32x100xf16>
%expanded_29 = tensor.expand_shape %43 [[0, 1]] output_shape [1, 1] : tensor<1xi64> into tensor<1x1xi64>
%expanded_30 = tensor.expand_shape %44 [] output_shape [1, 1] : tensor<i64> into tensor<1x1xi64>
%concat_31 = tensor.concat dim(1) %expanded_29, %cst_0, %cst_0, %expanded_30 : (tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
%expanded_32 = tensor.expand_shape %collapsed_28 [[0, 1, 2, 3, 4], [5]] output_shape [1, 1, 1, 1, 32, 100] : tensor<32x100xf16> into tensor<1x1x1x1x32x100xf16>
%45 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%concat_31 : tensor<1x4xi64>) outs(%37 : tensor<1x4xi32>) {
^bb0(%in: i64, %out: i32):
%52 = arith.trunci %in : i64 to i32
linalg.yield %52 : i32
} -> tensor<1x4xi32>
%46 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_32, %45 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%41 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
%extracted_slice_33 = tensor.extract_slice %expanded_6[1, 0, 0, 0] [1, 1, 32, 100] [1, 1, 1, 1] : tensor<4x1x32x100xf16> to tensor<1x1x32x100xf16>
%collapsed_34 = tensor.collapse_shape %extracted_slice_33 [[0, 1, 2], [3]] : tensor<1x1x32x100xf16> into tensor<32x100xf16>
%concat_35 = tensor.concat dim(1) %expanded_29, %cst_0, %cst, %expanded_30 : (tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x4xi64>
%expanded_36 = tensor.expand_shape %collapsed_34 [[0, 1, 2, 3, 4], [5]] output_shape [1, 1, 1, 1, 32, 100] : tensor<32x100xf16> into tensor<1x1x1x1x32x100xf16>
%47 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%concat_35 : tensor<1x4xi64>) outs(%37 : tensor<1x4xi32>) {
^bb0(%in: i64, %out: i32):
%52 = arith.trunci %in : i64 to i32
linalg.yield %52 : i32
} -> tensor<1x4xi32>
%48 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_36, %47 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%46 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
%collapsed_37 = tensor.collapse_shape %48 [[0], [1, 2, 3, 4, 5]] : tensor<?x26x2x16x32x100xf16> into tensor<?x2662400xf16>
%dim_38 = tensor.dim %collapsed_37, %c0 : tensor<?x2662400xf16>
%49 = hal.tensor.alias wait(%arg5) => %collapsed_37 : tensor<?x2662400xf16>{%dim_38} to %arg4 : !hal.buffer_view
%50:2 = hal.tensor.barrier join(%49, %0 : tensor<?x2662400xf16>, tensor<4x1xi64>) => %arg6 : !hal.fence
%51 = hal.tensor.export %50#1 : tensor<4x1xi64> -> !hal.buffer_view
util.return %51 : !hal.buffer_view
}
util.func public @decode_bs4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%0 = util.null : !hal.fence
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
%1 = util.call @decode_bs4$async(%arg0, %arg1, %arg2, %arg3, %arg4, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32
util.return %1 : !hal.buffer_view
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment