Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save vivekkhandelwal1/1fcc3e79ef8bf94a1988f75ff9b88a03 to your computer and use it in GitHub Desktop.
Save vivekkhandelwal1/1fcc3e79ef8bf94a1988f75ff9b88a03 to your computer and use it in GitHub Desktop.
bert_input_ir.mlir:1133:12: error: 'linalg.generic' op unexpected result less than 0 at expression #0 in (d0, d1, d2, d3) -> (d0, d3)
%404 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%401, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%398 : tensor<?x12x?x?xf32>) {
^
bert_input_ir.mlir:20:3: note: called from
func @forward(%arg0: tensor<1x512xi64>) -> tensor<?x2xf32> {
^
bert_input_ir.mlir:1133:12: note: see current operation: %333 = "linalg.generic"(%332, %263, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<1x?xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%404 = linalg.generic {indexing_maps = [#map2, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%401, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%398 : tensor<?x12x?x?xf32>) {
^
// -----// IR Dump After Canonicalizer Failed //----- //
"builtin.func"() ({
^bb0(%arg0: !hal.buffer_view):
%0 = "arith.constant"() {value = 1 : index} : () -> index
%1 = "arith.constant"() {value = false} : () -> i1
%2 = "arith.constant"() {value = 0 : index} : () -> index
%3 = "arith.constant"() {value = 512 : i64} : () -> i64
%4 = "arith.constant"() {value = 0 : i64} : () -> i64
%5 = "arith.constant"() {value = 512 : index} : () -> index
%6 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x2xf32>} : () -> tensor<384x2xf32>
%7 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%8 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%9 = "arith.constant"() {value = 5.6568542494923806 : f64} : () -> f64
%10 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%11 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%12 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%13 = "arith.constant"() {value = -1.000000e+04 : f32} : () -> f32
%14 = "arith.constant"() {value = 30522 : i64} : () -> i64
%15 = "arith.constant"() {value = 0.000000e+00 : f32} : () -> f32
%16 = "arith.constant"() {value = -3.40282347E+38 : f32} : () -> f32
%17 = "arith.constant"() {value = 5.000000e-01 : f32} : () -> f32
%18 = "arith.constant"() {value = 2.000000e+00 : f32} : () -> f32
%19 = "arith.constant"() {value = 1.000000e+00 : f32} : () -> f32
%20 = "arith.constant"() {value = 9223372036854775807 : i64} : () -> i64
%21 = "arith.constant"() {value = 9.9999999999999998E-13 : f64} : () -> f64
%22 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>} : () -> tensor<1x512xi64>
%23 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>} : () -> tensor<30522x384xf32>
%24 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>} : () -> tensor<2x384xf32>
%25 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>} : () -> tensor<512x384xf32>
%26 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%27 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%28 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%29 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%30 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%31 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%32 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%33 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%34 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%35 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%36 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%37 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%38 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%39 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%40 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%41 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%42 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%43 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%44 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%45 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%46 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%47 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%48 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%49 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%50 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%51 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%52 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%53 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%54 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%55 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%56 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%57 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%58 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%59 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%60 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%61 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%62 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%63 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%64 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%65 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%66 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%67 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%68 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%69 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%70 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%71 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%72 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%73 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%74 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%75 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%76 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%77 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%78 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%79 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%80 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%81 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%82 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%83 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%84 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%85 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%86 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%87 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%88 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%89 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%90 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%91 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%92 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%93 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%94 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%95 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%96 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%97 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%98 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%99 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%100 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%101 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%102 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%103 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%104 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%105 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%106 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%107 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%108 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%109 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%110 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%111 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%112 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%113 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%114 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%115 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%116 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%117 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%118 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%119 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%120 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%121 = "arith.constant"() {value = dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>} : () -> tensor<2xf32>
%122 = "arith.constant"() {value = 3.840000e+02 : f32} : () -> f32
%123 = "hal.tensor.import"(%arg0) {target_encoding = tensor<1x512xi64>} : (!hal.buffer_view) -> tensor<1x512xi64>
%124 = "linalg.init_tensor"() {static_sizes = []} : () -> tensor<i64>
%125 = "linalg.fill"(%3, %124) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<i64>) -> tensor<i64>
%126 = "tensor.extract"(%125) : (tensor<i64>) -> i64
%127 = "arith.index_cast"(%126) : (i64) -> index
%128 = "linalg.init_tensor"(%127) {static_sizes = [1, -1]} : (index) -> tensor<1x?xf32>
%129 = "linalg.fill"(%19, %128) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x?xf32>) -> tensor<1x?xf32>
%130 = "tensor.extract_slice"(%129, %127) {operand_segment_sizes = dense<[1, 0, 1, 0]> : vector<4xi32>, static_offsets = [0, 0], static_sizes = [1, -1], static_strides = [1, 1]} : (tensor<1x?xf32>, index) -> tensor<?xf32>
%131 = "tensor.dim"(%130, %2) : (tensor<?xf32>, index) -> index
%132 = "tensor.dim"(%130, %2) : (tensor<?xf32>, index) -> index
%133 = "flow.tensor.reshape"(%130, %132, %0, %131) {operand_segment_sizes = dense<[1, 1, 2]> : vector<3xi32>} : (tensor<?xf32>, index, index, index) -> tensor<?x1x1x?xf32>
%134 = "arith.cmpi"(%4, %126) {predicate = 4 : i64} : (i64, i64) -> i1
%135 = "std.select"(%134, %126, %4) : (i1, i64, i64) -> i64
%136 = "arith.index_cast"(%135) : (i64) -> index
%137 = "arith.cmpi"(%20, %126) {predicate = 4 : i64} : (i64, i64) -> i1
%138 = "std.select"(%137, %126, %20) : (i1, i64, i64) -> i64
%139 = "arith.index_cast"(%138) : (i64) -> index
%140 = "arith.cmpi"(%139, %136) {predicate = 5 : i64} : (index, index) -> i1
%141 = "std.select"(%140, %139, %136) : (i1, index, index) -> index
%142 = "arith.subi"(%141, %136) : (index, index) -> index
%143 = "tensor.extract_slice"(%133, %136, %142) {operand_segment_sizes = dense<[1, 1, 1, 0]> : vector<4xi32>, static_offsets = [0, 0, 0, -9223372036854775808], static_sizes = [1, 1, 1, -1], static_strides = [1, 1, 1, 1]} : (tensor<?x1x1x?xf32>, index, index) -> tensor<?xf32>
%144 = "linalg.init_tensor"(%142) {static_sizes = [-1]} : (index) -> tensor<?xf32>
%145 = "linalg.generic"(%143, %144) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.subf"(%19, %arg1) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %13) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%146 = "tensor.dim"(%145, %2) : (tensor<?xf32>, index) -> index
%147 = "tensor.dim"(%145, %2) : (tensor<?xf32>, index) -> index
%148 = "linalg.fill"(%3, %124) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<i64>) -> tensor<i64>
%149 = "tensor.extract"(%148) : (tensor<i64>) -> i64
%150 = "arith.addi"(%149, %3) : (i64, i64) -> i64
%151 = "arith.cmpi"(%149, %4) {predicate = 5 : i64} : (i64, i64) -> i1
%152 = "std.select"(%151, %149, %150) : (i1, i64, i64) -> i64
%153 = "arith.cmpi"(%152, %4) {predicate = 2 : i64} : (i64, i64) -> i1
%154 = "std.select"(%153, %4, %152) : (i1, i64, i64) -> i64
%155 = "arith.cmpi"(%154, %3) {predicate = 4 : i64} : (i64, i64) -> i1
%156 = "std.select"(%155, %3, %154) : (i1, i64, i64) -> i64
%157 = "arith.index_cast"(%156) : (i64) -> index
%158 = "arith.cmpi"(%157, %2) {predicate = 5 : i64} : (index, index) -> i1
%159 = "std.select"(%158, %157, %2) : (i1, index, index) -> index
%160 = "tensor.extract_slice"(%22, %159) {operand_segment_sizes = dense<[1, 0, 1, 0]> : vector<4xi32>, static_offsets = [0, 0], static_sizes = [1, -1], static_strides = [1, 1]} : (tensor<1x512xi64>, index) -> tensor<?xi64>
%161 = "flow.tensor.reshape"(%123) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<1x512xi64>) -> tensor<512xi64>
%162 = "arith.cmpi"(%5, %159) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%162) {msg = "mismatched size for broadcast"} : (i1) -> ()
%163 = "linalg.init_tensor"() {static_sizes = [512, 384]} : () -> tensor<512x384xf32>
%164 = "linalg.generic"(%161, %160, %163) ({
^bb0(%arg1: i64, %arg2: i64, %arg3: f32):
%600 = "linalg.index"() {dim = 1 : i64} : () -> index
%601 = "tensor.extract"(%24, %2, %600) : (tensor<2x384xf32>, index, index) -> f32
%602 = "arith.index_cast"(%arg1) : (i64) -> index
%603 = "arith.cmpi"(%arg1, %14) {predicate = 2 : i64} : (i64, i64) -> i1
"std.assert"(%603) {msg = "index must be smaller than dim size"} : (i1) -> ()
%604 = "arith.cmpi"(%arg1, %4) {predicate = 5 : i64} : (i64, i64) -> i1
"std.assert"(%604) {msg = "index must be larger or equal to 0"} : (i1) -> ()
%605 = "tensor.extract"(%23, %602, %600) : (tensor<30522x384xf32>, index, index) -> f32
%606 = "arith.index_cast"(%arg2) : (i64) -> index
%607 = "arith.cmpi"(%arg2, %3) {predicate = 2 : i64} : (i64, i64) -> i1
"std.assert"(%607) {msg = "index must be smaller than dim size"} : (i1) -> ()
%608 = "arith.cmpi"(%arg2, %4) {predicate = 5 : i64} : (i64, i64) -> i1
"std.assert"(%608) {msg = "index must be larger or equal to 0"} : (i1) -> ()
%609 = "tensor.extract"(%25, %606, %600) : (tensor<512x384xf32>, index, index) -> f32
%610 = "arith.addf"(%605, %601) : (f32, f32) -> f32
%611 = "arith.addf"(%610, %609) : (f32, f32) -> f32
"linalg.yield"(%611) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<512xi64>, tensor<?xi64>, tensor<512x384xf32>) -> tensor<512x384xf32>
%165 = "linalg.init_tensor"() {static_sizes = [512]} : () -> tensor<512xf32>
%166 = "linalg.fill"(%15, %165) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<512xf32>) -> tensor<512xf32>
%167 = "linalg.generic"(%164, %166) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512xf32>) -> tensor<512xf32>
%168 = "linalg.generic"(%167, %165) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>
%169 = "linalg.fill"(%15, %165) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<512xf32>) -> tensor<512xf32>
%170 = "linalg.generic"(%164, %168, %169) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>
%171 = "linalg.generic"(%164, %168, %170, %27, %26, %163) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%172 = "flow.tensor.reshape"(%171) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%173 = "linalg.generic"(%28, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%174 = "flow.tensor.reshape"(%173) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%175 = "linalg.batch_matmul"(%172, %12, %174) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%176 = "linalg.generic"(%29, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%177 = "flow.tensor.reshape"(%176) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%178 = "linalg.batch_matmul"(%172, %11, %177) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%179 = "flow.tensor.reshape"(%178, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<1x512x384xf32>, index, index) -> tensor<?x?x12x32xf32>
%180 = "linalg.generic"(%30, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%181 = "flow.tensor.reshape"(%180) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%182 = "linalg.batch_matmul"(%172, %10, %181) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%183 = "flow.tensor.reshape"(%182, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<1x512x384xf32>, index, index) -> tensor<?x?x12x32xf32>
%184 = "flow.tensor.reshape"(%175, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<1x512x384xf32>, index, index) -> tensor<?x?x12x32xf32>
%185 = "linalg.init_tensor"() {static_sizes = [12, 32, 512]} : () -> tensor<12x32x512xf32>
%186 = "linalg.generic"(%179, %185) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<12x32x512xf32>) -> tensor<12x32x512xf32>
%187 = "linalg.init_tensor"() {static_sizes = [12, 512, 32]} : () -> tensor<12x512x32xf32>
%188 = "linalg.generic"(%184, %187) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%189 = "linalg.init_tensor"() {static_sizes = [12, 512, 512]} : () -> tensor<12x512x512xf32>
%190 = "linalg.fill"(%15, %189) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%191 = "linalg.generic"(%188, %186, %190) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x32xf32>, tensor<12x32x512xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%192 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%192) {msg = "mismatched size for broadcast"} : (i1) -> ()
%193 = "linalg.generic"(%191, %145, %189) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<?xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%194 = "linalg.init_tensor"() {static_sizes = [12, 512]} : () -> tensor<12x512xf32>
%195 = "linalg.fill"(%16, %194) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512xf32>) -> tensor<12x512xf32>
%196 = "linalg.init_tensor"() {static_sizes = [12, 512]} : () -> tensor<12x512xi64>
%197 = "linalg.fill"(%4, %196) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<12x512xi64>) -> tensor<12x512xi64>
%198:2 = "linalg.generic"(%193, %195, %197) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 2 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512xi64>) -> (tensor<12x512xf32>, tensor<12x512xi64>)
%199 = "linalg.generic"(%193, %198#0, %189) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%200 = "linalg.fill"(%15, %194) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512xf32>) -> tensor<12x512xf32>
%201 = "linalg.generic"(%199, %200) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>) -> tensor<12x512xf32>
%202 = "linalg.generic"(%199, %201, %189) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%203 = "linalg.generic"(%183, %187) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%204 = "linalg.fill"(%15, %187) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%205 = "linalg.generic"(%202, %203, %204) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512x32xf32>, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%206 = "linalg.init_tensor"() {static_sizes = [512, 12, 32]} : () -> tensor<512x12x32xf32>
%207 = "linalg.generic"(%205, %206) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<12x512x32xf32>, tensor<512x12x32xf32>) -> tensor<512x12x32xf32>
%208 = "flow.tensor.reshape"(%207, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<512x12x32xf32>, index, index) -> tensor<?x?x384xf32>
%209 = "linalg.generic"(%31, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%210 = "flow.tensor.reshape"(%209) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%211 = "linalg.batch_matmul"(%208, %8, %210) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%212 = "linalg.init_tensor"() {static_sizes = [1, 512, 384]} : () -> tensor<1x512x384xf32>
%213 = "flow.tensor.reshape"(%211) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<1x512x384xf32>) -> tensor<512x384xf32>
%214 = "linalg.generic"(%213, %171, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%215 = "linalg.init_tensor"() {static_sizes = [1, 512]} : () -> tensor<1x512xf32>
%216 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%217 = "linalg.generic"(%214, %216) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%218 = "linalg.generic"(%217, %215) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%219 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%220 = "linalg.generic"(%214, %218, %219) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%221 = "linalg.generic"(%214, %218, %220, %33, %32, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%222 = "linalg.init_tensor"() {static_sizes = [1, 512, 1536]} : () -> tensor<1x512x1536xf32>
%223 = "linalg.init_tensor"() {static_sizes = [1, 384, 1536]} : () -> tensor<1x384x1536xf32>
%224 = "linalg.generic"(%34, %222) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%225 = "linalg.generic"(%35, %223) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<1x384x1536xf32>) -> tensor<1x384x1536xf32>
%226 = "linalg.batch_matmul"(%221, %225, %224) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%227 = "linalg.generic"(%226, %222) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%228 = "linalg.init_tensor"() {static_sizes = [1, 1536, 384]} : () -> tensor<1x1536x384xf32>
%229 = "linalg.generic"(%36, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%230 = "linalg.generic"(%37, %228) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<1x1536x384xf32>) -> tensor<1x1536x384xf32>
%231 = "linalg.batch_matmul"(%227, %230, %229) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x1536x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%232 = "linalg.generic"(%231, %221, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%233 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%234 = "linalg.generic"(%232, %233) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%235 = "linalg.generic"(%234, %215) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%236 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%237 = "linalg.generic"(%232, %235, %236) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%238 = "linalg.generic"(%232, %235, %237, %39, %38, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%239 = "linalg.init_tensor"() {static_sizes = [1, 384, 384]} : () -> tensor<1x384x384xf32>
%240 = "linalg.generic"(%40, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%241 = "linalg.generic"(%41, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%242 = "linalg.batch_matmul"(%238, %241, %240) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%243 = "tensor.cast"(%242) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%244 = "linalg.generic"(%42, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%245 = "linalg.generic"(%43, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%246 = "linalg.batch_matmul"(%238, %245, %244) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%247 = "tensor.cast"(%246) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%248 = "flow.tensor.reshape"(%247, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%249 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 32]} : () -> tensor<1x12x512x32xf32>
%250 = "linalg.generic"(%44, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%251 = "linalg.generic"(%45, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%252 = "linalg.batch_matmul"(%238, %251, %250) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%253 = "tensor.cast"(%252) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%254 = "flow.tensor.reshape"(%253, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%255 = "linalg.generic"(%254, %249) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%256 = "flow.tensor.reshape"(%243, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%257 = "linalg.generic"(%256, %249) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%258 = "linalg.init_tensor"() {static_sizes = [1, 12, 32, 512]} : () -> tensor<1x12x32x512xf32>
%259 = "linalg.generic"(%248, %258) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x32x512xf32>) -> tensor<1x12x32x512xf32>
%260 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 512]} : () -> tensor<1x12x512x512xf32>
%261 = "linalg.fill"(%15, %260) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%262 = "linalg.generic"(%257, %259, %261) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
"std.assert"(%192) {msg = "mismatched size for broadcast"} : (i1) -> ()
%263 = "flow.tensor.reshape"(%145, %147, %146) {operand_segment_sizes = dense<1> : vector<3xi32>} : (tensor<?xf32>, index, index) -> tensor<1x?xf32>
%264 = "linalg.generic"(%262, %263, %260) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x?xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%265 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xf32>
%266 = "linalg.fill"(%16, %265) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%267 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xi64>
%268 = "linalg.fill"(%4, %267) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<1x12x512xi64>) -> tensor<1x12x512xi64>
%269:2 = "linalg.generic"(%264, %266, %268) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512xi64>) -> (tensor<1x12x512xf32>, tensor<1x12x512xi64>)
%270 = "linalg.generic"(%264, %269#0, %260) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%271 = "linalg.fill"(%15, %265) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%272 = "linalg.generic"(%270, %271) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%273 = "linalg.generic"(%270, %272, %260) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%274 = "linalg.fill"(%15, %249) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%275 = "linalg.generic"(%273, %255, %274) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%276 = "linalg.init_tensor"() {static_sizes = [1, 512, 12, 32]} : () -> tensor<1x512x12x32xf32>
%277 = "linalg.generic"(%275, %276) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x512x12x32xf32>) -> tensor<1x512x12x32xf32>
%278 = "tensor.cast"(%277) : (tensor<1x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%279 = "flow.tensor.reshape"(%278, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%280 = "linalg.generic"(%46, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%281 = "linalg.generic"(%47, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%282 = "linalg.batch_matmul"(%279, %281, %280) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%283 = "linalg.init_tensor"() {static_sizes = [0, 512, 384]} : () -> tensor<0x512x384xf32>
%284 = "linalg.generic"(%282, %238, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%285 = "linalg.init_tensor"() {static_sizes = [0, 512]} : () -> tensor<0x512xf32>
%286 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%287 = "linalg.generic"(%284, %286) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%288 = "linalg.generic"(%287, %285) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%289 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%290 = "linalg.generic"(%284, %288, %289) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%291 = "linalg.generic"(%284, %288, %290, %49, %48, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%292 = "linalg.init_tensor"() {static_sizes = [0, 512, 1536]} : () -> tensor<0x512x1536xf32>
%293 = "linalg.init_tensor"() {static_sizes = [0, 384, 1536]} : () -> tensor<0x384x1536xf32>
%294 = "linalg.generic"(%50, %292) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%295 = "linalg.generic"(%51, %293) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<0x384x1536xf32>) -> tensor<0x384x1536xf32>
%296 = "linalg.batch_matmul"(%291, %295, %294) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%297 = "linalg.generic"(%296, %292) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%298 = "linalg.init_tensor"() {static_sizes = [0, 1536, 384]} : () -> tensor<0x1536x384xf32>
%299 = "linalg.generic"(%52, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%300 = "linalg.generic"(%53, %298) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<0x1536x384xf32>) -> tensor<0x1536x384xf32>
%301 = "linalg.batch_matmul"(%297, %300, %299) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x1536x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%302 = "linalg.generic"(%301, %291, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%303 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%304 = "linalg.generic"(%302, %303) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%305 = "linalg.generic"(%304, %285) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%306 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%307 = "linalg.generic"(%302, %305, %306) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%308 = "linalg.generic"(%302, %305, %307, %55, %54, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%309 = "linalg.init_tensor"() {static_sizes = [0, 384, 384]} : () -> tensor<0x384x384xf32>
%310 = "linalg.generic"(%56, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%311 = "linalg.generic"(%57, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%312 = "linalg.batch_matmul"(%308, %311, %310) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%313 = "tensor.cast"(%312) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%314 = "linalg.generic"(%58, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%315 = "linalg.generic"(%59, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%316 = "linalg.batch_matmul"(%308, %315, %314) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%317 = "tensor.cast"(%316) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%318 = "flow.tensor.reshape"(%317, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%319 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 32]} : () -> tensor<0x12x512x32xf32>
%320 = "linalg.generic"(%60, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%321 = "linalg.generic"(%61, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%322 = "linalg.batch_matmul"(%308, %321, %320) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%323 = "tensor.cast"(%322) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%324 = "flow.tensor.reshape"(%323, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%325 = "linalg.generic"(%324, %319) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%326 = "flow.tensor.reshape"(%313, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%327 = "linalg.generic"(%326, %319) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%328 = "linalg.init_tensor"() {static_sizes = [0, 12, 32, 512]} : () -> tensor<0x12x32x512xf32>
%329 = "linalg.generic"(%318, %328) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x32x512xf32>) -> tensor<0x12x32x512xf32>
%330 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 512]} : () -> tensor<0x12x512x512xf32>
%331 = "linalg.fill"(%15, %330) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%332 = "linalg.generic"(%327, %329, %331) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x12x32x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
"std.assert"(%192) {msg = "mismatched size for broadcast"} : (i1) -> ()
%333 = "linalg.generic"(%332, %263, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<1x?xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%334 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xf32>
%335 = "linalg.fill"(%16, %334) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%336 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xi64>
%337 = "linalg.fill"(%4, %336) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<0x12x512xi64>) -> tensor<0x12x512xi64>
%338:2 = "linalg.generic"(%333, %335, %337) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512xi64>) -> (tensor<0x12x512xf32>, tensor<0x12x512xi64>)
%339 = "linalg.generic"(%333, %338#0, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%340 = "linalg.fill"(%15, %334) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%341 = "linalg.generic"(%339, %340) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%342 = "linalg.generic"(%339, %341, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%343 = "linalg.fill"(%15, %319) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%344 = "linalg.generic"(%342, %325, %343) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%345 = "linalg.init_tensor"() {static_sizes = [0, 512, 12, 32]} : () -> tensor<0x512x12x32xf32>
%346 = "linalg.generic"(%344, %345) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x512x12x32xf32>) -> tensor<0x512x12x32xf32>
%347 = "tensor.cast"(%346) : (tensor<0x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%348 = "flow.tensor.reshape"(%347, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%349 = "linalg.generic"(%62, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%350 = "linalg.generic"(%63, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%351 = "linalg.batch_matmul"(%348, %350, %349) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%352 = "linalg.init_tensor"() {static_sizes = [1, 512, 384]} : () -> tensor<1x512x384xf32>
%353 = "linalg.generic"(%351, %308, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%354 = "linalg.init_tensor"() {static_sizes = [1, 512]} : () -> tensor<1x512xf32>
%355 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%356 = "linalg.generic"(%353, %355) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%357 = "linalg.generic"(%356, %354) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%358 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%359 = "linalg.generic"(%353, %357, %358) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%360 = "linalg.generic"(%353, %357, %359, %65, %64, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%361 = "linalg.init_tensor"() {static_sizes = [1, 512, 1536]} : () -> tensor<1x512x1536xf32>
%362 = "linalg.init_tensor"() {static_sizes = [1, 384, 1536]} : () -> tensor<1x384x1536xf32>
%363 = "linalg.generic"(%66, %361) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%364 = "linalg.generic"(%67, %362) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<1x384x1536xf32>) -> tensor<1x384x1536xf32>
%365 = "linalg.batch_matmul"(%360, %364, %363) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%366 = "linalg.generic"(%365, %361) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%367 = "linalg.init_tensor"() {static_sizes = [1, 1536, 384]} : () -> tensor<1x1536x384xf32>
%368 = "linalg.generic"(%68, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%369 = "linalg.generic"(%69, %367) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<1x1536x384xf32>) -> tensor<1x1536x384xf32>
%370 = "linalg.batch_matmul"(%366, %369, %368) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x1536x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%371 = "linalg.generic"(%370, %360, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%372 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%373 = "linalg.generic"(%371, %372) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%374 = "linalg.generic"(%373, %354) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%375 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%376 = "linalg.generic"(%371, %374, %375) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%377 = "linalg.generic"(%371, %374, %376, %71, %70, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%378 = "linalg.init_tensor"() {static_sizes = [1, 384, 384]} : () -> tensor<1x384x384xf32>
%379 = "linalg.generic"(%72, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%380 = "linalg.generic"(%73, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%381 = "linalg.batch_matmul"(%377, %380, %379) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%382 = "tensor.cast"(%381) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%383 = "linalg.generic"(%74, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%384 = "linalg.generic"(%75, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%385 = "linalg.batch_matmul"(%377, %384, %383) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%386 = "tensor.cast"(%385) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%387 = "flow.tensor.reshape"(%386, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%388 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 32]} : () -> tensor<1x12x512x32xf32>
%389 = "linalg.generic"(%76, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%390 = "linalg.generic"(%77, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%391 = "linalg.batch_matmul"(%377, %390, %389) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%392 = "tensor.cast"(%391) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%393 = "flow.tensor.reshape"(%392, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%394 = "linalg.generic"(%393, %388) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%395 = "flow.tensor.reshape"(%382, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%396 = "linalg.generic"(%395, %388) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%397 = "linalg.init_tensor"() {static_sizes = [1, 12, 32, 512]} : () -> tensor<1x12x32x512xf32>
%398 = "linalg.generic"(%387, %397) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x32x512xf32>) -> tensor<1x12x32x512xf32>
%399 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 512]} : () -> tensor<1x12x512x512xf32>
%400 = "linalg.fill"(%15, %399) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%401 = "linalg.generic"(%396, %398, %400) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%402 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%402) {msg = "mismatched size for broadcast"} : (i1) -> ()
%403 = "linalg.generic"(%401, %263, %399) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x?xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%404 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xf32>
%405 = "linalg.fill"(%16, %404) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%406 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xi64>
%407 = "linalg.fill"(%4, %406) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<1x12x512xi64>) -> tensor<1x12x512xi64>
%408:2 = "linalg.generic"(%403, %405, %407) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512xi64>) -> (tensor<1x12x512xf32>, tensor<1x12x512xi64>)
%409 = "linalg.generic"(%403, %408#0, %399) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%410 = "linalg.fill"(%15, %404) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%411 = "linalg.generic"(%409, %410) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%412 = "linalg.generic"(%409, %411, %399) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%413 = "linalg.fill"(%15, %388) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%414 = "linalg.generic"(%412, %394, %413) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%415 = "linalg.init_tensor"() {static_sizes = [1, 512, 12, 32]} : () -> tensor<1x512x12x32xf32>
%416 = "linalg.generic"(%414, %415) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x512x12x32xf32>) -> tensor<1x512x12x32xf32>
%417 = "tensor.cast"(%416) : (tensor<1x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%418 = "flow.tensor.reshape"(%417, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%419 = "linalg.generic"(%78, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%420 = "linalg.generic"(%79, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%421 = "linalg.batch_matmul"(%418, %420, %419) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%422 = "linalg.init_tensor"() {static_sizes = [0, 512, 384]} : () -> tensor<0x512x384xf32>
%423 = "linalg.generic"(%421, %377, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%424 = "linalg.init_tensor"() {static_sizes = [0, 512]} : () -> tensor<0x512xf32>
%425 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%426 = "linalg.generic"(%423, %425) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%427 = "linalg.generic"(%426, %424) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%428 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%429 = "linalg.generic"(%423, %427, %428) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%430 = "linalg.generic"(%423, %427, %429, %81, %80, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%431 = "linalg.init_tensor"() {static_sizes = [0, 512, 1536]} : () -> tensor<0x512x1536xf32>
%432 = "linalg.init_tensor"() {static_sizes = [0, 384, 1536]} : () -> tensor<0x384x1536xf32>
%433 = "linalg.generic"(%82, %431) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%434 = "linalg.generic"(%83, %432) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<0x384x1536xf32>) -> tensor<0x384x1536xf32>
%435 = "linalg.batch_matmul"(%430, %434, %433) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%436 = "linalg.generic"(%435, %431) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%437 = "linalg.init_tensor"() {static_sizes = [0, 1536, 384]} : () -> tensor<0x1536x384xf32>
%438 = "linalg.generic"(%84, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%439 = "linalg.generic"(%85, %437) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<0x1536x384xf32>) -> tensor<0x1536x384xf32>
%440 = "linalg.batch_matmul"(%436, %439, %438) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x1536x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%441 = "linalg.generic"(%440, %430, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%442 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%443 = "linalg.generic"(%441, %442) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%444 = "linalg.generic"(%443, %424) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%445 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%446 = "linalg.generic"(%441, %444, %445) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%447 = "linalg.generic"(%441, %444, %446, %87, %86, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%448 = "linalg.init_tensor"() {static_sizes = [0, 384, 384]} : () -> tensor<0x384x384xf32>
%449 = "linalg.generic"(%88, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%450 = "linalg.generic"(%89, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%451 = "linalg.batch_matmul"(%447, %450, %449) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%452 = "tensor.cast"(%451) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%453 = "linalg.generic"(%90, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%454 = "linalg.generic"(%91, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%455 = "linalg.batch_matmul"(%447, %454, %453) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%456 = "tensor.cast"(%455) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%457 = "flow.tensor.reshape"(%456, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%458 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 32]} : () -> tensor<0x12x512x32xf32>
%459 = "linalg.generic"(%92, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%460 = "linalg.generic"(%93, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%461 = "linalg.batch_matmul"(%447, %460, %459) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%462 = "tensor.cast"(%461) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%463 = "flow.tensor.reshape"(%462, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%464 = "linalg.generic"(%463, %458) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%465 = "flow.tensor.reshape"(%452, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%466 = "linalg.generic"(%465, %458) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%467 = "linalg.init_tensor"() {static_sizes = [0, 12, 32, 512]} : () -> tensor<0x12x32x512xf32>
%468 = "linalg.generic"(%457, %467) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x32x512xf32>) -> tensor<0x12x32x512xf32>
%469 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 512]} : () -> tensor<0x12x512x512xf32>
%470 = "linalg.fill"(%15, %469) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%471 = "linalg.generic"(%466, %468, %470) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x12x32x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%472 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%472) {msg = "mismatched size for broadcast"} : (i1) -> ()
%473 = "linalg.generic"(%471, %263, %469) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<1x?xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%474 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xf32>
%475 = "linalg.fill"(%16, %474) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%476 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xi64>
%477 = "linalg.fill"(%4, %476) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<0x12x512xi64>) -> tensor<0x12x512xi64>
%478:2 = "linalg.generic"(%473, %475, %477) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512xi64>) -> (tensor<0x12x512xf32>, tensor<0x12x512xi64>)
%479 = "linalg.generic"(%473, %478#0, %469) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%480 = "linalg.fill"(%15, %474) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%481 = "linalg.generic"(%479, %480) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%482 = "linalg.generic"(%479, %481, %469) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%483 = "linalg.fill"(%15, %458) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%484 = "linalg.generic"(%482, %464, %483) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%485 = "linalg.init_tensor"() {static_sizes = [0, 512, 12, 32]} : () -> tensor<0x512x12x32xf32>
%486 = "linalg.generic"(%484, %485) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x512x12x32xf32>) -> tensor<0x512x12x32xf32>
%487 = "tensor.cast"(%486) : (tensor<0x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%488 = "flow.tensor.reshape"(%487, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%489 = "linalg.generic"(%94, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%490 = "linalg.generic"(%95, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%491 = "linalg.batch_matmul"(%488, %490, %489) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%492 = "linalg.init_tensor"() {static_sizes = [1, 512, 384]} : () -> tensor<1x512x384xf32>
%493 = "linalg.generic"(%491, %447, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%494 = "linalg.init_tensor"() {static_sizes = [1, 512]} : () -> tensor<1x512xf32>
%495 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%496 = "linalg.generic"(%493, %495) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%497 = "linalg.generic"(%496, %494) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%498 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%499 = "linalg.generic"(%493, %497, %498) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%500 = "linalg.generic"(%493, %497, %499, %97, %96, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%501 = "linalg.init_tensor"() {static_sizes = [1, 512, 1536]} : () -> tensor<1x512x1536xf32>
%502 = "linalg.init_tensor"() {static_sizes = [1, 384, 1536]} : () -> tensor<1x384x1536xf32>
%503 = "linalg.generic"(%98, %501) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%504 = "linalg.generic"(%99, %502) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<1x384x1536xf32>) -> tensor<1x384x1536xf32>
%505 = "linalg.batch_matmul"(%500, %504, %503) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%506 = "linalg.generic"(%505, %501) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%507 = "linalg.init_tensor"() {static_sizes = [1, 1536, 384]} : () -> tensor<1x1536x384xf32>
%508 = "linalg.generic"(%100, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%509 = "linalg.generic"(%101, %507) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<1x1536x384xf32>) -> tensor<1x1536x384xf32>
%510 = "linalg.batch_matmul"(%506, %509, %508) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x1536x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%511 = "linalg.generic"(%510, %500, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%512 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%513 = "linalg.generic"(%511, %512) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%514 = "linalg.generic"(%513, %494) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%515 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%516 = "linalg.generic"(%511, %514, %515) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%517 = "linalg.generic"(%511, %514, %516, %103, %102, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%518 = "linalg.init_tensor"() {static_sizes = [1, 384, 384]} : () -> tensor<1x384x384xf32>
%519 = "linalg.generic"(%104, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%520 = "linalg.generic"(%105, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%521 = "linalg.batch_matmul"(%517, %520, %519) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%522 = "tensor.cast"(%521) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%523 = "linalg.generic"(%106, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%524 = "linalg.generic"(%107, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%525 = "linalg.batch_matmul"(%517, %524, %523) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%526 = "tensor.cast"(%525) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%527 = "flow.tensor.reshape"(%526, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%528 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 32]} : () -> tensor<1x12x512x32xf32>
%529 = "linalg.generic"(%108, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%530 = "linalg.generic"(%109, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%531 = "linalg.batch_matmul"(%517, %530, %529) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%532 = "tensor.cast"(%531) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%533 = "flow.tensor.reshape"(%532, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%534 = "linalg.generic"(%533, %528) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%535 = "flow.tensor.reshape"(%522, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%536 = "linalg.generic"(%535, %528) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%537 = "linalg.init_tensor"() {static_sizes = [1, 12, 32, 512]} : () -> tensor<1x12x32x512xf32>
%538 = "linalg.generic"(%527, %537) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x32x512xf32>) -> tensor<1x12x32x512xf32>
%539 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 512]} : () -> tensor<1x12x512x512xf32>
%540 = "linalg.fill"(%15, %539) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%541 = "linalg.generic"(%536, %538, %540) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%542 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%542) {msg = "mismatched size for broadcast"} : (i1) -> ()
%543 = "linalg.generic"(%541, %263, %539) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x?xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%544 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xf32>
%545 = "linalg.fill"(%16, %544) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%546 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xi64>
%547 = "linalg.fill"(%4, %546) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<1x12x512xi64>) -> tensor<1x12x512xi64>
%548:2 = "linalg.generic"(%543, %545, %547) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512xi64>) -> (tensor<1x12x512xf32>, tensor<1x12x512xi64>)
%549 = "linalg.generic"(%543, %548#0, %539) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%550 = "linalg.fill"(%15, %544) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%551 = "linalg.generic"(%549, %550) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%552 = "linalg.generic"(%549, %551, %539) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%553 = "linalg.fill"(%15, %528) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%554 = "linalg.generic"(%552, %534, %553) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%555 = "linalg.init_tensor"() {static_sizes = [1, 512, 12, 32]} : () -> tensor<1x512x12x32xf32>
%556 = "linalg.generic"(%554, %555) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x512x12x32xf32>) -> tensor<1x512x12x32xf32>
%557 = "tensor.cast"(%556) : (tensor<1x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%558 = "flow.tensor.reshape"(%557, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%559 = "linalg.generic"(%110, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%560 = "linalg.generic"(%111, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%561 = "linalg.batch_matmul"(%558, %560, %559) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%562 = "linalg.init_tensor"() {static_sizes = [0, 512, 384]} : () -> tensor<0x512x384xf32>
%563 = "linalg.generic"(%561, %517, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%564 = "linalg.init_tensor"() {static_sizes = [0, 512]} : () -> tensor<0x512xf32>
%565 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%566 = "linalg.generic"(%563, %565) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%567 = "linalg.generic"(%566, %564) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%568 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%569 = "linalg.generic"(%563, %567, %568) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%570 = "linalg.generic"(%563, %567, %569, %113, %112, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%571 = "linalg.init_tensor"() {static_sizes = [0, 512, 1536]} : () -> tensor<0x512x1536xf32>
%572 = "linalg.init_tensor"() {static_sizes = [0, 384, 1536]} : () -> tensor<0x384x1536xf32>
%573 = "linalg.generic"(%114, %571) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%574 = "linalg.generic"(%115, %572) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<0x384x1536xf32>) -> tensor<0x384x1536xf32>
%575 = "linalg.batch_matmul"(%570, %574, %573) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%576 = "linalg.generic"(%575, %571) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%577 = "linalg.init_tensor"() {static_sizes = [0, 1536, 384]} : () -> tensor<0x1536x384xf32>
%578 = "linalg.generic"(%116, %562) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%579 = "linalg.generic"(%117, %577) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<0x1536x384xf32>) -> tensor<0x1536x384xf32>
%580 = "linalg.batch_matmul"(%576, %579, %578) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x1536x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%581 = "linalg.generic"(%580, %570, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%582 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%583 = "linalg.generic"(%581, %582) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%584 = "linalg.generic"(%583, %564) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%585 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%586 = "linalg.generic"(%581, %584, %585) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%587 = "linalg.generic"(%581, %584, %586, %119, %118, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%588 = "tensor.extract_slice"(%587) {operand_segment_sizes = dense<[1, 0, 0, 0]> : vector<4xi32>, static_offsets = [0, 0, 0], static_sizes = [0, 1, 384], static_strides = [1, 1, 1]} : (tensor<0x512x384xf32>) -> tensor<0x1x384xf32>
%589 = "tensor.cast"(%588) : (tensor<0x1x384xf32>) -> tensor<?x?x384xf32>
%590 = "flow.tensor.reshape"(%589, %2, %0, %2) {operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index) -> tensor<?x384xf32>
%591 = "linalg.init_tensor"() {static_sizes = [0, 384]} : () -> tensor<0x384xf32>
%592 = "linalg.generic"(%120, %591) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x384xf32>) -> tensor<0x384xf32>
%593 = "linalg.matmul"(%590, %7, %592) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x384xf32>, tensor<384x384xf32>, tensor<0x384xf32>) -> tensor<0x384xf32>
%594 = "linalg.generic"(%593, %591) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.tanh"(%arg1) : (f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x384xf32>, tensor<0x384xf32>) -> tensor<0x384xf32>
%595 = "linalg.init_tensor"() {static_sizes = [0, 2]} : () -> tensor<0x2xf32>
%596 = "linalg.generic"(%121, %595) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<2xf32>, tensor<0x2xf32>) -> tensor<0x2xf32>
%597 = "linalg.matmul"(%594, %6, %596) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x384xf32>, tensor<384x2xf32>, tensor<0x2xf32>) -> tensor<0x2xf32>
%598 = "tensor.cast"(%597) : (tensor<0x2xf32>) -> tensor<?x2xf32>
%599 = "hal.tensor.export"(%598, %2) {operand_segment_sizes = dense<[1, 1, 0]> : vector<3xi32>, source_encoding = tensor<?x2xf32>} : (tensor<?x2xf32>, index) -> !hal.buffer_view
"std.return"(%599) : (!hal.buffer_view) -> ()
}) {iree.abi.stub, sym_name = "forward", type = (!hal.buffer_view) -> !hal.buffer_view} : () -> ()
bert_input_ir.mlir:19:1: error: conversion from source -> vm failed
module attributes {torch.debug_module_name = "MiniLMSequenceClassification"} {
^
bert_input_ir.mlir:19:1: note: see current operation: "builtin.module"() ({
"builtin.func"() ({
^bb0(%arg0: !hal.buffer_view):
%0 = "arith.constant"() {value = 1 : index} : () -> index
%1 = "arith.constant"() {value = false} : () -> i1
%2 = "arith.constant"() {value = 0 : index} : () -> index
%3 = "arith.constant"() {value = 512 : i64} : () -> i64
%4 = "arith.constant"() {value = 0 : i64} : () -> i64
%5 = "arith.constant"() {value = 512 : index} : () -> index
%6 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x2xf32>} : () -> tensor<384x2xf32>
%7 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%8 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%9 = "arith.constant"() {value = 5.6568542494923806 : f64} : () -> f64
%10 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%11 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%12 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>} : () -> tensor<1x384x384xf32>
%13 = "arith.constant"() {value = -1.000000e+04 : f32} : () -> f32
%14 = "arith.constant"() {value = 30522 : i64} : () -> i64
%15 = "arith.constant"() {value = 0.000000e+00 : f32} : () -> f32
%16 = "arith.constant"() {value = -3.40282347E+38 : f32} : () -> f32
%17 = "arith.constant"() {value = 5.000000e-01 : f32} : () -> f32
%18 = "arith.constant"() {value = 2.000000e+00 : f32} : () -> f32
%19 = "arith.constant"() {value = 1.000000e+00 : f32} : () -> f32
%20 = "arith.constant"() {value = 9223372036854775807 : i64} : () -> i64
%21 = "arith.constant"() {value = 9.9999999999999998E-13 : f64} : () -> f64
%22 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>} : () -> tensor<1x512xi64>
%23 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>} : () -> tensor<30522x384xf32>
%24 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>} : () -> tensor<2x384xf32>
%25 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>} : () -> tensor<512x384xf32>
%26 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%27 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%28 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%29 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%30 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%31 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%32 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%33 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%34 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%35 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%36 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%37 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%38 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%39 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%40 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%41 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%42 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%43 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%44 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%45 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%46 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%47 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%48 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%49 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%50 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%51 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%52 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%53 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%54 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%55 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%56 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%57 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%58 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%59 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%60 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%61 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%62 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%63 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%64 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%65 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%66 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%67 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%68 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%69 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%70 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%71 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%72 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%73 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%74 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%75 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%76 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%77 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%78 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%79 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%80 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%81 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%82 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%83 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%84 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%85 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%86 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%87 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%88 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%89 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%90 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%91 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%92 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%93 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%94 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%95 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%96 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%97 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%98 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%99 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%100 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%101 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%102 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%103 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%104 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%105 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%106 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%107 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%108 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%109 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%110 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%111 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>} : () -> tensor<384x384xf32>
%112 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%113 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%114 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>} : () -> tensor<1536xf32>
%115 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>} : () -> tensor<1536x384xf32>
%116 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%117 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>} : () -> tensor<384x1536xf32>
%118 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%119 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%120 = "arith.constant"() {value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>} : () -> tensor<384xf32>
%121 = "arith.constant"() {value = dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>} : () -> tensor<2xf32>
%122 = "arith.constant"() {value = 3.840000e+02 : f32} : () -> f32
%123 = "hal.tensor.import"(%arg0) {target_encoding = tensor<1x512xi64>} : (!hal.buffer_view) -> tensor<1x512xi64>
%124 = "linalg.init_tensor"() {static_sizes = []} : () -> tensor<i64>
%125 = "linalg.fill"(%3, %124) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<i64>) -> tensor<i64>
%126 = "tensor.extract"(%125) : (tensor<i64>) -> i64
%127 = "arith.index_cast"(%126) : (i64) -> index
%128 = "linalg.init_tensor"(%127) {static_sizes = [1, -1]} : (index) -> tensor<1x?xf32>
%129 = "linalg.fill"(%19, %128) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x?xf32>) -> tensor<1x?xf32>
%130 = "tensor.extract_slice"(%129, %127) {operand_segment_sizes = dense<[1, 0, 1, 0]> : vector<4xi32>, static_offsets = [0, 0], static_sizes = [1, -1], static_strides = [1, 1]} : (tensor<1x?xf32>, index) -> tensor<?xf32>
%131 = "tensor.dim"(%130, %2) : (tensor<?xf32>, index) -> index
%132 = "tensor.dim"(%130, %2) : (tensor<?xf32>, index) -> index
%133 = "flow.tensor.reshape"(%130, %132, %0, %131) {operand_segment_sizes = dense<[1, 1, 2]> : vector<3xi32>} : (tensor<?xf32>, index, index, index) -> tensor<?x1x1x?xf32>
%134 = "arith.cmpi"(%4, %126) {predicate = 4 : i64} : (i64, i64) -> i1
%135 = "std.select"(%134, %126, %4) : (i1, i64, i64) -> i64
%136 = "arith.index_cast"(%135) : (i64) -> index
%137 = "arith.cmpi"(%20, %126) {predicate = 4 : i64} : (i64, i64) -> i1
%138 = "std.select"(%137, %126, %20) : (i1, i64, i64) -> i64
%139 = "arith.index_cast"(%138) : (i64) -> index
%140 = "arith.cmpi"(%139, %136) {predicate = 5 : i64} : (index, index) -> i1
%141 = "std.select"(%140, %139, %136) : (i1, index, index) -> index
%142 = "arith.subi"(%141, %136) : (index, index) -> index
%143 = "tensor.extract_slice"(%133, %136, %142) {operand_segment_sizes = dense<[1, 1, 1, 0]> : vector<4xi32>, static_offsets = [0, 0, 0, -9223372036854775808], static_sizes = [1, 1, 1, -1], static_strides = [1, 1, 1, 1]} : (tensor<?x1x1x?xf32>, index, index) -> tensor<?xf32>
%144 = "linalg.init_tensor"(%142) {static_sizes = [-1]} : (index) -> tensor<?xf32>
%145 = "linalg.generic"(%143, %144) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.subf"(%19, %arg1) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %13) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%146 = "tensor.dim"(%145, %2) : (tensor<?xf32>, index) -> index
%147 = "tensor.dim"(%145, %2) : (tensor<?xf32>, index) -> index
%148 = "linalg.fill"(%3, %124) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<i64>) -> tensor<i64>
%149 = "tensor.extract"(%148) : (tensor<i64>) -> i64
%150 = "arith.addi"(%149, %3) : (i64, i64) -> i64
%151 = "arith.cmpi"(%149, %4) {predicate = 5 : i64} : (i64, i64) -> i1
%152 = "std.select"(%151, %149, %150) : (i1, i64, i64) -> i64
%153 = "arith.cmpi"(%152, %4) {predicate = 2 : i64} : (i64, i64) -> i1
%154 = "std.select"(%153, %4, %152) : (i1, i64, i64) -> i64
%155 = "arith.cmpi"(%154, %3) {predicate = 4 : i64} : (i64, i64) -> i1
%156 = "std.select"(%155, %3, %154) : (i1, i64, i64) -> i64
%157 = "arith.index_cast"(%156) : (i64) -> index
%158 = "arith.cmpi"(%157, %2) {predicate = 5 : i64} : (index, index) -> i1
%159 = "std.select"(%158, %157, %2) : (i1, index, index) -> index
%160 = "tensor.extract_slice"(%22, %159) {operand_segment_sizes = dense<[1, 0, 1, 0]> : vector<4xi32>, static_offsets = [0, 0], static_sizes = [1, -1], static_strides = [1, 1]} : (tensor<1x512xi64>, index) -> tensor<?xi64>
%161 = "flow.tensor.reshape"(%123) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<1x512xi64>) -> tensor<512xi64>
%162 = "arith.cmpi"(%5, %159) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%162) {msg = "mismatched size for broadcast"} : (i1) -> ()
%163 = "linalg.init_tensor"() {static_sizes = [512, 384]} : () -> tensor<512x384xf32>
%164 = "linalg.generic"(%161, %160, %163) ({
^bb0(%arg1: i64, %arg2: i64, %arg3: f32):
%600 = "linalg.index"() {dim = 1 : i64} : () -> index
%601 = "tensor.extract"(%24, %2, %600) : (tensor<2x384xf32>, index, index) -> f32
%602 = "arith.index_cast"(%arg1) : (i64) -> index
%603 = "arith.cmpi"(%arg1, %14) {predicate = 2 : i64} : (i64, i64) -> i1
"std.assert"(%603) {msg = "index must be smaller than dim size"} : (i1) -> ()
%604 = "arith.cmpi"(%arg1, %4) {predicate = 5 : i64} : (i64, i64) -> i1
"std.assert"(%604) {msg = "index must be larger or equal to 0"} : (i1) -> ()
%605 = "tensor.extract"(%23, %602, %600) : (tensor<30522x384xf32>, index, index) -> f32
%606 = "arith.index_cast"(%arg2) : (i64) -> index
%607 = "arith.cmpi"(%arg2, %3) {predicate = 2 : i64} : (i64, i64) -> i1
"std.assert"(%607) {msg = "index must be smaller than dim size"} : (i1) -> ()
%608 = "arith.cmpi"(%arg2, %4) {predicate = 5 : i64} : (i64, i64) -> i1
"std.assert"(%608) {msg = "index must be larger or equal to 0"} : (i1) -> ()
%609 = "tensor.extract"(%25, %606, %600) : (tensor<512x384xf32>, index, index) -> f32
%610 = "arith.addf"(%605, %601) : (f32, f32) -> f32
%611 = "arith.addf"(%610, %609) : (f32, f32) -> f32
"linalg.yield"(%611) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<512xi64>, tensor<?xi64>, tensor<512x384xf32>) -> tensor<512x384xf32>
%165 = "linalg.init_tensor"() {static_sizes = [512]} : () -> tensor<512xf32>
%166 = "linalg.fill"(%15, %165) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<512xf32>) -> tensor<512xf32>
%167 = "linalg.generic"(%164, %166) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512xf32>) -> tensor<512xf32>
%168 = "linalg.generic"(%167, %165) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>
%169 = "linalg.fill"(%15, %165) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<512xf32>) -> tensor<512xf32>
%170 = "linalg.generic"(%164, %168, %169) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>
%171 = "linalg.generic"(%164, %168, %170, %27, %26, %163) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%172 = "flow.tensor.reshape"(%171) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%173 = "linalg.generic"(%28, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%174 = "flow.tensor.reshape"(%173) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%175 = "linalg.batch_matmul"(%172, %12, %174) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%176 = "linalg.generic"(%29, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%177 = "flow.tensor.reshape"(%176) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%178 = "linalg.batch_matmul"(%172, %11, %177) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%179 = "flow.tensor.reshape"(%178, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<1x512x384xf32>, index, index) -> tensor<?x?x12x32xf32>
%180 = "linalg.generic"(%30, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%181 = "flow.tensor.reshape"(%180) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%182 = "linalg.batch_matmul"(%172, %10, %181) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%183 = "flow.tensor.reshape"(%182, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<1x512x384xf32>, index, index) -> tensor<?x?x12x32xf32>
%184 = "flow.tensor.reshape"(%175, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<1x512x384xf32>, index, index) -> tensor<?x?x12x32xf32>
%185 = "linalg.init_tensor"() {static_sizes = [12, 32, 512]} : () -> tensor<12x32x512xf32>
%186 = "linalg.generic"(%179, %185) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<12x32x512xf32>) -> tensor<12x32x512xf32>
%187 = "linalg.init_tensor"() {static_sizes = [12, 512, 32]} : () -> tensor<12x512x32xf32>
%188 = "linalg.generic"(%184, %187) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%189 = "linalg.init_tensor"() {static_sizes = [12, 512, 512]} : () -> tensor<12x512x512xf32>
%190 = "linalg.fill"(%15, %189) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%191 = "linalg.generic"(%188, %186, %190) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x32xf32>, tensor<12x32x512xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%192 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%192) {msg = "mismatched size for broadcast"} : (i1) -> ()
%193 = "linalg.generic"(%191, %145, %189) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<?xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%194 = "linalg.init_tensor"() {static_sizes = [12, 512]} : () -> tensor<12x512xf32>
%195 = "linalg.fill"(%16, %194) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512xf32>) -> tensor<12x512xf32>
%196 = "linalg.init_tensor"() {static_sizes = [12, 512]} : () -> tensor<12x512xi64>
%197 = "linalg.fill"(%4, %196) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<12x512xi64>) -> tensor<12x512xi64>
%198:2 = "linalg.generic"(%193, %195, %197) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 2 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512xi64>) -> (tensor<12x512xf32>, tensor<12x512xi64>)
%199 = "linalg.generic"(%193, %198#0, %189) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%200 = "linalg.fill"(%15, %194) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512xf32>) -> tensor<12x512xf32>
%201 = "linalg.generic"(%199, %200) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>) -> tensor<12x512xf32>
%202 = "linalg.generic"(%199, %201, %189) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512x512xf32>) -> tensor<12x512x512xf32>
%203 = "linalg.generic"(%183, %187) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%204 = "linalg.fill"(%15, %187) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%205 = "linalg.generic"(%202, %203, %204) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<12x512x512xf32>, tensor<12x512x32xf32>, tensor<12x512x32xf32>) -> tensor<12x512x32xf32>
%206 = "linalg.init_tensor"() {static_sizes = [512, 12, 32]} : () -> tensor<512x12x32xf32>
%207 = "linalg.generic"(%205, %206) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<12x512x32xf32>, tensor<512x12x32xf32>) -> tensor<512x12x32xf32>
%208 = "flow.tensor.reshape"(%207, %0, %5) {operand_segment_sizes = dense<[1, 0, 2]> : vector<3xi32>} : (tensor<512x12x32xf32>, index, index) -> tensor<?x?x384xf32>
%209 = "linalg.generic"(%31, %163) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<512x384xf32>) -> tensor<512x384xf32>
%210 = "flow.tensor.reshape"(%209) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<512x384xf32>) -> tensor<1x512x384xf32>
%211 = "linalg.batch_matmul"(%208, %8, %210) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%212 = "linalg.init_tensor"() {static_sizes = [1, 512, 384]} : () -> tensor<1x512x384xf32>
%213 = "flow.tensor.reshape"(%211) {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (tensor<1x512x384xf32>) -> tensor<512x384xf32>
%214 = "linalg.generic"(%213, %171, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<512x384xf32>, tensor<512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%215 = "linalg.init_tensor"() {static_sizes = [1, 512]} : () -> tensor<1x512xf32>
%216 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%217 = "linalg.generic"(%214, %216) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%218 = "linalg.generic"(%217, %215) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%219 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%220 = "linalg.generic"(%214, %218, %219) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%221 = "linalg.generic"(%214, %218, %220, %33, %32, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%222 = "linalg.init_tensor"() {static_sizes = [1, 512, 1536]} : () -> tensor<1x512x1536xf32>
%223 = "linalg.init_tensor"() {static_sizes = [1, 384, 1536]} : () -> tensor<1x384x1536xf32>
%224 = "linalg.generic"(%34, %222) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%225 = "linalg.generic"(%35, %223) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<1x384x1536xf32>) -> tensor<1x384x1536xf32>
%226 = "linalg.batch_matmul"(%221, %225, %224) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%227 = "linalg.generic"(%226, %222) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%228 = "linalg.init_tensor"() {static_sizes = [1, 1536, 384]} : () -> tensor<1x1536x384xf32>
%229 = "linalg.generic"(%36, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%230 = "linalg.generic"(%37, %228) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<1x1536x384xf32>) -> tensor<1x1536x384xf32>
%231 = "linalg.batch_matmul"(%227, %230, %229) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x1536x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%232 = "linalg.generic"(%231, %221, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%233 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%234 = "linalg.generic"(%232, %233) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%235 = "linalg.generic"(%234, %215) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%236 = "linalg.fill"(%15, %215) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%237 = "linalg.generic"(%232, %235, %236) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%238 = "linalg.generic"(%232, %235, %237, %39, %38, %212) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%239 = "linalg.init_tensor"() {static_sizes = [1, 384, 384]} : () -> tensor<1x384x384xf32>
%240 = "linalg.generic"(%40, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%241 = "linalg.generic"(%41, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%242 = "linalg.batch_matmul"(%238, %241, %240) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%243 = "tensor.cast"(%242) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%244 = "linalg.generic"(%42, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%245 = "linalg.generic"(%43, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%246 = "linalg.batch_matmul"(%238, %245, %244) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%247 = "tensor.cast"(%246) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%248 = "flow.tensor.reshape"(%247, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%249 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 32]} : () -> tensor<1x12x512x32xf32>
%250 = "linalg.generic"(%44, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%251 = "linalg.generic"(%45, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%252 = "linalg.batch_matmul"(%238, %251, %250) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%253 = "tensor.cast"(%252) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%254 = "flow.tensor.reshape"(%253, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%255 = "linalg.generic"(%254, %249) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%256 = "flow.tensor.reshape"(%243, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%257 = "linalg.generic"(%256, %249) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%258 = "linalg.init_tensor"() {static_sizes = [1, 12, 32, 512]} : () -> tensor<1x12x32x512xf32>
%259 = "linalg.generic"(%248, %258) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x32x512xf32>) -> tensor<1x12x32x512xf32>
%260 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 512]} : () -> tensor<1x12x512x512xf32>
%261 = "linalg.fill"(%15, %260) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%262 = "linalg.generic"(%257, %259, %261) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
"std.assert"(%192) {msg = "mismatched size for broadcast"} : (i1) -> ()
%263 = "flow.tensor.reshape"(%145, %147, %146) {operand_segment_sizes = dense<1> : vector<3xi32>} : (tensor<?xf32>, index, index) -> tensor<1x?xf32>
%264 = "linalg.generic"(%262, %263, %260) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x?xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%265 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xf32>
%266 = "linalg.fill"(%16, %265) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%267 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xi64>
%268 = "linalg.fill"(%4, %267) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<1x12x512xi64>) -> tensor<1x12x512xi64>
%269:2 = "linalg.generic"(%264, %266, %268) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512xi64>) -> (tensor<1x12x512xf32>, tensor<1x12x512xi64>)
%270 = "linalg.generic"(%264, %269#0, %260) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%271 = "linalg.fill"(%15, %265) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%272 = "linalg.generic"(%270, %271) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%273 = "linalg.generic"(%270, %272, %260) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%274 = "linalg.fill"(%15, %249) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%275 = "linalg.generic"(%273, %255, %274) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%276 = "linalg.init_tensor"() {static_sizes = [1, 512, 12, 32]} : () -> tensor<1x512x12x32xf32>
%277 = "linalg.generic"(%275, %276) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x512x12x32xf32>) -> tensor<1x512x12x32xf32>
%278 = "tensor.cast"(%277) : (tensor<1x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%279 = "flow.tensor.reshape"(%278, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%280 = "linalg.generic"(%46, %212) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%281 = "linalg.generic"(%47, %239) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%282 = "linalg.batch_matmul"(%279, %281, %280) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%283 = "linalg.init_tensor"() {static_sizes = [0, 512, 384]} : () -> tensor<0x512x384xf32>
%284 = "linalg.generic"(%282, %238, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%285 = "linalg.init_tensor"() {static_sizes = [0, 512]} : () -> tensor<0x512xf32>
%286 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%287 = "linalg.generic"(%284, %286) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%288 = "linalg.generic"(%287, %285) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%289 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%290 = "linalg.generic"(%284, %288, %289) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%291 = "linalg.generic"(%284, %288, %290, %49, %48, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%292 = "linalg.init_tensor"() {static_sizes = [0, 512, 1536]} : () -> tensor<0x512x1536xf32>
%293 = "linalg.init_tensor"() {static_sizes = [0, 384, 1536]} : () -> tensor<0x384x1536xf32>
%294 = "linalg.generic"(%50, %292) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%295 = "linalg.generic"(%51, %293) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<0x384x1536xf32>) -> tensor<0x384x1536xf32>
%296 = "linalg.batch_matmul"(%291, %295, %294) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%297 = "linalg.generic"(%296, %292) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%298 = "linalg.init_tensor"() {static_sizes = [0, 1536, 384]} : () -> tensor<0x1536x384xf32>
%299 = "linalg.generic"(%52, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%300 = "linalg.generic"(%53, %298) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<0x1536x384xf32>) -> tensor<0x1536x384xf32>
%301 = "linalg.batch_matmul"(%297, %300, %299) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x1536x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%302 = "linalg.generic"(%301, %291, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%303 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%304 = "linalg.generic"(%302, %303) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%305 = "linalg.generic"(%304, %285) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%306 = "linalg.fill"(%15, %285) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%307 = "linalg.generic"(%302, %305, %306) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%308 = "linalg.generic"(%302, %305, %307, %55, %54, %283) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%309 = "linalg.init_tensor"() {static_sizes = [0, 384, 384]} : () -> tensor<0x384x384xf32>
%310 = "linalg.generic"(%56, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%311 = "linalg.generic"(%57, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%312 = "linalg.batch_matmul"(%308, %311, %310) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%313 = "tensor.cast"(%312) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%314 = "linalg.generic"(%58, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%315 = "linalg.generic"(%59, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%316 = "linalg.batch_matmul"(%308, %315, %314) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%317 = "tensor.cast"(%316) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%318 = "flow.tensor.reshape"(%317, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%319 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 32]} : () -> tensor<0x12x512x32xf32>
%320 = "linalg.generic"(%60, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%321 = "linalg.generic"(%61, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%322 = "linalg.batch_matmul"(%308, %321, %320) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%323 = "tensor.cast"(%322) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%324 = "flow.tensor.reshape"(%323, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%325 = "linalg.generic"(%324, %319) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%326 = "flow.tensor.reshape"(%313, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%327 = "linalg.generic"(%326, %319) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%328 = "linalg.init_tensor"() {static_sizes = [0, 12, 32, 512]} : () -> tensor<0x12x32x512xf32>
%329 = "linalg.generic"(%318, %328) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x32x512xf32>) -> tensor<0x12x32x512xf32>
%330 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 512]} : () -> tensor<0x12x512x512xf32>
%331 = "linalg.fill"(%15, %330) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%332 = "linalg.generic"(%327, %329, %331) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x12x32x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
"std.assert"(%192) {msg = "mismatched size for broadcast"} : (i1) -> ()
%333 = "linalg.generic"(%332, %263, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<1x?xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%334 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xf32>
%335 = "linalg.fill"(%16, %334) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%336 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xi64>
%337 = "linalg.fill"(%4, %336) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<0x12x512xi64>) -> tensor<0x12x512xi64>
%338:2 = "linalg.generic"(%333, %335, %337) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512xi64>) -> (tensor<0x12x512xf32>, tensor<0x12x512xi64>)
%339 = "linalg.generic"(%333, %338#0, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%340 = "linalg.fill"(%15, %334) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%341 = "linalg.generic"(%339, %340) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%342 = "linalg.generic"(%339, %341, %330) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%343 = "linalg.fill"(%15, %319) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%344 = "linalg.generic"(%342, %325, %343) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%345 = "linalg.init_tensor"() {static_sizes = [0, 512, 12, 32]} : () -> tensor<0x512x12x32xf32>
%346 = "linalg.generic"(%344, %345) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x512x12x32xf32>) -> tensor<0x512x12x32xf32>
%347 = "tensor.cast"(%346) : (tensor<0x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%348 = "flow.tensor.reshape"(%347, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%349 = "linalg.generic"(%62, %283) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%350 = "linalg.generic"(%63, %309) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%351 = "linalg.batch_matmul"(%348, %350, %349) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%352 = "linalg.init_tensor"() {static_sizes = [1, 512, 384]} : () -> tensor<1x512x384xf32>
%353 = "linalg.generic"(%351, %308, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%354 = "linalg.init_tensor"() {static_sizes = [1, 512]} : () -> tensor<1x512xf32>
%355 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%356 = "linalg.generic"(%353, %355) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%357 = "linalg.generic"(%356, %354) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%358 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%359 = "linalg.generic"(%353, %357, %358) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%360 = "linalg.generic"(%353, %357, %359, %65, %64, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%361 = "linalg.init_tensor"() {static_sizes = [1, 512, 1536]} : () -> tensor<1x512x1536xf32>
%362 = "linalg.init_tensor"() {static_sizes = [1, 384, 1536]} : () -> tensor<1x384x1536xf32>
%363 = "linalg.generic"(%66, %361) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%364 = "linalg.generic"(%67, %362) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<1x384x1536xf32>) -> tensor<1x384x1536xf32>
%365 = "linalg.batch_matmul"(%360, %364, %363) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%366 = "linalg.generic"(%365, %361) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%367 = "linalg.init_tensor"() {static_sizes = [1, 1536, 384]} : () -> tensor<1x1536x384xf32>
%368 = "linalg.generic"(%68, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%369 = "linalg.generic"(%69, %367) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<1x1536x384xf32>) -> tensor<1x1536x384xf32>
%370 = "linalg.batch_matmul"(%366, %369, %368) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x1536x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%371 = "linalg.generic"(%370, %360, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%372 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%373 = "linalg.generic"(%371, %372) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%374 = "linalg.generic"(%373, %354) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%375 = "linalg.fill"(%15, %354) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%376 = "linalg.generic"(%371, %374, %375) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%377 = "linalg.generic"(%371, %374, %376, %71, %70, %352) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%378 = "linalg.init_tensor"() {static_sizes = [1, 384, 384]} : () -> tensor<1x384x384xf32>
%379 = "linalg.generic"(%72, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%380 = "linalg.generic"(%73, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%381 = "linalg.batch_matmul"(%377, %380, %379) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%382 = "tensor.cast"(%381) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%383 = "linalg.generic"(%74, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%384 = "linalg.generic"(%75, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%385 = "linalg.batch_matmul"(%377, %384, %383) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%386 = "tensor.cast"(%385) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%387 = "flow.tensor.reshape"(%386, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%388 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 32]} : () -> tensor<1x12x512x32xf32>
%389 = "linalg.generic"(%76, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%390 = "linalg.generic"(%77, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%391 = "linalg.batch_matmul"(%377, %390, %389) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%392 = "tensor.cast"(%391) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%393 = "flow.tensor.reshape"(%392, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%394 = "linalg.generic"(%393, %388) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%395 = "flow.tensor.reshape"(%382, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%396 = "linalg.generic"(%395, %388) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%397 = "linalg.init_tensor"() {static_sizes = [1, 12, 32, 512]} : () -> tensor<1x12x32x512xf32>
%398 = "linalg.generic"(%387, %397) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x32x512xf32>) -> tensor<1x12x32x512xf32>
%399 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 512]} : () -> tensor<1x12x512x512xf32>
%400 = "linalg.fill"(%15, %399) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%401 = "linalg.generic"(%396, %398, %400) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%402 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%402) {msg = "mismatched size for broadcast"} : (i1) -> ()
%403 = "linalg.generic"(%401, %263, %399) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x?xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%404 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xf32>
%405 = "linalg.fill"(%16, %404) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%406 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xi64>
%407 = "linalg.fill"(%4, %406) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<1x12x512xi64>) -> tensor<1x12x512xi64>
%408:2 = "linalg.generic"(%403, %405, %407) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512xi64>) -> (tensor<1x12x512xf32>, tensor<1x12x512xi64>)
%409 = "linalg.generic"(%403, %408#0, %399) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%410 = "linalg.fill"(%15, %404) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%411 = "linalg.generic"(%409, %410) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%412 = "linalg.generic"(%409, %411, %399) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%413 = "linalg.fill"(%15, %388) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%414 = "linalg.generic"(%412, %394, %413) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%415 = "linalg.init_tensor"() {static_sizes = [1, 512, 12, 32]} : () -> tensor<1x512x12x32xf32>
%416 = "linalg.generic"(%414, %415) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x512x12x32xf32>) -> tensor<1x512x12x32xf32>
%417 = "tensor.cast"(%416) : (tensor<1x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%418 = "flow.tensor.reshape"(%417, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%419 = "linalg.generic"(%78, %352) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%420 = "linalg.generic"(%79, %378) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%421 = "linalg.batch_matmul"(%418, %420, %419) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%422 = "linalg.init_tensor"() {static_sizes = [0, 512, 384]} : () -> tensor<0x512x384xf32>
%423 = "linalg.generic"(%421, %377, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%424 = "linalg.init_tensor"() {static_sizes = [0, 512]} : () -> tensor<0x512xf32>
%425 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%426 = "linalg.generic"(%423, %425) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%427 = "linalg.generic"(%426, %424) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%428 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%429 = "linalg.generic"(%423, %427, %428) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%430 = "linalg.generic"(%423, %427, %429, %81, %80, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%431 = "linalg.init_tensor"() {static_sizes = [0, 512, 1536]} : () -> tensor<0x512x1536xf32>
%432 = "linalg.init_tensor"() {static_sizes = [0, 384, 1536]} : () -> tensor<0x384x1536xf32>
%433 = "linalg.generic"(%82, %431) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%434 = "linalg.generic"(%83, %432) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<0x384x1536xf32>) -> tensor<0x384x1536xf32>
%435 = "linalg.batch_matmul"(%430, %434, %433) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%436 = "linalg.generic"(%435, %431) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%437 = "linalg.init_tensor"() {static_sizes = [0, 1536, 384]} : () -> tensor<0x1536x384xf32>
%438 = "linalg.generic"(%84, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%439 = "linalg.generic"(%85, %437) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<0x1536x384xf32>) -> tensor<0x1536x384xf32>
%440 = "linalg.batch_matmul"(%436, %439, %438) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x1536x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%441 = "linalg.generic"(%440, %430, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%442 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%443 = "linalg.generic"(%441, %442) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%444 = "linalg.generic"(%443, %424) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%445 = "linalg.fill"(%15, %424) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%446 = "linalg.generic"(%441, %444, %445) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%447 = "linalg.generic"(%441, %444, %446, %87, %86, %422) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%448 = "linalg.init_tensor"() {static_sizes = [0, 384, 384]} : () -> tensor<0x384x384xf32>
%449 = "linalg.generic"(%88, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%450 = "linalg.generic"(%89, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%451 = "linalg.batch_matmul"(%447, %450, %449) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%452 = "tensor.cast"(%451) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%453 = "linalg.generic"(%90, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%454 = "linalg.generic"(%91, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%455 = "linalg.batch_matmul"(%447, %454, %453) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%456 = "tensor.cast"(%455) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%457 = "flow.tensor.reshape"(%456, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%458 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 32]} : () -> tensor<0x12x512x32xf32>
%459 = "linalg.generic"(%92, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%460 = "linalg.generic"(%93, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%461 = "linalg.batch_matmul"(%447, %460, %459) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%462 = "tensor.cast"(%461) : (tensor<0x512x384xf32>) -> tensor<?x?x384xf32>
%463 = "flow.tensor.reshape"(%462, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%464 = "linalg.generic"(%463, %458) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%465 = "flow.tensor.reshape"(%452, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%466 = "linalg.generic"(%465, %458) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%467 = "linalg.init_tensor"() {static_sizes = [0, 12, 32, 512]} : () -> tensor<0x12x32x512xf32>
%468 = "linalg.generic"(%457, %467) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<0x12x32x512xf32>) -> tensor<0x12x32x512xf32>
%469 = "linalg.init_tensor"() {static_sizes = [0, 12, 512, 512]} : () -> tensor<0x12x512x512xf32>
%470 = "linalg.fill"(%15, %469) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%471 = "linalg.generic"(%466, %468, %470) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x12x32x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%472 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%472) {msg = "mismatched size for broadcast"} : (i1) -> ()
%473 = "linalg.generic"(%471, %263, %469) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<1x?xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%474 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xf32>
%475 = "linalg.fill"(%16, %474) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%476 = "linalg.init_tensor"() {static_sizes = [0, 12, 512]} : () -> tensor<0x12x512xi64>
%477 = "linalg.fill"(%4, %476) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<0x12x512xi64>) -> tensor<0x12x512xi64>
%478:2 = "linalg.generic"(%473, %475, %477) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512xi64>) -> (tensor<0x12x512xf32>, tensor<0x12x512xi64>)
%479 = "linalg.generic"(%473, %478#0, %469) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%480 = "linalg.fill"(%15, %474) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%481 = "linalg.generic"(%479, %480) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>) -> tensor<0x12x512xf32>
%482 = "linalg.generic"(%479, %481, %469) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512xf32>, tensor<0x12x512x512xf32>) -> tensor<0x12x512x512xf32>
%483 = "linalg.fill"(%15, %458) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%484 = "linalg.generic"(%482, %464, %483) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x12x512x512xf32>, tensor<0x12x512x32xf32>, tensor<0x12x512x32xf32>) -> tensor<0x12x512x32xf32>
%485 = "linalg.init_tensor"() {static_sizes = [0, 512, 12, 32]} : () -> tensor<0x512x12x32xf32>
%486 = "linalg.generic"(%484, %485) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x12x512x32xf32>, tensor<0x512x12x32xf32>) -> tensor<0x512x12x32xf32>
%487 = "tensor.cast"(%486) : (tensor<0x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%488 = "flow.tensor.reshape"(%487, %2, %5, %2, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%489 = "linalg.generic"(%94, %422) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%490 = "linalg.generic"(%95, %448) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<0x384x384xf32>) -> tensor<0x384x384xf32>
%491 = "linalg.batch_matmul"(%488, %490, %489) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<0x384x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%492 = "linalg.init_tensor"() {static_sizes = [1, 512, 384]} : () -> tensor<1x512x384xf32>
%493 = "linalg.generic"(%491, %447, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%494 = "linalg.init_tensor"() {static_sizes = [1, 512]} : () -> tensor<1x512xf32>
%495 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%496 = "linalg.generic"(%493, %495) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%497 = "linalg.generic"(%496, %494) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%498 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%499 = "linalg.generic"(%493, %497, %498) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%500 = "linalg.generic"(%493, %497, %499, %97, %96, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%501 = "linalg.init_tensor"() {static_sizes = [1, 512, 1536]} : () -> tensor<1x512x1536xf32>
%502 = "linalg.init_tensor"() {static_sizes = [1, 384, 1536]} : () -> tensor<1x384x1536xf32>
%503 = "linalg.generic"(%98, %501) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%504 = "linalg.generic"(%99, %502) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<1x384x1536xf32>) -> tensor<1x384x1536xf32>
%505 = "linalg.batch_matmul"(%500, %504, %503) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%506 = "linalg.generic"(%505, %501) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x512x1536xf32>) -> tensor<1x512x1536xf32>
%507 = "linalg.init_tensor"() {static_sizes = [1, 1536, 384]} : () -> tensor<1x1536x384xf32>
%508 = "linalg.generic"(%100, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%509 = "linalg.generic"(%101, %507) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<1x1536x384xf32>) -> tensor<1x1536x384xf32>
%510 = "linalg.batch_matmul"(%506, %509, %508) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x1536xf32>, tensor<1x1536x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%511 = "linalg.generic"(%510, %500, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%512 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%513 = "linalg.generic"(%511, %512) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%514 = "linalg.generic"(%513, %494) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%515 = "linalg.fill"(%15, %494) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x512xf32>) -> tensor<1x512xf32>
%516 = "linalg.generic"(%511, %514, %515) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32>
%517 = "linalg.generic"(%511, %514, %516, %103, %102, %492) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%518 = "linalg.init_tensor"() {static_sizes = [1, 384, 384]} : () -> tensor<1x384x384xf32>
%519 = "linalg.generic"(%104, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%520 = "linalg.generic"(%105, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%521 = "linalg.batch_matmul"(%517, %520, %519) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%522 = "tensor.cast"(%521) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%523 = "linalg.generic"(%106, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%524 = "linalg.generic"(%107, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%525 = "linalg.batch_matmul"(%517, %524, %523) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%526 = "tensor.cast"(%525) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%527 = "flow.tensor.reshape"(%526, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%528 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 32]} : () -> tensor<1x12x512x32xf32>
%529 = "linalg.generic"(%108, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%530 = "linalg.generic"(%109, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%531 = "linalg.batch_matmul"(%517, %530, %529) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%532 = "tensor.cast"(%531) : (tensor<1x512x384xf32>) -> tensor<?x?x384xf32>
%533 = "flow.tensor.reshape"(%532, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%534 = "linalg.generic"(%533, %528) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%535 = "flow.tensor.reshape"(%522, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index, index) -> tensor<?x?x12x32xf32>
%536 = "linalg.generic"(%535, %528) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%537 = "linalg.init_tensor"() {static_sizes = [1, 12, 32, 512]} : () -> tensor<1x12x32x512xf32>
%538 = "linalg.generic"(%527, %537) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<?x?x12x32xf32>, tensor<1x12x32x512xf32>) -> tensor<1x12x32x512xf32>
%539 = "linalg.init_tensor"() {static_sizes = [1, 12, 512, 512]} : () -> tensor<1x12x512x512xf32>
%540 = "linalg.fill"(%15, %539) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%541 = "linalg.generic"(%536, %538, %540) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%542 = "arith.cmpi"(%5, %142) {predicate = 0 : i64} : (index, index) -> i1
"std.assert"(%542) {msg = "mismatched size for broadcast"} : (i1) -> ()
%543 = "linalg.generic"(%541, %263, %539) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.truncf"(%9) : (f64) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%601, %arg2) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x?xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%544 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xf32>
%545 = "linalg.fill"(%16, %544) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%546 = "linalg.init_tensor"() {static_sizes = [1, 12, 512]} : () -> tensor<1x12x512xi64>
%547 = "linalg.fill"(%4, %546) ({
^bb0(%arg1: i64, %arg2: i64):
"linalg.yield"(%arg1) : (i64) -> ()
}) : (i64, tensor<1x12x512xi64>) -> tensor<1x12x512xi64>
%548:2 = "linalg.generic"(%543, %545, %547) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%600 = "linalg.index"() {dim = 3 : i64} : () -> index
%601 = "arith.index_cast"(%600) : (index) -> i64
%602 = "arith.cmpf"(%arg1, %arg2) {predicate = 2 : i64} : (f32, f32) -> i1
%603 = "std.select"(%602, %arg1, %arg2) : (i1, f32, f32) -> f32
%604 = "std.select"(%602, %601, %arg3) : (i1, i64, i64) -> i64
"linalg.yield"(%603, %604) : (f32, i64) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[1, 2]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512xi64>) -> (tensor<1x12x512xf32>, tensor<1x12x512xi64>)
%549 = "linalg.generic"(%543, %548#0, %539) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "math.exp"(%600) : (f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%550 = "linalg.fill"(%15, %544) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%551 = "linalg.generic"(%549, %550) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>) -> tensor<1x12x512xf32>
%552 = "linalg.generic"(%549, %551, %539) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.divf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512xf32>, tensor<1x12x512x512xf32>) -> tensor<1x12x512x512xf32>
%553 = "linalg.fill"(%15, %528) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%554 = "linalg.generic"(%552, %534, %553) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%600, %arg3) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>, tensor<1x12x512x32xf32>) -> tensor<1x12x512x32xf32>
%555 = "linalg.init_tensor"() {static_sizes = [1, 512, 12, 32]} : () -> tensor<1x512x12x32xf32>
%556 = "linalg.generic"(%554, %555) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1x12x512x32xf32>, tensor<1x512x12x32xf32>) -> tensor<1x512x12x32xf32>
%557 = "tensor.cast"(%556) : (tensor<1x512x12x32xf32>) -> tensor<?x?x12x32xf32>
%558 = "flow.tensor.reshape"(%557, %0, %5, %0, %5) {operand_segment_sizes = dense<[1, 2, 2]> : vector<3xi32>} : (tensor<?x?x12x32xf32>, index, index, index, index) -> tensor<?x?x384xf32>
%559 = "linalg.generic"(%110, %492) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%560 = "linalg.generic"(%111, %518) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32>
%561 = "linalg.batch_matmul"(%558, %560, %559) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x?x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
"std.assert"(%1) {msg = "mismatched size for broadcast"} : (i1) -> ()
%562 = "linalg.init_tensor"() {static_sizes = [0, 512, 384]} : () -> tensor<0x512x384xf32>
%563 = "linalg.generic"(%561, %517, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<1x512x384xf32>, tensor<1x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%564 = "linalg.init_tensor"() {static_sizes = [0, 512]} : () -> tensor<0x512xf32>
%565 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%566 = "linalg.generic"(%563, %565) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%567 = "linalg.generic"(%566, %564) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%568 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%569 = "linalg.generic"(%563, %567, %568) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%570 = "linalg.generic"(%563, %567, %569, %113, %112, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%571 = "linalg.init_tensor"() {static_sizes = [0, 512, 1536]} : () -> tensor<0x512x1536xf32>
%572 = "linalg.init_tensor"() {static_sizes = [0, 384, 1536]} : () -> tensor<0x384x1536xf32>
%573 = "linalg.generic"(%114, %571) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%574 = "linalg.generic"(%115, %572) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<1536x384xf32>, tensor<0x384x1536xf32>) -> tensor<0x384x1536xf32>
%575 = "linalg.batch_matmul"(%570, %574, %573) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x384x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%576 = "linalg.generic"(%575, %571) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.sqrt"(%18) : (f32) -> f32
%601 = "arith.divf"(%arg1, %600) : (f32, f32) -> f32
%602 = "math.erf"(%601) : (f32) -> f32
%603 = "arith.addf"(%602, %19) : (f32, f32) -> f32
%604 = "arith.mulf"(%603, %17) : (f32, f32) -> f32
%605 = "arith.mulf"(%arg1, %604) : (f32, f32) -> f32
"linalg.yield"(%605) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x512x1536xf32>) -> tensor<0x512x1536xf32>
%577 = "linalg.init_tensor"() {static_sizes = [0, 1536, 384]} : () -> tensor<0x1536x384xf32>
%578 = "linalg.generic"(%116, %562) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%579 = "linalg.generic"(%117, %577) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384x1536xf32>, tensor<0x1536x384xf32>) -> tensor<0x1536x384xf32>
%580 = "linalg.batch_matmul"(%576, %579, %578) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x1536xf32>, tensor<0x1536x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%581 = "linalg.generic"(%580, %570, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.addf"(%arg1, %arg2) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512x384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%582 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%583 = "linalg.generic"(%581, %582) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.addf"(%arg2, %arg1) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%584 = "linalg.generic"(%583, %564) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "arith.divf"(%arg1, %122) : (f32, f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%585 = "linalg.fill"(%15, %564) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) : (f32, tensor<0x512xf32>) -> tensor<0x512xf32>
%586 = "linalg.generic"(%581, %584, %585) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.mulf"(%600, %600) : (f32, f32) -> f32
%602 = "arith.addf"(%arg3, %601) : (f32, f32) -> f32
"linalg.yield"(%602) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>) -> tensor<0x512xf32>
%587 = "linalg.generic"(%581, %584, %586, %119, %118, %562) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%600 = "arith.divf"(%arg3, %122) : (f32, f32) -> f32
%601 = "arith.subf"(%arg1, %arg2) : (f32, f32) -> f32
%602 = "arith.truncf"(%21) : (f64) -> f32
%603 = "arith.addf"(%600, %602) : (f32, f32) -> f32
%604 = "math.rsqrt"(%603) : (f32) -> f32
%605 = "arith.mulf"(%601, %604) : (f32, f32) -> f32
%606 = "arith.mulf"(%605, %arg4) : (f32, f32) -> f32
%607 = "arith.addf"(%606, %arg5) : (f32, f32) -> f32
"linalg.yield"(%607) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"], operand_segment_sizes = dense<[5, 1]> : vector<2xi32>} : (tensor<0x512x384xf32>, tensor<0x512xf32>, tensor<0x512xf32>, tensor<384xf32>, tensor<384xf32>, tensor<0x512x384xf32>) -> tensor<0x512x384xf32>
%588 = "tensor.extract_slice"(%587) {operand_segment_sizes = dense<[1, 0, 0, 0]> : vector<4xi32>, static_offsets = [0, 0, 0], static_sizes = [0, 1, 384], static_strides = [1, 1, 1]} : (tensor<0x512x384xf32>) -> tensor<0x1x384xf32>
%589 = "tensor.cast"(%588) : (tensor<0x1x384xf32>) -> tensor<?x?x384xf32>
%590 = "flow.tensor.reshape"(%589, %2, %0, %2) {operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32>} : (tensor<?x?x384xf32>, index, index, index) -> tensor<?x384xf32>
%591 = "linalg.init_tensor"() {static_sizes = [0, 384]} : () -> tensor<0x384xf32>
%592 = "linalg.generic"(%120, %591) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<384xf32>, tensor<0x384xf32>) -> tensor<0x384xf32>
%593 = "linalg.matmul"(%590, %7, %592) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<?x384xf32>, tensor<384x384xf32>, tensor<0x384xf32>) -> tensor<0x384xf32>
%594 = "linalg.generic"(%593, %591) ({
^bb0(%arg1: f32, %arg2: f32):
%600 = "math.tanh"(%arg1) : (f32) -> f32
"linalg.yield"(%600) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<0x384xf32>, tensor<0x384xf32>) -> tensor<0x384xf32>
%595 = "linalg.init_tensor"() {static_sizes = [0, 2]} : () -> tensor<0x2xf32>
%596 = "linalg.generic"(%121, %595) ({
^bb0(%arg1: f32, %arg2: f32):
"linalg.yield"(%arg1) : (f32) -> ()
}) {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], operand_segment_sizes = dense<1> : vector<2xi32>} : (tensor<2xf32>, tensor<0x2xf32>) -> tensor<0x2xf32>
%597 = "linalg.matmul"(%594, %6, %596) ({
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%600 = "arith.mulf"(%arg1, %arg2) : (f32, f32) -> f32
%601 = "arith.addf"(%arg3, %600) : (f32, f32) -> f32
"linalg.yield"(%601) : (f32) -> ()
}) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : (tensor<0x384xf32>, tensor<384x2xf32>, tensor<0x2xf32>) -> tensor<0x2xf32>
%598 = "tensor.cast"(%597) : (tensor<0x2xf32>) -> tensor<?x2xf32>
%599 = "hal.tensor.export"(%598, %2) {operand_segment_sizes = dense<[1, 1, 0]> : vector<3xi32>, source_encoding = tensor<?x2xf32>} : (tensor<?x2xf32>, index) -> !hal.buffer_view
"std.return"(%599) : (!hal.buffer_view) -> ()
}) {iree.abi.stub, sym_name = "forward", type = (!hal.buffer_view) -> !hal.buffer_view} : () -> ()
}) {torch.debug_module_name = "MiniLMSequenceClassification"} : () -> ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment