Skip to content

Instantly share code, notes, and snippets.

@vivekkhandelwal1
Last active February 21, 2022 16:49
Show Gist options
  • Save vivekkhandelwal1/4d610bf73a9c9cee1906d42176b92af2 to your computer and use it in GitHub Desktop.
Save vivekkhandelwal1/4d610bf73a9c9cee1906d42176b92af2 to your computer and use it in GitHub Desktop.
// -----// IR Dump Before Canonicalizer //----- //
func private @_forward(%arg0: tensor<1x512xi64>) -> tensor<?x2xf32> {
%c12_i64 = arith.constant 12 : i64
%c512 = arith.constant 512 : index
%c30522_i64 = arith.constant 30522 : i64
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%cst = arith.constant 0.000000e+00 : f32
%c384_i64 = arith.constant 384 : i64
%c4_i64 = arith.constant 4 : i64
%cst_0 = arith.constant 1.000000e+00 : f64
%cst_1 = arith.constant -3.40282347E+38 : f32
%c32_i64 = arith.constant 32 : i64
%c3_i64 = arith.constant 3 : i64
%cst_2 = arith.constant 5.000000e-01 : f32
%cst_3 = arith.constant 2.000000e+00 : f32
%c512_i64 = arith.constant 512 : i64
%cst_4 = arith.constant 1.000000e+00 : f32
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_5 = arith.constant 9.9999999999999998E-13 : f64
%c2_i64 = arith.constant 2 : i64
%c-1_i64 = arith.constant -1 : i64
%cst_6 = arith.constant dense<-1.000000e+04> : tensor<f64>
%cst_7 = arith.constant dense<0> : tensor<1x512xi64>
%cst_8 = arith.constant dense<0> : tensor<i64>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_20 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_21 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_94 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_95 = arith.constant dense<5.6568542494923806> : tensor<f64>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_100 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_101 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_102 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_103 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_104 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_105 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_106 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_107 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_108 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_109 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_110 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_111 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_112 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_113 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_114 = arith.constant dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>
%cst_115 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%true = arith.constant true
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = linalg.init_tensor [] : tensor<i64>
%1 = linalg.fill(%c512_i64, %0) : i64, tensor<i64> -> tensor<i64>
%2 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%1, %cst_8 : tensor<i64>, tensor<i64>) outs(%0 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%834 = arith.muli %arg2, %c1_i64 : i64
%835 = arith.addi %arg1, %834 : i64
linalg.yield %835 : i64
} -> tensor<i64>
%3 = tensor.extract %2[] : tensor<i64>
%4 = arith.index_cast %c1_i64 : i64 to index
%5 = arith.index_cast %3 : i64 to index
%6 = linalg.init_tensor [%4, %5] : tensor<?x?xf32>
%7 = linalg.fill(%cst_4, %6) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%8 = arith.addi %c0_i64, %c1_i64 : i64
%9 = arith.cmpi sge, %c0_i64, %c0_i64 : i64
%10 = arith.select %9, %c0_i64, %8 : i64
%11 = arith.cmpi slt, %10, %c0_i64 : i64
%12 = arith.select %11, %c0_i64, %10 : i64
%13 = arith.cmpi sgt, %12, %c1_i64 : i64
%14 = arith.select %13, %c1_i64, %12 : i64
%15 = arith.index_cast %14 : i64 to index
%16 = arith.addi %c9223372036854775807_i64, %c1_i64 : i64
%17 = arith.cmpi sge, %c9223372036854775807_i64, %c0_i64 : i64
%18 = arith.select %17, %c9223372036854775807_i64, %16 : i64
%19 = arith.cmpi slt, %18, %c0_i64 : i64
%20 = arith.select %19, %c0_i64, %18 : i64
%21 = arith.cmpi sgt, %20, %c1_i64 : i64
%22 = arith.select %21, %c1_i64, %20 : i64
%23 = arith.index_cast %22 : i64 to index
%24 = arith.cmpi sge, %23, %15 : index
%25 = arith.select %24, %23, %15 : index
%26 = arith.subi %25, %15 : index
%27 = tensor.extract_slice %cst_7[%15, 0] [%26, 512] [1, 1] : tensor<1x512xi64> to tensor<?x512xi64>
%28 = arith.addi %c0_i64, %c512_i64 : i64
%29 = arith.select %9, %c0_i64, %28 : i64
%30 = arith.cmpi slt, %29, %c0_i64 : i64
%31 = arith.select %30, %c0_i64, %29 : i64
%32 = arith.cmpi sgt, %31, %c512_i64 : i64
%33 = arith.select %32, %c512_i64, %31 : i64
%34 = arith.index_cast %33 : i64 to index
%35 = arith.addi %c512_i64, %c512_i64 : i64
%36 = arith.cmpi sge, %c512_i64, %c0_i64 : i64
%37 = arith.select %36, %c512_i64, %35 : i64
%38 = arith.cmpi slt, %37, %c0_i64 : i64
%39 = arith.select %38, %c0_i64, %37 : i64
%40 = arith.cmpi sgt, %39, %c512_i64 : i64
%41 = arith.select %40, %c512_i64, %39 : i64
%42 = arith.index_cast %41 : i64 to index
%43 = arith.cmpi sge, %42, %34 : index
%44 = arith.select %43, %42, %34 : index
%45 = arith.subi %44, %34 : index
%46 = tensor.extract_slice %27[0, %34] [%26, %45] [1, 1] : tensor<?x512xi64> to tensor<?x?xi64>
%47 = arith.cmpi slt, %c1_i64, %c0_i64 : i64
%48 = arith.index_cast %26 : index to i64
%49 = arith.cmpi eq, %48, %c1_i64 : i64
%50 = arith.ori %47, %49 : i1
cf.assert %50, "only broadcasting singleton dimensions supported"
%51 = arith.cmpi slt, %c512_i64, %c0_i64 : i64
%52 = arith.index_cast %45 : index to i64
%53 = arith.cmpi eq, %52, %c512_i64 : i64
%54 = arith.ori %51, %53 : i1
cf.assert %54, "only broadcasting singleton dimensions supported"
%55 = arith.index_cast %4 : index to i64
%56 = arith.addi %c0_i64, %55 : i64
%57 = arith.select %9, %c0_i64, %56 : i64
%58 = arith.cmpi slt, %57, %c0_i64 : i64
%59 = arith.select %58, %c0_i64, %57 : i64
%60 = arith.cmpi sgt, %59, %55 : i64
%61 = arith.select %60, %55, %59 : i64
%62 = arith.index_cast %61 : i64 to index
%63 = arith.addi %c9223372036854775807_i64, %55 : i64
%64 = arith.select %17, %c9223372036854775807_i64, %63 : i64
%65 = arith.cmpi slt, %64, %c0_i64 : i64
%66 = arith.select %65, %c0_i64, %64 : i64
%67 = arith.cmpi sgt, %66, %55 : i64
%68 = arith.select %67, %55, %66 : i64
%69 = arith.index_cast %68 : i64 to index
%70 = arith.cmpi sge, %69, %62 : index
%71 = arith.select %70, %69, %62 : index
%72 = arith.subi %71, %62 : index
%73 = tensor.extract_slice %7[%62, 0] [%72, %5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%74 = tensor.expand_shape %73 [[0], [1, 2, 3]] : tensor<?x?xf32> into tensor<?x1x1x?xf32>
%75 = arith.index_cast %5 : index to i64
%76 = arith.addi %c0_i64, %75 : i64
%77 = arith.select %9, %c0_i64, %76 : i64
%78 = arith.cmpi slt, %77, %c0_i64 : i64
%79 = arith.select %78, %c0_i64, %77 : i64
%80 = arith.cmpi sgt, %79, %75 : i64
%81 = arith.select %80, %75, %79 : i64
%82 = arith.index_cast %81 : i64 to index
%83 = arith.addi %c9223372036854775807_i64, %75 : i64
%84 = arith.select %17, %c9223372036854775807_i64, %83 : i64
%85 = arith.cmpi slt, %84, %c0_i64 : i64
%86 = arith.select %85, %c0_i64, %84 : i64
%87 = arith.cmpi sgt, %86, %75 : i64
%88 = arith.select %87, %75, %86 : i64
%89 = arith.index_cast %88 : i64 to index
%90 = arith.cmpi sge, %89, %82 : index
%91 = arith.select %90, %89, %82 : index
%92 = arith.subi %91, %82 : index
%93 = tensor.extract_slice %74[0, 0, 0, %82] [%72, 1, 1, %92] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<?x1x1x?xf32>
%94 = linalg.init_tensor [%72, 1, 1, %92] : tensor<?x1x1x?xf32>
%95 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%93 : tensor<?x1x1x?xf32>) outs(%94 : tensor<?x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.sitofp %c1_i64 : i64 to f32
%836 = arith.mulf %arg1, %835 : f32
%837 = arith.subf %834, %836 : f32
linalg.yield %837 : f32
} -> tensor<?x1x1x?xf32>
%96 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%95, %cst_6 : tensor<?x1x1x?xf32>, tensor<f64>) outs(%94 : tensor<?x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.mulf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x1x1x?xf32>
%97 = linalg.fill(%c512_i64, %0) : i64, tensor<i64> -> tensor<i64>
%98 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%97, %cst_8 : tensor<i64>, tensor<i64>) outs(%0 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%834 = arith.muli %arg2, %c1_i64 : i64
%835 = arith.addi %arg1, %834 : i64
linalg.yield %835 : i64
} -> tensor<i64>
%99 = tensor.extract %98[] : tensor<i64>
%100 = tensor.extract_slice %cst_9[%15, 0] [%26, 512] [1, 1] : tensor<1x512xi64> to tensor<?x512xi64>
%101 = arith.addi %99, %c512_i64 : i64
%102 = arith.cmpi sge, %99, %c0_i64 : i64
%103 = arith.select %102, %99, %101 : i64
%104 = arith.cmpi slt, %103, %c0_i64 : i64
%105 = arith.select %104, %c0_i64, %103 : i64
%106 = arith.cmpi sgt, %105, %c512_i64 : i64
%107 = arith.select %106, %c512_i64, %105 : i64
%108 = arith.index_cast %107 : i64 to index
%109 = arith.cmpi sge, %108, %34 : index
%110 = arith.select %109, %108, %34 : index
%111 = arith.subi %110, %34 : index
%112 = tensor.extract_slice %100[0, %34] [%26, %111] [1, 1] : tensor<?x512xi64> to tensor<?x?xi64>
%113 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%114 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x512xi64>) outs(%113 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%834 = arith.index_cast %arg1 : i64 to index
%835 = linalg.index 2 : index
%836 = arith.cmpi slt, %arg1, %c30522_i64 : i64
cf.assert %836, "index must be smaller than dim size"
%837 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %837, "index must be larger or equal to 0"
%838 = tensor.extract %cst_10[%834, %835] : tensor<30522x384xf32>
linalg.yield %838 : f32
} -> tensor<1x512x384xf32>
%115 = linalg.init_tensor [%26, %45, 384] : tensor<?x?x384xf32>
%116 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%46 : tensor<?x?xi64>) outs(%115 : tensor<?x?x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%834 = arith.index_cast %arg1 : i64 to index
%835 = linalg.index 2 : index
%836 = arith.cmpi slt, %arg1, %c2_i64 : i64
cf.assert %836, "index must be smaller than dim size"
%837 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %837, "index must be larger or equal to 0"
%838 = tensor.extract %cst_11[%834, %835] : tensor<2x384xf32>
linalg.yield %838 : f32
} -> tensor<?x?x384xf32>
%117 = arith.cmpi eq, %c512, %45 : index
cf.assert %117, "mismatched size for broadcast"
%118 = linalg.init_tensor [%26, 512, 384] : tensor<?x512x384xf32>
%119 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%114, %116 : tensor<1x512x384xf32>, tensor<?x?x384xf32>) outs(%118 : tensor<?x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x512x384xf32>
%120 = linalg.init_tensor [%26, %111, 384] : tensor<?x?x384xf32>
%121 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%112 : tensor<?x?xi64>) outs(%120 : tensor<?x?x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%834 = arith.index_cast %arg1 : i64 to index
%835 = linalg.index 2 : index
%836 = arith.cmpi slt, %arg1, %c512_i64 : i64
cf.assert %836, "index must be smaller than dim size"
%837 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %837, "index must be larger or equal to 0"
%838 = tensor.extract %cst_12[%834, %835] : tensor<512x384xf32>
linalg.yield %838 : f32
} -> tensor<?x?x384xf32>
%122 = arith.cmpi eq, %26, %26 : index
cf.assert %122, "mismatched size for broadcast"
%123 = arith.cmpi eq, %c512, %111 : index
cf.assert %123, "mismatched size for broadcast"
%124 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%119, %121 : tensor<?x512x384xf32>, tensor<?x?x384xf32>) outs(%118 : tensor<?x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x512x384xf32>
%125 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %125, "mismatching contracting dimension"
cf.assert %125, "mismatching contracting dimension"
cf.assert %125, "mismatching contracting dimension"
%126 = arith.sitofp %c384_i64 : i64 to f32
%127 = linalg.init_tensor [%26, 512] : tensor<?x512xf32>
%128 = linalg.fill(%cst, %127) : f32, tensor<?x512xf32> -> tensor<?x512xf32>
%129 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%124 : tensor<?x512x384xf32>) outs(%128 : tensor<?x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x512xf32>
%130 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%129 : tensor<?x512xf32>) outs(%127 : tensor<?x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %126 : f32
linalg.yield %834 : f32
} -> tensor<?x512xf32>
%131 = linalg.fill(%cst, %127) : f32, tensor<?x512xf32> -> tensor<?x512xf32>
%132 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%124, %130 : tensor<?x512x384xf32>, tensor<?x512xf32>) outs(%131 : tensor<?x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x512xf32>
%133 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%132 : tensor<?x512xf32>) outs(%127 : tensor<?x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %126 : f32
linalg.yield %834 : f32
} -> tensor<?x512xf32>
%134 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%124, %130, %133, %cst_14, %cst_13 : tensor<?x512x384xf32>, tensor<?x512xf32>, tensor<?x512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%118 : tensor<?x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x512x384xf32>
%135 = linalg.init_tensor [%26, 384, 384] : tensor<?x384x384xf32>
%136 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_15 : tensor<384xf32>) outs(%118 : tensor<?x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x512x384xf32>
%137 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16 : tensor<384x384xf32>) outs(%135 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%138 = linalg.batch_matmul ins(%134, %137 : tensor<?x512x384xf32>, tensor<?x384x384xf32>) outs(%136 : tensor<?x512x384xf32>) -> tensor<?x512x384xf32>
%139 = tensor.cast %138 : tensor<?x512x384xf32> to tensor<?x?x384xf32>
%140 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_17 : tensor<384xf32>) outs(%118 : tensor<?x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x512x384xf32>
%141 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18 : tensor<384x384xf32>) outs(%135 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%142 = linalg.batch_matmul ins(%134, %141 : tensor<?x512x384xf32>, tensor<?x384x384xf32>) outs(%140 : tensor<?x512x384xf32>) -> tensor<?x512x384xf32>
%143 = tensor.cast %142 : tensor<?x512x384xf32> to tensor<?x?x384xf32>
%144 = arith.addi %c0_i64, %c3_i64 : i64
%145 = arith.select %9, %c0_i64, %144 : i64
%146 = arith.cmpi sge, %145, %c0_i64 : i64
cf.assert %146, "dim must be greater or equal to zero"
%147 = arith.cmpi slt, %145, %c3_i64 : i64
cf.assert %147, "dim must be smaller than inputRank"
%148 = arith.index_cast %145 : i64 to index
%149 = tensor.dim %142, %148 : tensor<?x512x384xf32>
%150 = arith.index_cast %149 : index to i64
%151 = arith.addi %c1_i64, %c3_i64 : i64
%152 = arith.cmpi sge, %c1_i64, %c0_i64 : i64
%153 = arith.select %152, %c1_i64, %151 : i64
%154 = arith.cmpi sge, %153, %c0_i64 : i64
cf.assert %154, "dim must be greater or equal to zero"
%155 = arith.cmpi slt, %153, %c3_i64 : i64
cf.assert %155, "dim must be smaller than inputRank"
%156 = arith.index_cast %153 : i64 to index
%157 = tensor.dim %142, %156 : tensor<?x512x384xf32>
%158 = arith.index_cast %157 : index to i64
%159 = tensor.expand_shape %143 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%160 = linalg.init_tensor [%26, 12, %c512, 32] : tensor<?x12x?x32xf32>
%161 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%159 : tensor<?x?x12x32xf32>) outs(%160 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%162 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_19 : tensor<384xf32>) outs(%118 : tensor<?x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x512x384xf32>
%163 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_20 : tensor<384x384xf32>) outs(%135 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%164 = linalg.batch_matmul ins(%134, %163 : tensor<?x512x384xf32>, tensor<?x384x384xf32>) outs(%162 : tensor<?x512x384xf32>) -> tensor<?x512x384xf32>
%165 = tensor.cast %164 : tensor<?x512x384xf32> to tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%166 = tensor.dim %164, %148 : tensor<?x512x384xf32>
%167 = arith.index_cast %166 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%168 = tensor.dim %164, %156 : tensor<?x512x384xf32>
%169 = arith.index_cast %168 : index to i64
%170 = tensor.expand_shape %165 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%171 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%170 : tensor<?x?x12x32xf32>) outs(%160 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%172 = tensor.dim %138, %148 : tensor<?x512x384xf32>
%173 = arith.index_cast %172 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%174 = tensor.dim %138, %156 : tensor<?x512x384xf32>
%175 = arith.index_cast %174 : index to i64
%176 = tensor.expand_shape %139 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%177 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%176 : tensor<?x?x12x32xf32>) outs(%160 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%178 = linalg.init_tensor [%26, 12, 32, %c512] : tensor<?x12x32x?xf32>
%179 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%161 : tensor<?x12x?x32xf32>) outs(%178 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%180 = arith.cmpi eq, %48, %48 : i64
cf.assert %180, "mismatching contracting dimension"
%181 = linalg.init_tensor [%26, 12, %c512, %c512] : tensor<?x12x?x?xf32>
%182 = linalg.fill(%cst, %181) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%183 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%177, %179 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%182 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%184 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%183, %cst_95 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%181 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.divf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%185 = arith.cmpi eq, %26, %72 : index
cf.assert %185, "mismatched size for broadcast"
%186 = arith.cmpi eq, %c512, %92 : index
cf.assert %186, "mismatched size for broadcast"
%187 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%184, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%181 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%188 = linalg.init_tensor [%26, 12, %c512, 1] : tensor<?x12x?x1xi64>
%189 = linalg.fill(%c0_i64, %188) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%190 = linalg.init_tensor [%26, 12, %c512, 1] : tensor<?x12x?x1xf32>
%191 = linalg.fill(%cst_1, %190) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%192:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%187 : tensor<?x12x?x?xf32>) outs(%191, %189 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%834 = linalg.index 3 : index
%835 = arith.index_cast %834 : index to i64
%836 = arith.cmpf ogt, %arg1, %arg2 : f32
%837 = arith.select %836, %arg1, %arg2 : f32
%838 = arith.select %836, %835, %arg3 : i64
linalg.yield %837, %838 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
cf.assert %true, "mismatched size for broadcast"
cf.assert %true, "mismatched size for broadcast"
%193 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%187, %192#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%181 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.subf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%194 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%193 : tensor<?x12x?x?xf32>) outs(%181 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.exp %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
%195 = linalg.fill(%cst, %190) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%196 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%194 : tensor<?x12x?x?xf32>) outs(%195 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x1xf32>
cf.assert %true, "mismatched size for broadcast"
cf.assert %true, "mismatched size for broadcast"
%197 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%194, %196 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%181 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.divf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %180, "mismatching contracting dimension"
cf.assert %true, "mismatching contracting dimension"
%198 = linalg.fill(%cst, %160) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%199 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%197, %171 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%198 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x32xf32>
%200 = linalg.init_tensor [%26, %c512, 12, 32] : tensor<?x?x12x32xf32>
%201 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%199 : tensor<?x12x?x32xf32>) outs(%200 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%202 = tensor.cast %201 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%203 = arith.addi %c0_i64, %c4_i64 : i64
%204 = arith.select %9, %c0_i64, %203 : i64
%205 = arith.cmpi sge, %204, %c0_i64 : i64
cf.assert %205, "dim must be greater or equal to zero"
%206 = arith.cmpi slt, %204, %c4_i64 : i64
cf.assert %206, "dim must be smaller than inputRank"
%207 = arith.index_cast %204 : i64 to index
%208 = tensor.dim %201, %207 : tensor<?x?x12x32xf32>
%209 = arith.index_cast %208 : index to i64
%210 = arith.addi %c1_i64, %c4_i64 : i64
%211 = arith.select %152, %c1_i64, %210 : i64
%212 = arith.cmpi sge, %211, %c0_i64 : i64
cf.assert %212, "dim must be greater or equal to zero"
%213 = arith.cmpi slt, %211, %c4_i64 : i64
cf.assert %213, "dim must be smaller than inputRank"
%214 = arith.index_cast %211 : i64 to index
%215 = tensor.dim %201, %214 : tensor<?x?x12x32xf32>
%216 = arith.index_cast %215 : index to i64
%217 = tensor.collapse_shape %202 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%218 = tensor.cast %217 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%219 = linalg.init_tensor [%26, %c512, 384] : tensor<?x?x384xf32>
%220 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_21 : tensor<384xf32>) outs(%219 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%221 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_22 : tensor<384x384xf32>) outs(%135 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%222 = linalg.batch_matmul ins(%218, %221 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%220 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%223 = tensor.dim %218, %c0 : tensor<?x?x384xf32>
%224 = tensor.dim %218, %c1 : tensor<?x?x384xf32>
%225 = arith.cmpi eq, %223, %26 : index
cf.assert %225, "mismatched size for broadcast"
%226 = arith.cmpi eq, %224, %c512 : index
cf.assert %226, "mismatched size for broadcast"
%227 = linalg.init_tensor [%223, %224, 384] : tensor<?x?x384xf32>
%228 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%222, %134 : tensor<?x?x384xf32>, tensor<?x512x384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%229 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %229, "mismatching contracting dimension"
cf.assert %229, "mismatching contracting dimension"
cf.assert %229, "mismatching contracting dimension"
%230 = arith.sitofp %c384_i64 : i64 to f32
%231 = linalg.init_tensor [%223, %224] : tensor<?x?xf32>
%232 = linalg.fill(%cst, %231) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%233 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%228 : tensor<?x?x384xf32>) outs(%232 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%234 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%233 : tensor<?x?xf32>) outs(%231 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %230 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%235 = linalg.fill(%cst, %231) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%228, %234 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%235 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%237 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%236 : tensor<?x?xf32>) outs(%231 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %230 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%238 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%228, %234, %237, %cst_24, %cst_23 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%239 = linalg.init_tensor [%223, %224, 1536] : tensor<?x?x1536xf32>
%240 = linalg.init_tensor [%223, 384, 1536] : tensor<?x384x1536xf32>
%241 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_25 : tensor<1536xf32>) outs(%239 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%242 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_26 : tensor<1536x384xf32>) outs(%240 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%243 = linalg.batch_matmul ins(%238, %242 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%241 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%244 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%243 : tensor<?x?x1536xf32>) outs(%239 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.sqrt %cst_3 : f32
%835 = arith.divf %arg1, %834 : f32
%836 = math.erf %835 : f32
%837 = arith.addf %836, %cst_4 : f32
%838 = arith.mulf %837, %cst_2 : f32
%839 = arith.mulf %arg1, %838 : f32
linalg.yield %839 : f32
} -> tensor<?x?x1536xf32>
%245 = linalg.init_tensor [%223, 1536, 384] : tensor<?x1536x384xf32>
%246 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_27 : tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%247 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_28 : tensor<384x1536xf32>) outs(%245 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%248 = linalg.batch_matmul ins(%244, %247 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%246 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%249 = arith.cmpi eq, %223, %223 : index
cf.assert %249, "mismatched size for broadcast"
%250 = arith.cmpi eq, %224, %224 : index
cf.assert %250, "mismatched size for broadcast"
%251 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%248, %238 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%252 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %252, "mismatching contracting dimension"
cf.assert %252, "mismatching contracting dimension"
cf.assert %252, "mismatching contracting dimension"
%253 = arith.sitofp %c384_i64 : i64 to f32
%254 = linalg.fill(%cst, %231) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%255 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%251 : tensor<?x?x384xf32>) outs(%254 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%256 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%255 : tensor<?x?xf32>) outs(%231 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %253 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%257 = linalg.fill(%cst, %231) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%258 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%251, %256 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%257 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%259 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%258 : tensor<?x?xf32>) outs(%231 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %253 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%260 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%251, %256, %259, %cst_30, %cst_29 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%261 = linalg.init_tensor [%223, 384, 384] : tensor<?x384x384xf32>
%262 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_31 : tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%263 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_32 : tensor<384x384xf32>) outs(%261 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%264 = linalg.batch_matmul ins(%260, %263 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%262 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%265 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_33 : tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%266 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_34 : tensor<384x384xf32>) outs(%261 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%267 = linalg.batch_matmul ins(%260, %266 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%265 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%268 = tensor.dim %267, %148 : tensor<?x?x384xf32>
%269 = arith.index_cast %268 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%270 = tensor.dim %267, %156 : tensor<?x?x384xf32>
%271 = arith.index_cast %270 : index to i64
%272 = tensor.expand_shape %267 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%273 = linalg.init_tensor [%223, 12, %224, 32] : tensor<?x12x?x32xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%272 : tensor<?x?x12x32xf32>) outs(%273 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%275 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_35 : tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%276 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_36 : tensor<384x384xf32>) outs(%261 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%277 = linalg.batch_matmul ins(%260, %276 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%275 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%278 = tensor.dim %277, %148 : tensor<?x?x384xf32>
%279 = arith.index_cast %278 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%280 = tensor.dim %277, %156 : tensor<?x?x384xf32>
%281 = arith.index_cast %280 : index to i64
%282 = tensor.expand_shape %277 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%283 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%282 : tensor<?x?x12x32xf32>) outs(%273 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%284 = tensor.dim %264, %148 : tensor<?x?x384xf32>
%285 = arith.index_cast %284 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%286 = tensor.dim %264, %156 : tensor<?x?x384xf32>
%287 = arith.index_cast %286 : index to i64
%288 = tensor.expand_shape %264 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%289 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%288 : tensor<?x?x12x32xf32>) outs(%273 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%290 = linalg.init_tensor [%223, 12, 32, %224] : tensor<?x12x32x?xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%274 : tensor<?x12x?x32xf32>) outs(%290 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%292 = arith.index_cast %223 : index to i64
%293 = arith.cmpi eq, %292, %292 : i64
cf.assert %293, "mismatching contracting dimension"
%294 = linalg.init_tensor [%223, 12, %224, %224] : tensor<?x12x?x?xf32>
%295 = linalg.fill(%cst, %294) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%296 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%289, %291 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%295 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%297 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%296, %cst_95 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%294 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.divf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%298 = arith.cmpi eq, %223, %72 : index
cf.assert %298, "mismatched size for broadcast"
%299 = arith.cmpi eq, %224, %92 : index
cf.assert %299, "mismatched size for broadcast"
%300 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%297, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%294 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%301 = linalg.init_tensor [%223, 12, %224, 1] : tensor<?x12x?x1xi64>
%302 = linalg.fill(%c0_i64, %301) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%303 = linalg.init_tensor [%223, 12, %224, 1] : tensor<?x12x?x1xf32>
%304 = linalg.fill(%cst_1, %303) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%305:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%300 : tensor<?x12x?x?xf32>) outs(%304, %302 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%834 = linalg.index 3 : index
%835 = arith.index_cast %834 : index to i64
%836 = arith.cmpf ogt, %arg1, %arg2 : f32
%837 = arith.select %836, %arg1, %arg2 : f32
%838 = arith.select %836, %835, %arg3 : i64
linalg.yield %837, %838 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
cf.assert %249, "mismatched size for broadcast"
cf.assert %250, "mismatched size for broadcast"
%306 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%300, %305#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%294 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.subf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%307 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%306 : tensor<?x12x?x?xf32>) outs(%294 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.exp %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
%308 = linalg.fill(%cst, %303) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%309 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%307 : tensor<?x12x?x?xf32>) outs(%308 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x1xf32>
cf.assert %249, "mismatched size for broadcast"
cf.assert %250, "mismatched size for broadcast"
%310 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%307, %309 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%294 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.divf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %293, "mismatching contracting dimension"
%311 = arith.index_cast %224 : index to i64
%312 = arith.cmpi eq, %311, %311 : i64
cf.assert %312, "mismatching contracting dimension"
%313 = linalg.fill(%cst, %273) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%314 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%310, %283 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%313 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x32xf32>
%315 = linalg.init_tensor [%223, %224, 12, 32] : tensor<?x?x12x32xf32>
%316 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%314 : tensor<?x12x?x32xf32>) outs(%315 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%317 = tensor.cast %316 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
cf.assert %205, "dim must be greater or equal to zero"
cf.assert %206, "dim must be smaller than inputRank"
%318 = tensor.dim %316, %207 : tensor<?x?x12x32xf32>
%319 = arith.index_cast %318 : index to i64
cf.assert %212, "dim must be greater or equal to zero"
cf.assert %213, "dim must be smaller than inputRank"
%320 = tensor.dim %316, %214 : tensor<?x?x12x32xf32>
%321 = arith.index_cast %320 : index to i64
%322 = tensor.collapse_shape %317 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%323 = tensor.cast %322 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_37 : tensor<384xf32>) outs(%227 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%325 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_38 : tensor<384x384xf32>) outs(%261 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%326 = linalg.batch_matmul ins(%323, %325 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%324 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%327 = tensor.dim %323, %c0 : tensor<?x?x384xf32>
%328 = tensor.dim %323, %c1 : tensor<?x?x384xf32>
%329 = arith.cmpi eq, %327, %223 : index
cf.assert %329, "mismatched size for broadcast"
%330 = arith.cmpi eq, %328, %224 : index
cf.assert %330, "mismatched size for broadcast"
%331 = linalg.init_tensor [%327, %328, 384] : tensor<?x?x384xf32>
%332 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%326, %260 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%333 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %333, "mismatching contracting dimension"
cf.assert %333, "mismatching contracting dimension"
cf.assert %333, "mismatching contracting dimension"
%334 = arith.sitofp %c384_i64 : i64 to f32
%335 = linalg.init_tensor [%327, %328] : tensor<?x?xf32>
%336 = linalg.fill(%cst, %335) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%337 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%332 : tensor<?x?x384xf32>) outs(%336 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%338 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%337 : tensor<?x?xf32>) outs(%335 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %334 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%339 = linalg.fill(%cst, %335) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%340 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%332, %338 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%339 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%341 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%340 : tensor<?x?xf32>) outs(%335 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %334 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%342 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%332, %338, %341, %cst_40, %cst_39 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%343 = linalg.init_tensor [%327, %328, 1536] : tensor<?x?x1536xf32>
%344 = linalg.init_tensor [%327, 384, 1536] : tensor<?x384x1536xf32>
%345 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_41 : tensor<1536xf32>) outs(%343 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%346 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42 : tensor<1536x384xf32>) outs(%344 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%347 = linalg.batch_matmul ins(%342, %346 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%345 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%348 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%347 : tensor<?x?x1536xf32>) outs(%343 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.sqrt %cst_3 : f32
%835 = arith.divf %arg1, %834 : f32
%836 = math.erf %835 : f32
%837 = arith.addf %836, %cst_4 : f32
%838 = arith.mulf %837, %cst_2 : f32
%839 = arith.mulf %arg1, %838 : f32
linalg.yield %839 : f32
} -> tensor<?x?x1536xf32>
%349 = linalg.init_tensor [%327, 1536, 384] : tensor<?x1536x384xf32>
%350 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_43 : tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%351 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_44 : tensor<384x1536xf32>) outs(%349 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%352 = linalg.batch_matmul ins(%348, %351 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%350 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%353 = arith.cmpi eq, %327, %327 : index
cf.assert %353, "mismatched size for broadcast"
%354 = arith.cmpi eq, %328, %328 : index
cf.assert %354, "mismatched size for broadcast"
%355 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%352, %342 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%356 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %356, "mismatching contracting dimension"
cf.assert %356, "mismatching contracting dimension"
cf.assert %356, "mismatching contracting dimension"
%357 = arith.sitofp %c384_i64 : i64 to f32
%358 = linalg.fill(%cst, %335) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%359 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%355 : tensor<?x?x384xf32>) outs(%358 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%360 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%359 : tensor<?x?xf32>) outs(%335 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %357 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%361 = linalg.fill(%cst, %335) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%362 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%355, %360 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%361 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%363 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%362 : tensor<?x?xf32>) outs(%335 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %357 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%364 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%355, %360, %363, %cst_46, %cst_45 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%365 = linalg.init_tensor [%327, 384, 384] : tensor<?x384x384xf32>
%366 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_47 : tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%367 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_48 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%368 = linalg.batch_matmul ins(%364, %367 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%366 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_49 : tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%370 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_50 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%371 = linalg.batch_matmul ins(%364, %370 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%369 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%372 = tensor.dim %371, %148 : tensor<?x?x384xf32>
%373 = arith.index_cast %372 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%374 = tensor.dim %371, %156 : tensor<?x?x384xf32>
%375 = arith.index_cast %374 : index to i64
%376 = tensor.expand_shape %371 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%377 = linalg.init_tensor [%327, 12, %328, 32] : tensor<?x12x?x32xf32>
%378 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%376 : tensor<?x?x12x32xf32>) outs(%377 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%379 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_51 : tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%380 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%381 = linalg.batch_matmul ins(%364, %380 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%379 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%382 = tensor.dim %381, %148 : tensor<?x?x384xf32>
%383 = arith.index_cast %382 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%384 = tensor.dim %381, %156 : tensor<?x?x384xf32>
%385 = arith.index_cast %384 : index to i64
%386 = tensor.expand_shape %381 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%387 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%386 : tensor<?x?x12x32xf32>) outs(%377 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%388 = tensor.dim %368, %148 : tensor<?x?x384xf32>
%389 = arith.index_cast %388 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%390 = tensor.dim %368, %156 : tensor<?x?x384xf32>
%391 = arith.index_cast %390 : index to i64
%392 = tensor.expand_shape %368 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%392 : tensor<?x?x12x32xf32>) outs(%377 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%394 = linalg.init_tensor [%327, 12, 32, %328] : tensor<?x12x32x?xf32>
%395 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%378 : tensor<?x12x?x32xf32>) outs(%394 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%396 = arith.index_cast %327 : index to i64
%397 = arith.cmpi eq, %396, %396 : i64
cf.assert %397, "mismatching contracting dimension"
%398 = linalg.init_tensor [%327, 12, %328, %328] : tensor<?x12x?x?xf32>
%399 = linalg.fill(%cst, %398) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%400 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%393, %395 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%399 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%401 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%400, %cst_95 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%398 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.divf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%402 = arith.cmpi eq, %327, %72 : index
cf.assert %402, "mismatched size for broadcast"
%403 = arith.cmpi eq, %328, %92 : index
cf.assert %403, "mismatched size for broadcast"
%404 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%401, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%398 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%405 = linalg.init_tensor [%327, 12, %328, 1] : tensor<?x12x?x1xi64>
%406 = linalg.fill(%c0_i64, %405) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%407 = linalg.init_tensor [%327, 12, %328, 1] : tensor<?x12x?x1xf32>
%408 = linalg.fill(%cst_1, %407) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%409:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%404 : tensor<?x12x?x?xf32>) outs(%408, %406 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%834 = linalg.index 3 : index
%835 = arith.index_cast %834 : index to i64
%836 = arith.cmpf ogt, %arg1, %arg2 : f32
%837 = arith.select %836, %arg1, %arg2 : f32
%838 = arith.select %836, %835, %arg3 : i64
linalg.yield %837, %838 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
cf.assert %353, "mismatched size for broadcast"
cf.assert %354, "mismatched size for broadcast"
%410 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%404, %409#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%398 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.subf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%411 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%410 : tensor<?x12x?x?xf32>) outs(%398 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.exp %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
%412 = linalg.fill(%cst, %407) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%413 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%411 : tensor<?x12x?x?xf32>) outs(%412 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x1xf32>
cf.assert %353, "mismatched size for broadcast"
cf.assert %354, "mismatched size for broadcast"
%414 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%411, %413 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%398 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.divf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %397, "mismatching contracting dimension"
%415 = arith.index_cast %328 : index to i64
%416 = arith.cmpi eq, %415, %415 : i64
cf.assert %416, "mismatching contracting dimension"
%417 = linalg.fill(%cst, %377) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%418 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%414, %387 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%417 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x32xf32>
%419 = linalg.init_tensor [%327, %328, 12, 32] : tensor<?x?x12x32xf32>
%420 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%418 : tensor<?x12x?x32xf32>) outs(%419 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%421 = tensor.cast %420 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
cf.assert %205, "dim must be greater or equal to zero"
cf.assert %206, "dim must be smaller than inputRank"
%422 = tensor.dim %420, %207 : tensor<?x?x12x32xf32>
%423 = arith.index_cast %422 : index to i64
cf.assert %212, "dim must be greater or equal to zero"
cf.assert %213, "dim must be smaller than inputRank"
%424 = tensor.dim %420, %214 : tensor<?x?x12x32xf32>
%425 = arith.index_cast %424 : index to i64
%426 = tensor.collapse_shape %421 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%427 = tensor.cast %426 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%428 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_53 : tensor<384xf32>) outs(%331 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%429 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_54 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%430 = linalg.batch_matmul ins(%427, %429 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%428 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%431 = tensor.dim %427, %c0 : tensor<?x?x384xf32>
%432 = tensor.dim %427, %c1 : tensor<?x?x384xf32>
%433 = arith.cmpi eq, %431, %327 : index
cf.assert %433, "mismatched size for broadcast"
%434 = arith.cmpi eq, %432, %328 : index
cf.assert %434, "mismatched size for broadcast"
%435 = linalg.init_tensor [%431, %432, 384] : tensor<?x?x384xf32>
%436 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%430, %364 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%437 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %437, "mismatching contracting dimension"
cf.assert %437, "mismatching contracting dimension"
cf.assert %437, "mismatching contracting dimension"
%438 = arith.sitofp %c384_i64 : i64 to f32
%439 = linalg.init_tensor [%431, %432] : tensor<?x?xf32>
%440 = linalg.fill(%cst, %439) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%441 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%436 : tensor<?x?x384xf32>) outs(%440 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%442 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%441 : tensor<?x?xf32>) outs(%439 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %438 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%443 = linalg.fill(%cst, %439) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%436, %442 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%443 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%445 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%444 : tensor<?x?xf32>) outs(%439 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %438 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%446 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%436, %442, %445, %cst_56, %cst_55 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%447 = linalg.init_tensor [%431, %432, 1536] : tensor<?x?x1536xf32>
%448 = linalg.init_tensor [%431, 384, 1536] : tensor<?x384x1536xf32>
%449 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_57 : tensor<1536xf32>) outs(%447 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%450 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_58 : tensor<1536x384xf32>) outs(%448 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%451 = linalg.batch_matmul ins(%446, %450 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%449 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%452 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%451 : tensor<?x?x1536xf32>) outs(%447 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.sqrt %cst_3 : f32
%835 = arith.divf %arg1, %834 : f32
%836 = math.erf %835 : f32
%837 = arith.addf %836, %cst_4 : f32
%838 = arith.mulf %837, %cst_2 : f32
%839 = arith.mulf %arg1, %838 : f32
linalg.yield %839 : f32
} -> tensor<?x?x1536xf32>
%453 = linalg.init_tensor [%431, 1536, 384] : tensor<?x1536x384xf32>
%454 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_59 : tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%455 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_60 : tensor<384x1536xf32>) outs(%453 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%456 = linalg.batch_matmul ins(%452, %455 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%454 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%457 = arith.cmpi eq, %431, %431 : index
cf.assert %457, "mismatched size for broadcast"
%458 = arith.cmpi eq, %432, %432 : index
cf.assert %458, "mismatched size for broadcast"
%459 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%456, %446 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%460 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %460, "mismatching contracting dimension"
cf.assert %460, "mismatching contracting dimension"
cf.assert %460, "mismatching contracting dimension"
%461 = arith.sitofp %c384_i64 : i64 to f32
%462 = linalg.fill(%cst, %439) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%463 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%459 : tensor<?x?x384xf32>) outs(%462 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%464 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%463 : tensor<?x?xf32>) outs(%439 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %461 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%465 = linalg.fill(%cst, %439) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%466 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%459, %464 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%465 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%467 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%466 : tensor<?x?xf32>) outs(%439 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %461 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%468 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%459, %464, %467, %cst_62, %cst_61 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%469 = linalg.init_tensor [%431, 384, 384] : tensor<?x384x384xf32>
%470 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63 : tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%471 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64 : tensor<384x384xf32>) outs(%469 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%472 = linalg.batch_matmul ins(%468, %471 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%470 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%473 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65 : tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%474 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_66 : tensor<384x384xf32>) outs(%469 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%475 = linalg.batch_matmul ins(%468, %474 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%473 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%476 = tensor.dim %475, %148 : tensor<?x?x384xf32>
%477 = arith.index_cast %476 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%478 = tensor.dim %475, %156 : tensor<?x?x384xf32>
%479 = arith.index_cast %478 : index to i64
%480 = tensor.expand_shape %475 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%481 = linalg.init_tensor [%431, 12, %432, 32] : tensor<?x12x?x32xf32>
%482 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%480 : tensor<?x?x12x32xf32>) outs(%481 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%483 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_67 : tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_68 : tensor<384x384xf32>) outs(%469 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%485 = linalg.batch_matmul ins(%468, %484 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%483 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%486 = tensor.dim %485, %148 : tensor<?x?x384xf32>
%487 = arith.index_cast %486 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%488 = tensor.dim %485, %156 : tensor<?x?x384xf32>
%489 = arith.index_cast %488 : index to i64
%490 = tensor.expand_shape %485 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%491 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%490 : tensor<?x?x12x32xf32>) outs(%481 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%492 = tensor.dim %472, %148 : tensor<?x?x384xf32>
%493 = arith.index_cast %492 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%494 = tensor.dim %472, %156 : tensor<?x?x384xf32>
%495 = arith.index_cast %494 : index to i64
%496 = tensor.expand_shape %472 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%497 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%496 : tensor<?x?x12x32xf32>) outs(%481 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%498 = linalg.init_tensor [%431, 12, 32, %432] : tensor<?x12x32x?xf32>
%499 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%482 : tensor<?x12x?x32xf32>) outs(%498 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%500 = arith.index_cast %431 : index to i64
%501 = arith.cmpi eq, %500, %500 : i64
cf.assert %501, "mismatching contracting dimension"
%502 = linalg.init_tensor [%431, 12, %432, %432] : tensor<?x12x?x?xf32>
%503 = linalg.fill(%cst, %502) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%504 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%497, %499 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%503 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%505 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%504, %cst_95 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%502 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.divf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%506 = arith.cmpi eq, %431, %72 : index
cf.assert %506, "mismatched size for broadcast"
%507 = arith.cmpi eq, %432, %92 : index
cf.assert %507, "mismatched size for broadcast"
%508 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%505, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%502 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%509 = linalg.init_tensor [%431, 12, %432, 1] : tensor<?x12x?x1xi64>
%510 = linalg.fill(%c0_i64, %509) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%511 = linalg.init_tensor [%431, 12, %432, 1] : tensor<?x12x?x1xf32>
%512 = linalg.fill(%cst_1, %511) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%513:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%508 : tensor<?x12x?x?xf32>) outs(%512, %510 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%834 = linalg.index 3 : index
%835 = arith.index_cast %834 : index to i64
%836 = arith.cmpf ogt, %arg1, %arg2 : f32
%837 = arith.select %836, %arg1, %arg2 : f32
%838 = arith.select %836, %835, %arg3 : i64
linalg.yield %837, %838 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
cf.assert %457, "mismatched size for broadcast"
cf.assert %458, "mismatched size for broadcast"
%514 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%508, %513#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%502 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.subf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%515 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%514 : tensor<?x12x?x?xf32>) outs(%502 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.exp %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
%516 = linalg.fill(%cst, %511) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%517 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%515 : tensor<?x12x?x?xf32>) outs(%516 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x1xf32>
cf.assert %457, "mismatched size for broadcast"
cf.assert %458, "mismatched size for broadcast"
%518 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%515, %517 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%502 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.divf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %501, "mismatching contracting dimension"
%519 = arith.index_cast %432 : index to i64
%520 = arith.cmpi eq, %519, %519 : i64
cf.assert %520, "mismatching contracting dimension"
%521 = linalg.fill(%cst, %481) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%522 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%518, %491 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%521 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x32xf32>
%523 = linalg.init_tensor [%431, %432, 12, 32] : tensor<?x?x12x32xf32>
%524 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%522 : tensor<?x12x?x32xf32>) outs(%523 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%525 = tensor.cast %524 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
cf.assert %205, "dim must be greater or equal to zero"
cf.assert %206, "dim must be smaller than inputRank"
%526 = tensor.dim %524, %207 : tensor<?x?x12x32xf32>
%527 = arith.index_cast %526 : index to i64
cf.assert %212, "dim must be greater or equal to zero"
cf.assert %213, "dim must be smaller than inputRank"
%528 = tensor.dim %524, %214 : tensor<?x?x12x32xf32>
%529 = arith.index_cast %528 : index to i64
%530 = tensor.collapse_shape %525 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%531 = tensor.cast %530 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%532 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_69 : tensor<384xf32>) outs(%435 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%533 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_70 : tensor<384x384xf32>) outs(%469 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%534 = linalg.batch_matmul ins(%531, %533 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%532 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%535 = tensor.dim %531, %c0 : tensor<?x?x384xf32>
%536 = tensor.dim %531, %c1 : tensor<?x?x384xf32>
%537 = arith.cmpi eq, %535, %431 : index
cf.assert %537, "mismatched size for broadcast"
%538 = arith.cmpi eq, %536, %432 : index
cf.assert %538, "mismatched size for broadcast"
%539 = linalg.init_tensor [%535, %536, 384] : tensor<?x?x384xf32>
%540 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%534, %468 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%541 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %541, "mismatching contracting dimension"
cf.assert %541, "mismatching contracting dimension"
cf.assert %541, "mismatching contracting dimension"
%542 = arith.sitofp %c384_i64 : i64 to f32
%543 = linalg.init_tensor [%535, %536] : tensor<?x?xf32>
%544 = linalg.fill(%cst, %543) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%545 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%540 : tensor<?x?x384xf32>) outs(%544 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%546 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%545 : tensor<?x?xf32>) outs(%543 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %542 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%547 = linalg.fill(%cst, %543) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%548 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%540, %546 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%547 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%549 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%548 : tensor<?x?xf32>) outs(%543 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %542 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%550 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%540, %546, %549, %cst_72, %cst_71 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%551 = linalg.init_tensor [%535, %536, 1536] : tensor<?x?x1536xf32>
%552 = linalg.init_tensor [%535, 384, 1536] : tensor<?x384x1536xf32>
%553 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_73 : tensor<1536xf32>) outs(%551 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%554 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_74 : tensor<1536x384xf32>) outs(%552 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%555 = linalg.batch_matmul ins(%550, %554 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%553 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%556 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%555 : tensor<?x?x1536xf32>) outs(%551 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.sqrt %cst_3 : f32
%835 = arith.divf %arg1, %834 : f32
%836 = math.erf %835 : f32
%837 = arith.addf %836, %cst_4 : f32
%838 = arith.mulf %837, %cst_2 : f32
%839 = arith.mulf %arg1, %838 : f32
linalg.yield %839 : f32
} -> tensor<?x?x1536xf32>
%557 = linalg.init_tensor [%535, 1536, 384] : tensor<?x1536x384xf32>
%558 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_75 : tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%559 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_76 : tensor<384x1536xf32>) outs(%557 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%560 = linalg.batch_matmul ins(%556, %559 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%558 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%561 = arith.cmpi eq, %535, %535 : index
cf.assert %561, "mismatched size for broadcast"
%562 = arith.cmpi eq, %536, %536 : index
cf.assert %562, "mismatched size for broadcast"
%563 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%560, %550 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%564 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %564, "mismatching contracting dimension"
cf.assert %564, "mismatching contracting dimension"
cf.assert %564, "mismatching contracting dimension"
%565 = arith.sitofp %c384_i64 : i64 to f32
%566 = linalg.fill(%cst, %543) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%567 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%563 : tensor<?x?x384xf32>) outs(%566 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%568 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%567 : tensor<?x?xf32>) outs(%543 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %565 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%569 = linalg.fill(%cst, %543) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%570 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%563, %568 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%569 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%571 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%570 : tensor<?x?xf32>) outs(%543 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %565 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%572 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%563, %568, %571, %cst_78, %cst_77 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%573 = linalg.init_tensor [%535, 384, 384] : tensor<?x384x384xf32>
%574 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_79 : tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%575 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_80 : tensor<384x384xf32>) outs(%573 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%576 = linalg.batch_matmul ins(%572, %575 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%574 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%577 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_81 : tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%578 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_82 : tensor<384x384xf32>) outs(%573 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%579 = linalg.batch_matmul ins(%572, %578 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%577 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%580 = tensor.dim %579, %148 : tensor<?x?x384xf32>
%581 = arith.index_cast %580 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%582 = tensor.dim %579, %156 : tensor<?x?x384xf32>
%583 = arith.index_cast %582 : index to i64
%584 = tensor.expand_shape %579 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%585 = linalg.init_tensor [%535, 12, %536, 32] : tensor<?x12x?x32xf32>
%586 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%584 : tensor<?x?x12x32xf32>) outs(%585 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%587 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_83 : tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%588 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_84 : tensor<384x384xf32>) outs(%573 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%589 = linalg.batch_matmul ins(%572, %588 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%587 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%590 = tensor.dim %589, %148 : tensor<?x?x384xf32>
%591 = arith.index_cast %590 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%592 = tensor.dim %589, %156 : tensor<?x?x384xf32>
%593 = arith.index_cast %592 : index to i64
%594 = tensor.expand_shape %589 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%595 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%594 : tensor<?x?x12x32xf32>) outs(%585 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%596 = tensor.dim %576, %148 : tensor<?x?x384xf32>
%597 = arith.index_cast %596 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%598 = tensor.dim %576, %156 : tensor<?x?x384xf32>
%599 = arith.index_cast %598 : index to i64
%600 = tensor.expand_shape %576 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%601 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%600 : tensor<?x?x12x32xf32>) outs(%585 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%602 = linalg.init_tensor [%535, 12, 32, %536] : tensor<?x12x32x?xf32>
%603 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%586 : tensor<?x12x?x32xf32>) outs(%602 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%604 = arith.index_cast %535 : index to i64
%605 = arith.cmpi eq, %604, %604 : i64
cf.assert %605, "mismatching contracting dimension"
%606 = linalg.init_tensor [%535, 12, %536, %536] : tensor<?x12x?x?xf32>
%607 = linalg.fill(%cst, %606) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%608 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%601, %603 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%607 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%609 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%608, %cst_95 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%606 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.divf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%610 = arith.cmpi eq, %535, %72 : index
cf.assert %610, "mismatched size for broadcast"
%611 = arith.cmpi eq, %536, %92 : index
cf.assert %611, "mismatched size for broadcast"
%612 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%609, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%606 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%613 = linalg.init_tensor [%535, 12, %536, 1] : tensor<?x12x?x1xi64>
%614 = linalg.fill(%c0_i64, %613) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%615 = linalg.init_tensor [%535, 12, %536, 1] : tensor<?x12x?x1xf32>
%616 = linalg.fill(%cst_1, %615) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%617:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%612 : tensor<?x12x?x?xf32>) outs(%616, %614 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%834 = linalg.index 3 : index
%835 = arith.index_cast %834 : index to i64
%836 = arith.cmpf ogt, %arg1, %arg2 : f32
%837 = arith.select %836, %arg1, %arg2 : f32
%838 = arith.select %836, %835, %arg3 : i64
linalg.yield %837, %838 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
cf.assert %561, "mismatched size for broadcast"
cf.assert %562, "mismatched size for broadcast"
%618 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%612, %617#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%606 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.subf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%619 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%618 : tensor<?x12x?x?xf32>) outs(%606 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.exp %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
%620 = linalg.fill(%cst, %615) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%621 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%619 : tensor<?x12x?x?xf32>) outs(%620 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x1xf32>
cf.assert %561, "mismatched size for broadcast"
cf.assert %562, "mismatched size for broadcast"
%622 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%619, %621 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%606 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.divf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %605, "mismatching contracting dimension"
%623 = arith.index_cast %536 : index to i64
%624 = arith.cmpi eq, %623, %623 : i64
cf.assert %624, "mismatching contracting dimension"
%625 = linalg.fill(%cst, %585) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%626 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%622, %595 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%625 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x32xf32>
%627 = linalg.init_tensor [%535, %536, 12, 32] : tensor<?x?x12x32xf32>
%628 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%626 : tensor<?x12x?x32xf32>) outs(%627 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%629 = tensor.cast %628 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
cf.assert %205, "dim must be greater or equal to zero"
cf.assert %206, "dim must be smaller than inputRank"
%630 = tensor.dim %628, %207 : tensor<?x?x12x32xf32>
%631 = arith.index_cast %630 : index to i64
cf.assert %212, "dim must be greater or equal to zero"
cf.assert %213, "dim must be smaller than inputRank"
%632 = tensor.dim %628, %214 : tensor<?x?x12x32xf32>
%633 = arith.index_cast %632 : index to i64
%634 = tensor.collapse_shape %629 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%635 = tensor.cast %634 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%636 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_85 : tensor<384xf32>) outs(%539 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%637 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_86 : tensor<384x384xf32>) outs(%573 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%638 = linalg.batch_matmul ins(%635, %637 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%636 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%639 = tensor.dim %635, %c0 : tensor<?x?x384xf32>
%640 = tensor.dim %635, %c1 : tensor<?x?x384xf32>
%641 = arith.cmpi eq, %639, %535 : index
cf.assert %641, "mismatched size for broadcast"
%642 = arith.cmpi eq, %640, %536 : index
cf.assert %642, "mismatched size for broadcast"
%643 = linalg.init_tensor [%639, %640, 384] : tensor<?x?x384xf32>
%644 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%638, %572 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%645 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %645, "mismatching contracting dimension"
cf.assert %645, "mismatching contracting dimension"
cf.assert %645, "mismatching contracting dimension"
%646 = arith.sitofp %c384_i64 : i64 to f32
%647 = linalg.init_tensor [%639, %640] : tensor<?x?xf32>
%648 = linalg.fill(%cst, %647) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%649 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%644 : tensor<?x?x384xf32>) outs(%648 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%650 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%649 : tensor<?x?xf32>) outs(%647 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %646 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%651 = linalg.fill(%cst, %647) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%652 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%644, %650 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%651 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%653 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%652 : tensor<?x?xf32>) outs(%647 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %646 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%654 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%644, %650, %653, %cst_88, %cst_87 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%655 = linalg.init_tensor [%639, %640, 1536] : tensor<?x?x1536xf32>
%656 = linalg.init_tensor [%639, 384, 1536] : tensor<?x384x1536xf32>
%657 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_89 : tensor<1536xf32>) outs(%655 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%658 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_90 : tensor<1536x384xf32>) outs(%656 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%659 = linalg.batch_matmul ins(%654, %658 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%657 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%660 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%659 : tensor<?x?x1536xf32>) outs(%655 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.sqrt %cst_3 : f32
%835 = arith.divf %arg1, %834 : f32
%836 = math.erf %835 : f32
%837 = arith.addf %836, %cst_4 : f32
%838 = arith.mulf %837, %cst_2 : f32
%839 = arith.mulf %arg1, %838 : f32
linalg.yield %839 : f32
} -> tensor<?x?x1536xf32>
%661 = linalg.init_tensor [%639, 1536, 384] : tensor<?x1536x384xf32>
%662 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_91 : tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%663 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_92 : tensor<384x1536xf32>) outs(%661 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%664 = linalg.batch_matmul ins(%660, %663 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%662 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%665 = arith.cmpi eq, %639, %639 : index
cf.assert %665, "mismatched size for broadcast"
%666 = arith.cmpi eq, %640, %640 : index
cf.assert %666, "mismatched size for broadcast"
%667 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%664, %654 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%668 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %668, "mismatching contracting dimension"
cf.assert %668, "mismatching contracting dimension"
cf.assert %668, "mismatching contracting dimension"
%669 = arith.sitofp %c384_i64 : i64 to f32
%670 = linalg.fill(%cst, %647) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%671 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%667 : tensor<?x?x384xf32>) outs(%670 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%672 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%671 : tensor<?x?xf32>) outs(%647 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %669 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%673 = linalg.fill(%cst, %647) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%674 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%667, %672 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%673 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%675 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%674 : tensor<?x?xf32>) outs(%647 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %669 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%676 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%667, %672, %675, %cst_94, %cst_93 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%677 = linalg.init_tensor [%639, 384, 384] : tensor<?x384x384xf32>
%678 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_96 : tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%679 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_97 : tensor<384x384xf32>) outs(%677 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%680 = linalg.batch_matmul ins(%676, %679 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%678 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%681 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_98 : tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%682 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_99 : tensor<384x384xf32>) outs(%677 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%683 = linalg.batch_matmul ins(%676, %682 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%681 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%684 = tensor.dim %683, %148 : tensor<?x?x384xf32>
%685 = arith.index_cast %684 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%686 = tensor.dim %683, %156 : tensor<?x?x384xf32>
%687 = arith.index_cast %686 : index to i64
%688 = tensor.expand_shape %683 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%689 = linalg.init_tensor [%639, 12, %640, 32] : tensor<?x12x?x32xf32>
%690 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%688 : tensor<?x?x12x32xf32>) outs(%689 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%691 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_100 : tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%692 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_101 : tensor<384x384xf32>) outs(%677 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%693 = linalg.batch_matmul ins(%676, %692 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%691 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%694 = tensor.dim %693, %148 : tensor<?x?x384xf32>
%695 = arith.index_cast %694 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%696 = tensor.dim %693, %156 : tensor<?x?x384xf32>
%697 = arith.index_cast %696 : index to i64
%698 = tensor.expand_shape %693 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%699 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%698 : tensor<?x?x12x32xf32>) outs(%689 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
cf.assert %146, "dim must be greater or equal to zero"
cf.assert %147, "dim must be smaller than inputRank"
%700 = tensor.dim %680, %148 : tensor<?x?x384xf32>
%701 = arith.index_cast %700 : index to i64
cf.assert %154, "dim must be greater or equal to zero"
cf.assert %155, "dim must be smaller than inputRank"
%702 = tensor.dim %680, %156 : tensor<?x?x384xf32>
%703 = arith.index_cast %702 : index to i64
%704 = tensor.expand_shape %680 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%705 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%704 : tensor<?x?x12x32xf32>) outs(%689 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%706 = linalg.init_tensor [%639, 12, 32, %640] : tensor<?x12x32x?xf32>
%707 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%690 : tensor<?x12x?x32xf32>) outs(%706 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%708 = arith.index_cast %639 : index to i64
%709 = arith.cmpi eq, %708, %708 : i64
cf.assert %709, "mismatching contracting dimension"
%710 = linalg.init_tensor [%639, 12, %640, %640] : tensor<?x12x?x?xf32>
%711 = linalg.fill(%cst, %710) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%712 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%705, %707 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%711 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%713 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%712, %cst_95 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%710 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%834 = arith.truncf %arg2 : f64 to f32
%835 = arith.divf %arg1, %834 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x?xf32>
%714 = arith.cmpi eq, %639, %72 : index
cf.assert %714, "mismatched size for broadcast"
%715 = arith.cmpi eq, %640, %92 : index
cf.assert %715, "mismatched size for broadcast"
%716 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%713, %96 : tensor<?x12x?x?xf32>, tensor<?x1x1x?xf32>) outs(%710 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%717 = linalg.init_tensor [%639, 12, %640, 1] : tensor<?x12x?x1xi64>
%718 = linalg.fill(%c0_i64, %717) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%719 = linalg.init_tensor [%639, 12, %640, 1] : tensor<?x12x?x1xf32>
%720 = linalg.fill(%cst_1, %719) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%721:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%716 : tensor<?x12x?x?xf32>) outs(%720, %718 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%834 = linalg.index 3 : index
%835 = arith.index_cast %834 : index to i64
%836 = arith.cmpf ogt, %arg1, %arg2 : f32
%837 = arith.select %836, %arg1, %arg2 : f32
%838 = arith.select %836, %835, %arg3 : i64
linalg.yield %837, %838 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
cf.assert %665, "mismatched size for broadcast"
cf.assert %666, "mismatched size for broadcast"
%722 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%716, %721#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%710 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.truncf %cst_0 : f64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.subf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x12x?x?xf32>
%723 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%722 : tensor<?x12x?x?xf32>) outs(%710 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.exp %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
%724 = linalg.fill(%cst, %719) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%725 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%723 : tensor<?x12x?x?xf32>) outs(%724 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x1xf32>
cf.assert %665, "mismatched size for broadcast"
cf.assert %666, "mismatched size for broadcast"
%726 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%723, %725 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%710 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.divf %arg1, %arg2 : f32
linalg.yield %834 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %709, "mismatching contracting dimension"
%727 = arith.index_cast %640 : index to i64
%728 = arith.cmpi eq, %727, %727 : i64
cf.assert %728, "mismatching contracting dimension"
%729 = linalg.fill(%cst, %689) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%730 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%726, %699 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%729 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.mulf %arg1, %arg2 : f32
%835 = arith.addf %834, %arg3 : f32
linalg.yield %835 : f32
} -> tensor<?x12x?x32xf32>
%731 = linalg.init_tensor [%639, %640, 12, 32] : tensor<?x?x12x32xf32>
%732 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%730 : tensor<?x12x?x32xf32>) outs(%731 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%733 = tensor.cast %732 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
cf.assert %205, "dim must be greater or equal to zero"
cf.assert %206, "dim must be smaller than inputRank"
%734 = tensor.dim %732, %207 : tensor<?x?x12x32xf32>
%735 = arith.index_cast %734 : index to i64
cf.assert %212, "dim must be greater or equal to zero"
cf.assert %213, "dim must be smaller than inputRank"
%736 = tensor.dim %732, %214 : tensor<?x?x12x32xf32>
%737 = arith.index_cast %736 : index to i64
%738 = tensor.collapse_shape %733 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%739 = tensor.cast %738 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%740 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_102 : tensor<384xf32>) outs(%643 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%741 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_103 : tensor<384x384xf32>) outs(%677 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%742 = linalg.batch_matmul ins(%739, %741 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%740 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%743 = tensor.dim %739, %c0 : tensor<?x?x384xf32>
%744 = tensor.dim %739, %c1 : tensor<?x?x384xf32>
%745 = arith.cmpi eq, %743, %639 : index
cf.assert %745, "mismatched size for broadcast"
%746 = arith.cmpi eq, %744, %640 : index
cf.assert %746, "mismatched size for broadcast"
%747 = linalg.init_tensor [%743, %744, 384] : tensor<?x?x384xf32>
%748 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%742, %676 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%747 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%749 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %749, "mismatching contracting dimension"
cf.assert %749, "mismatching contracting dimension"
cf.assert %749, "mismatching contracting dimension"
%750 = arith.sitofp %c384_i64 : i64 to f32
%751 = linalg.init_tensor [%743, %744] : tensor<?x?xf32>
%752 = linalg.fill(%cst, %751) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%753 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%748 : tensor<?x?x384xf32>) outs(%752 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%754 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%753 : tensor<?x?xf32>) outs(%751 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %750 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%755 = linalg.fill(%cst, %751) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%756 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%748, %754 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%755 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%757 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%756 : tensor<?x?xf32>) outs(%751 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %750 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%758 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%748, %754, %757, %cst_105, %cst_104 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%747 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%759 = linalg.init_tensor [%743, %744, 1536] : tensor<?x?x1536xf32>
%760 = linalg.init_tensor [%743, 384, 1536] : tensor<?x384x1536xf32>
%761 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_106 : tensor<1536xf32>) outs(%759 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%762 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_107 : tensor<1536x384xf32>) outs(%760 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%763 = linalg.batch_matmul ins(%758, %762 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%761 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%764 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%763 : tensor<?x?x1536xf32>) outs(%759 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.sqrt %cst_3 : f32
%835 = arith.divf %arg1, %834 : f32
%836 = math.erf %835 : f32
%837 = arith.addf %836, %cst_4 : f32
%838 = arith.mulf %837, %cst_2 : f32
%839 = arith.mulf %arg1, %838 : f32
linalg.yield %839 : f32
} -> tensor<?x?x1536xf32>
%765 = linalg.init_tensor [%743, 1536, 384] : tensor<?x1536x384xf32>
%766 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_108 : tensor<384xf32>) outs(%747 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%767 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_109 : tensor<384x1536xf32>) outs(%765 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%768 = linalg.batch_matmul ins(%764, %767 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%766 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%769 = arith.cmpi eq, %743, %743 : index
cf.assert %769, "mismatched size for broadcast"
%770 = arith.cmpi eq, %744, %744 : index
cf.assert %770, "mismatched size for broadcast"
%771 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%768, %758 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%747 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.sitofp %c1_i64 : i64 to f32
%835 = arith.mulf %arg2, %834 : f32
%836 = arith.addf %arg1, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?x384xf32>
%772 = arith.cmpi eq, %c384_i64, %c384_i64 : i64
cf.assert %772, "mismatching contracting dimension"
cf.assert %772, "mismatching contracting dimension"
cf.assert %772, "mismatching contracting dimension"
%773 = arith.sitofp %c384_i64 : i64 to f32
%774 = linalg.fill(%cst, %751) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%775 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%771 : tensor<?x?x384xf32>) outs(%774 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.addf %arg2, %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%776 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%775 : tensor<?x?xf32>) outs(%751 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %773 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%777 = linalg.fill(%cst, %751) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%778 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%771, %776 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%777 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.mulf %834, %834 : f32
%836 = arith.addf %arg3, %835 : f32
linalg.yield %836 : f32
} -> tensor<?x?xf32>
%779 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%778 : tensor<?x?xf32>) outs(%751 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = arith.divf %arg1, %773 : f32
linalg.yield %834 : f32
} -> tensor<?x?xf32>
%780 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%771, %776, %779, %cst_111, %cst_110 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%747 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%834 = arith.subf %arg1, %arg2 : f32
%835 = arith.truncf %cst_5 : f64 to f32
%836 = arith.addf %arg3, %835 : f32
%837 = math.rsqrt %836 : f32
%838 = arith.mulf %834, %837 : f32
%839 = arith.mulf %838, %arg4 : f32
%840 = arith.addf %839, %arg5 : f32
linalg.yield %840 : f32
} -> tensor<?x?x384xf32>
%781 = arith.index_cast %743 : index to i64
%782 = arith.addi %c0_i64, %781 : i64
%783 = arith.select %9, %c0_i64, %782 : i64
%784 = arith.cmpi slt, %783, %c0_i64 : i64
%785 = arith.select %784, %c0_i64, %783 : i64
%786 = arith.cmpi sgt, %785, %781 : i64
%787 = arith.select %786, %781, %785 : i64
%788 = arith.index_cast %787 : i64 to index
%789 = arith.addi %c9223372036854775807_i64, %781 : i64
%790 = arith.select %17, %c9223372036854775807_i64, %789 : i64
%791 = arith.cmpi slt, %790, %c0_i64 : i64
%792 = arith.select %791, %c0_i64, %790 : i64
%793 = arith.cmpi sgt, %792, %781 : i64
%794 = arith.select %793, %781, %792 : i64
%795 = arith.index_cast %794 : i64 to index
%796 = arith.cmpi sge, %795, %788 : index
%797 = arith.select %796, %795, %788 : index
%798 = arith.subi %797, %788 : index
%799 = tensor.extract_slice %780[%788, 0, 0] [%798, %744, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%800 = arith.index_cast %744 : index to i64
%801 = arith.addi %c0_i64, %800 : i64
%802 = arith.select %9, %c0_i64, %801 : i64
%803 = arith.cmpi slt, %802, %c0_i64 : i64
%804 = arith.select %803, %c0_i64, %802 : i64
%805 = arith.cmpi sgt, %804, %800 : i64
%806 = arith.select %805, %800, %804 : i64
%807 = arith.index_cast %806 : i64 to index
%808 = arith.addi %c1_i64, %800 : i64
%809 = arith.cmpi sge, %c1_i64, %c0_i64 : i64
%810 = arith.select %809, %c1_i64, %808 : i64
%811 = arith.cmpi slt, %810, %c0_i64 : i64
%812 = arith.select %811, %c0_i64, %810 : i64
%813 = arith.cmpi sgt, %812, %800 : i64
%814 = arith.select %813, %800, %812 : i64
%815 = arith.index_cast %814 : i64 to index
%816 = arith.cmpi sge, %815, %807 : index
%817 = arith.select %816, %815, %807 : index
%818 = arith.subi %817, %807 : index
%819 = tensor.extract_slice %799[0, %807, 0] [%798, %818, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%820 = tensor.cast %819 : tensor<?x?x384xf32> to tensor<?x1x384xf32>
%821 = tensor.collapse_shape %820 [[0, 1], [2]] : tensor<?x1x384xf32> into tensor<?x384xf32>
%822 = tensor.dim %820, %c0 : tensor<?x1x384xf32>
%823 = linalg.init_tensor [%822, 384] : tensor<?x384xf32>
%824 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%825 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_112 : tensor<384xf32>) outs(%823 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384xf32>
%826 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_113 : tensor<384x384xf32>) outs(%824 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%827 = linalg.matmul ins(%821, %826 : tensor<?x384xf32>, tensor<384x384xf32>) outs(%825 : tensor<?x384xf32>) -> tensor<?x384xf32>
%828 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%827 : tensor<?x384xf32>) outs(%823 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%834 = math.tanh %arg1 : f32
linalg.yield %834 : f32
} -> tensor<?x384xf32>
%829 = linalg.init_tensor [%822, 2] : tensor<?x2xf32>
%830 = linalg.init_tensor [384, 2] : tensor<384x2xf32>
%831 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_114 : tensor<2xf32>) outs(%829 : tensor<?x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x2xf32>
%832 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_115 : tensor<2x384xf32>) outs(%830 : tensor<384x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x2xf32>
%833 = linalg.matmul ins(%828, %832 : tensor<?x384xf32>, tensor<384x2xf32>) outs(%831 : tensor<?x2xf32>) -> tensor<?x2xf32>
return %833 : tensor<?x2xf32>
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%1 = call @_forward(%0) : (tensor<1x512xi64>) -> tensor<?x2xf32>
%c0 = arith.constant 0 : index
%2 = tensor.dim %1, %c0 : tensor<?x2xf32>
%3 = hal.tensor.export %1 : tensor<?x2xf32>{%2} -> !hal.buffer_view
return %3 : !hal.buffer_view
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0 = arith.constant 0 : index
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%cst = arith.constant 3.840000e+02 : f32
%c512 = arith.constant 512 : index
%c1 = arith.constant 1 : index
%c0_0 = arith.constant 0 : index
%c1_i64 = arith.constant 1 : i64
%c512_i64 = arith.constant 512 : i64
%cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_2 = arith.constant dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>
%cst_3 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_4 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_5 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_6 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_7 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_8 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_20 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_21 = arith.constant dense<5.6568542494923806> : tensor<f64>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_94 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_95 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_100 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_101 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_102 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_103 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_104 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_105 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_106 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_107 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_108 = arith.constant dense<0> : tensor<i64>
%cst_109 = arith.constant dense<0> : tensor<1x512xi64>
%cst_110 = arith.constant dense<-1.000000e+04> : tensor<f64>
%c2_i64 = arith.constant 2 : i64
%cst_111 = arith.constant 9.9999999999999998E-13 : f64
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_112 = arith.constant 1.000000e+00 : f32
%cst_113 = arith.constant 2.000000e+00 : f32
%cst_114 = arith.constant 5.000000e-01 : f32
%cst_115 = arith.constant -3.40282347E+38 : f32
%cst_116 = arith.constant 0.000000e+00 : f32
%c0_i64 = arith.constant 0 : i64
%c30522_i64 = arith.constant 30522 : i64
%1 = linalg.init_tensor [] : tensor<i64>
%2 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%3 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2, %cst_108 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%564 = arith.addi %arg1, %arg2 : i64
linalg.yield %564 : i64
} -> tensor<i64>
%4 = tensor.extract %3[] : tensor<i64>
%5 = arith.index_cast %4 : i64 to index
%6 = linalg.init_tensor [1, %5] : tensor<1x?xf32>
%7 = linalg.fill(%cst_112, %6) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
%8 = tensor.extract_slice %7[0, 0] [1, %5] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
%9 = tensor.cast %8 : tensor<1x?xf32> to tensor<?x?xf32>
%10 = tensor.expand_shape %9 [[0], [1, 2, 3]] : tensor<?x?xf32> into tensor<?x1x1x?xf32>
%11 = arith.cmpi sgt, %c0_i64, %4 : i64
%12 = arith.select %11, %4, %c0_i64 : i64
%13 = arith.index_cast %12 : i64 to index
%14 = arith.cmpi sgt, %c9223372036854775807_i64, %4 : i64
%15 = arith.select %14, %4, %c9223372036854775807_i64 : i64
%16 = arith.index_cast %15 : i64 to index
%17 = arith.cmpi sge, %16, %13 : index
%18 = arith.select %17, %16, %13 : index
%19 = arith.subi %18, %13 : index
%20 = tensor.extract_slice %10[0, 0, 0, %13] [1, 1, 1, %19] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<1x1x1x?xf32>
%21 = linalg.init_tensor [1, 1, 1, %19] : tensor<1x1x1x?xf32>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%20 : tensor<1x1x1x?xf32>) outs(%21 : tensor<1x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.subf %cst_112, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<1x1x1x?xf32>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%22, %cst_110 : tensor<1x1x1x?xf32>, tensor<f64>) outs(%21 : tensor<1x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<1x1x1x?xf32>
%24 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%25 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%24, %cst_108 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%564 = arith.addi %arg1, %arg2 : i64
linalg.yield %564 : i64
} -> tensor<i64>
%26 = tensor.extract %25[] : tensor<i64>
%27 = arith.addi %26, %c512_i64 : i64
%28 = arith.cmpi sge, %26, %c0_i64 : i64
%29 = arith.select %28, %26, %27 : i64
%30 = arith.cmpi slt, %29, %c0_i64 : i64
%31 = arith.select %30, %c0_i64, %29 : i64
%32 = arith.cmpi sgt, %31, %c512_i64 : i64
%33 = arith.select %32, %c512_i64, %31 : i64
%34 = arith.index_cast %33 : i64 to index
%35 = arith.cmpi sge, %34, %c0_0 : index
%36 = arith.select %35, %34, %c0_0 : index
%37 = tensor.extract_slice %cst_107[0, 0] [1, %36] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
%38 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%39 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<1x512xi64>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%564 = arith.index_cast %arg1 : i64 to index
%565 = linalg.index 2 : index
%566 = arith.cmpi slt, %arg1, %c30522_i64 : i64
cf.assert %566, "index must be smaller than dim size"
%567 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %567, "index must be larger or equal to 0"
%568 = tensor.extract %cst_106[%564, %565] : tensor<30522x384xf32>
linalg.yield %568 : f32
} -> tensor<1x512x384xf32>
%40 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_109 : tensor<1x512xi64>) outs(%40 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%564 = arith.index_cast %arg1 : i64 to index
%565 = linalg.index 2 : index
%566 = arith.cmpi slt, %arg1, %c2_i64 : i64
cf.assert %566, "index must be smaller than dim size"
%567 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %567, "index must be larger or equal to 0"
%568 = tensor.extract %cst_105[%564, %565] : tensor<2x384xf32>
linalg.yield %568 : f32
} -> tensor<1x512x384xf32>
%42 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%39, %41 : tensor<1x512x384xf32>, tensor<1x512x384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x512x384xf32>
%44 = linalg.init_tensor [1, %36, 384] : tensor<1x?x384xf32>
%45 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%37 : tensor<1x?xi64>) outs(%44 : tensor<1x?x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%564 = arith.index_cast %arg1 : i64 to index
%565 = linalg.index 2 : index
%566 = arith.cmpi slt, %arg1, %c512_i64 : i64
cf.assert %566, "index must be smaller than dim size"
%567 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %567, "index must be larger or equal to 0"
%568 = tensor.extract %cst_104[%564, %565] : tensor<512x384xf32>
linalg.yield %568 : f32
} -> tensor<1x?x384xf32>
%46 = arith.cmpi eq, %c512, %36 : index
cf.assert %46, "mismatched size for broadcast"
%47 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %45 : tensor<1x512x384xf32>, tensor<1x?x384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x512x384xf32>
%48 = linalg.init_tensor [1, 512] : tensor<1x512xf32>
%49 = linalg.fill(%cst_116, %48) : f32, tensor<1x512xf32> -> tensor<1x512xf32>
%50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%47 : tensor<1x512x384xf32>) outs(%49 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<1x512xf32>
%51 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%50 : tensor<1x512xf32>) outs(%48 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<1x512xf32>
%52 = linalg.fill(%cst_116, %48) : f32, tensor<1x512xf32> -> tensor<1x512xf32>
%53 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%47, %51 : tensor<1x512x384xf32>, tensor<1x512xf32>) outs(%52 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<1x512xf32>
%54 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%53 : tensor<1x512xf32>) outs(%48 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<1x512xf32>
%55 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47, %51, %54, %cst_102, %cst_103 : tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<1x512x384xf32>
%56 = linalg.init_tensor [1, 384, 384] : tensor<1x384x384xf32>
%57 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_101 : tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%58 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_100 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%59 = linalg.batch_matmul ins(%55, %58 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%57 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%60 = tensor.cast %59 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%61 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_99 : tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%62 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_98 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%63 = linalg.batch_matmul ins(%55, %62 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%61 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%64 = tensor.cast %63 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%65 = tensor.expand_shape %64 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%66 = linalg.init_tensor [1, 12, 512, 32] : tensor<1x12x512x32xf32>
%67 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%65 : tensor<?x?x12x32xf32>) outs(%66 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%68 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_97 : tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%69 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_96 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%70 = linalg.batch_matmul ins(%55, %69 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%68 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%71 = tensor.cast %70 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%72 = tensor.expand_shape %71 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%73 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%72 : tensor<?x?x12x32xf32>) outs(%66 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%74 = tensor.expand_shape %60 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%75 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%74 : tensor<?x?x12x32xf32>) outs(%66 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%76 = linalg.init_tensor [1, 12, 32, 512] : tensor<1x12x32x512xf32>
%77 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%67 : tensor<1x12x512x32xf32>) outs(%76 : tensor<1x12x32x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x32x512xf32>
%78 = linalg.init_tensor [1, 12, 512, 512] : tensor<1x12x512x512xf32>
%79 = linalg.fill(%cst_116, %78) : f32, tensor<1x12x512x512xf32> -> tensor<1x12x512x512xf32>
%80 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%75, %77 : tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>) outs(%79 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<1x12x512x512xf32>
%81 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%80, %cst_21 : tensor<1x12x512x512xf32>, tensor<f64>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<1x12x512x512xf32>
%82 = arith.cmpi eq, %c512, %19 : index
cf.assert %82, "mismatched size for broadcast"
%83 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%81, %23 : tensor<1x12x512x512xf32>, tensor<1x1x1x?xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%84 = linalg.init_tensor [1, 12, 512, 1] : tensor<1x12x512x1xi64>
%85 = linalg.fill(%c0_i64, %84) : i64, tensor<1x12x512x1xi64> -> tensor<1x12x512x1xi64>
%86 = linalg.init_tensor [1, 12, 512, 1] : tensor<1x12x512x1xf32>
%87 = linalg.fill(%cst_115, %86) : f32, tensor<1x12x512x1xf32> -> tensor<1x12x512x1xf32>
%88:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%83 : tensor<1x12x512x512xf32>) outs(%87, %85 : tensor<1x12x512x1xf32>, tensor<1x12x512x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<1x12x512x1xf32>, tensor<1x12x512x1xi64>)
%89 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%83, %88#0 : tensor<1x12x512x512xf32>, tensor<1x12x512x1xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%90 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%89 : tensor<1x12x512x512xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%91 = linalg.fill(%cst_116, %86) : f32, tensor<1x12x512x1xf32> -> tensor<1x12x512x1xf32>
%92 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%90 : tensor<1x12x512x512xf32>) outs(%91 : tensor<1x12x512x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x1xf32>
%93 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%90, %92 : tensor<1x12x512x512xf32>, tensor<1x12x512x1xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%94 = linalg.fill(%cst_116, %66) : f32, tensor<1x12x512x32xf32> -> tensor<1x12x512x32xf32>
%95 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%93, %73 : tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>) outs(%94 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<1x12x512x32xf32>
%96 = linalg.init_tensor [1, 512, 12, 32] : tensor<1x512x12x32xf32>
%97 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%95 : tensor<1x12x512x32xf32>) outs(%96 : tensor<1x512x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x12x32xf32>
%98 = tensor.cast %97 : tensor<1x512x12x32xf32> to tensor<?x?x?x?xf32>
%99 = tensor.collapse_shape %98 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%100 = tensor.cast %99 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%101 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%102 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_95 : tensor<384xf32>) outs(%101 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%103 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_94 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%104 = linalg.batch_matmul ins(%100, %103 : tensor<?x?x384xf32>, tensor<1x384x384xf32>) outs(%102 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%105 = tensor.dim %99, %c0_0 : tensor<?x?x?xf32>
%106 = tensor.dim %99, %c1 : tensor<?x?x?xf32>
%107 = arith.cmpi eq, %105, %c1 : index
cf.assert %107, "mismatched size for broadcast"
%108 = arith.cmpi eq, %106, %c512 : index
cf.assert %108, "mismatched size for broadcast"
%109 = linalg.init_tensor [%105, %106, 384] : tensor<?x?x384xf32>
%110 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%104, %55 : tensor<1x512x384xf32>, tensor<1x512x384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%111 = linalg.init_tensor [%105, %106] : tensor<?x?xf32>
%112 = linalg.fill(%cst_116, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%113 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%110 : tensor<?x?x384xf32>) outs(%112 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%114 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%113 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%115 = linalg.fill(%cst_116, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%116 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%110, %114 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%115 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%117 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%116 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%118 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%110, %114, %117, %cst_92, %cst_93 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%119 = linalg.init_tensor [%105, %106, 1536] : tensor<?x?x1536xf32>
%120 = linalg.init_tensor [%105, 384, 1536] : tensor<?x384x1536xf32>
%121 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_91 : tensor<1536xf32>) outs(%119 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%122 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_90 : tensor<1536x384xf32>) outs(%120 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%123 = linalg.batch_matmul ins(%118, %122 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%121 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%124 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%123 : tensor<?x?x1536xf32>) outs(%119 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_113 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_112 : f32
%568 = arith.mulf %567, %cst_114 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%125 = linalg.init_tensor [%105, 1536, 384] : tensor<?x1536x384xf32>
%126 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_89 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%127 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_88 : tensor<384x1536xf32>) outs(%125 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%128 = linalg.batch_matmul ins(%124, %127 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%126 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%129 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%128, %118 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%130 = linalg.fill(%cst_116, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%131 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%129 : tensor<?x?x384xf32>) outs(%130 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%132 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%131 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%133 = linalg.fill(%cst_116, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%134 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%129, %132 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%133 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%135 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%134 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%136 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%129, %132, %135, %cst_86, %cst_87 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%137 = linalg.init_tensor [%105, 384, 384] : tensor<?x384x384xf32>
%138 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_85 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%139 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_84 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%140 = linalg.batch_matmul ins(%136, %139 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%138 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%141 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_83 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%142 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_82 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%143 = linalg.batch_matmul ins(%136, %142 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%141 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%144 = tensor.expand_shape %143 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%145 = linalg.init_tensor [%105, 12, %106, 32] : tensor<?x12x?x32xf32>
%146 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%144 : tensor<?x?x12x32xf32>) outs(%145 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%147 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_81 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%148 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_80 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%149 = linalg.batch_matmul ins(%136, %148 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%147 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%150 = tensor.expand_shape %149 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%151 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%150 : tensor<?x?x12x32xf32>) outs(%145 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%152 = tensor.expand_shape %140 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%153 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%152 : tensor<?x?x12x32xf32>) outs(%145 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%154 = linalg.init_tensor [%105, 12, 32, %106] : tensor<?x12x32x?xf32>
%155 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%146 : tensor<?x12x?x32xf32>) outs(%154 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%156 = linalg.init_tensor [%105, 12, %106, %106] : tensor<?x12x?x?xf32>
%157 = linalg.fill(%cst_116, %156) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%158 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%153, %155 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%157 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%159 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%158, %cst_21 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%160 = arith.cmpi eq, %105, %c1 : index
cf.assert %160, "mismatched size for broadcast"
%161 = arith.cmpi eq, %106, %19 : index
cf.assert %161, "mismatched size for broadcast"
%162 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%159, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%163 = linalg.init_tensor [%105, 12, %106, 1] : tensor<?x12x?x1xi64>
%164 = linalg.fill(%c0_i64, %163) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%165 = linalg.init_tensor [%105, 12, %106, 1] : tensor<?x12x?x1xf32>
%166 = linalg.fill(%cst_115, %165) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%167:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%162 : tensor<?x12x?x?xf32>) outs(%166, %164 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%168 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%162, %167#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%169 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%168 : tensor<?x12x?x?xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%170 = linalg.fill(%cst_116, %165) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%171 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%169 : tensor<?x12x?x?xf32>) outs(%170 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%172 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%169, %171 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%173 = linalg.fill(%cst_116, %145) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%174 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%172, %151 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%173 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%175 = linalg.init_tensor [%105, %106, 12, 32] : tensor<?x?x12x32xf32>
%176 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%174 : tensor<?x12x?x32xf32>) outs(%175 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%177 = tensor.cast %176 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%178 = tensor.collapse_shape %177 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%179 = tensor.cast %178 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%180 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_79 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%181 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_78 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%182 = linalg.batch_matmul ins(%179, %181 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%180 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%183 = tensor.dim %178, %c0_0 : tensor<?x?x?xf32>
%184 = tensor.dim %178, %c1 : tensor<?x?x?xf32>
%185 = arith.cmpi eq, %183, %105 : index
cf.assert %185, "mismatched size for broadcast"
%186 = arith.cmpi eq, %184, %106 : index
cf.assert %186, "mismatched size for broadcast"
%187 = linalg.init_tensor [%183, %184, 384] : tensor<?x?x384xf32>
%188 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%182, %136 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%189 = linalg.init_tensor [%183, %184] : tensor<?x?xf32>
%190 = linalg.fill(%cst_116, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%191 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%188 : tensor<?x?x384xf32>) outs(%190 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%192 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%191 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%193 = linalg.fill(%cst_116, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%194 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%188, %192 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%193 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%195 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%194 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%196 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%188, %192, %195, %cst_76, %cst_77 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%197 = linalg.init_tensor [%183, %184, 1536] : tensor<?x?x1536xf32>
%198 = linalg.init_tensor [%183, 384, 1536] : tensor<?x384x1536xf32>
%199 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_75 : tensor<1536xf32>) outs(%197 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%200 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_74 : tensor<1536x384xf32>) outs(%198 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%201 = linalg.batch_matmul ins(%196, %200 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%199 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%202 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%201 : tensor<?x?x1536xf32>) outs(%197 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_113 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_112 : f32
%568 = arith.mulf %567, %cst_114 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%203 = linalg.init_tensor [%183, 1536, 384] : tensor<?x1536x384xf32>
%204 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_73 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%205 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_72 : tensor<384x1536xf32>) outs(%203 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%206 = linalg.batch_matmul ins(%202, %205 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%204 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%207 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%206, %196 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%208 = linalg.fill(%cst_116, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%209 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%207 : tensor<?x?x384xf32>) outs(%208 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%210 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%209 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%211 = linalg.fill(%cst_116, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%212 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%207, %210 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%211 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%213 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%212 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%214 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%207, %210, %213, %cst_70, %cst_71 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%215 = linalg.init_tensor [%183, 384, 384] : tensor<?x384x384xf32>
%216 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_69 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%217 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_68 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%218 = linalg.batch_matmul ins(%214, %217 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%216 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%219 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_67 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%220 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_66 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%221 = linalg.batch_matmul ins(%214, %220 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%219 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%222 = tensor.expand_shape %221 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%223 = linalg.init_tensor [%183, 12, %184, 32] : tensor<?x12x?x32xf32>
%224 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%222 : tensor<?x?x12x32xf32>) outs(%223 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%225 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%226 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%227 = linalg.batch_matmul ins(%214, %226 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%225 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%228 = tensor.expand_shape %227 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%229 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%228 : tensor<?x?x12x32xf32>) outs(%223 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%230 = tensor.expand_shape %218 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%231 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%230 : tensor<?x?x12x32xf32>) outs(%223 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%232 = linalg.init_tensor [%183, 12, 32, %184] : tensor<?x12x32x?xf32>
%233 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%224 : tensor<?x12x?x32xf32>) outs(%232 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%234 = linalg.init_tensor [%183, 12, %184, %184] : tensor<?x12x?x?xf32>
%235 = linalg.fill(%cst_116, %234) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%231, %233 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%235 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%237 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%236, %cst_21 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%238 = arith.cmpi eq, %183, %c1 : index
cf.assert %238, "mismatched size for broadcast"
%239 = arith.cmpi eq, %184, %19 : index
cf.assert %239, "mismatched size for broadcast"
%240 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%237, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%241 = linalg.init_tensor [%183, 12, %184, 1] : tensor<?x12x?x1xi64>
%242 = linalg.fill(%c0_i64, %241) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%243 = linalg.init_tensor [%183, 12, %184, 1] : tensor<?x12x?x1xf32>
%244 = linalg.fill(%cst_115, %243) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%245:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%240 : tensor<?x12x?x?xf32>) outs(%244, %242 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%246 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%240, %245#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%247 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%246 : tensor<?x12x?x?xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%248 = linalg.fill(%cst_116, %243) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%249 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%247 : tensor<?x12x?x?xf32>) outs(%248 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%250 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%247, %249 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%251 = linalg.fill(%cst_116, %223) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%252 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%250, %229 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%251 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%253 = linalg.init_tensor [%183, %184, 12, 32] : tensor<?x?x12x32xf32>
%254 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%252 : tensor<?x12x?x32xf32>) outs(%253 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%255 = tensor.cast %254 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%256 = tensor.collapse_shape %255 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%257 = tensor.cast %256 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%258 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%259 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_62 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%260 = linalg.batch_matmul ins(%257, %259 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%258 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%261 = tensor.dim %256, %c0_0 : tensor<?x?x?xf32>
%262 = tensor.dim %256, %c1 : tensor<?x?x?xf32>
%263 = arith.cmpi eq, %261, %183 : index
cf.assert %263, "mismatched size for broadcast"
%264 = arith.cmpi eq, %262, %184 : index
cf.assert %264, "mismatched size for broadcast"
%265 = linalg.init_tensor [%261, %262, 384] : tensor<?x?x384xf32>
%266 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%260, %214 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%267 = linalg.init_tensor [%261, %262] : tensor<?x?xf32>
%268 = linalg.fill(%cst_116, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%269 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%266 : tensor<?x?x384xf32>) outs(%268 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%270 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%269 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%271 = linalg.fill(%cst_116, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%272 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%266, %270 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%271 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%272 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%266, %270, %273, %cst_60, %cst_61 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%275 = linalg.init_tensor [%261, %262, 1536] : tensor<?x?x1536xf32>
%276 = linalg.init_tensor [%261, 384, 1536] : tensor<?x384x1536xf32>
%277 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_59 : tensor<1536xf32>) outs(%275 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%278 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_58 : tensor<1536x384xf32>) outs(%276 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%279 = linalg.batch_matmul ins(%274, %278 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%277 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%280 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%279 : tensor<?x?x1536xf32>) outs(%275 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_113 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_112 : f32
%568 = arith.mulf %567, %cst_114 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%281 = linalg.init_tensor [%261, 1536, 384] : tensor<?x1536x384xf32>
%282 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_57 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%283 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_56 : tensor<384x1536xf32>) outs(%281 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%284 = linalg.batch_matmul ins(%280, %283 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%282 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%285 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%284, %274 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%286 = linalg.fill(%cst_116, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%287 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%285 : tensor<?x?x384xf32>) outs(%286 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%288 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%287 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%289 = linalg.fill(%cst_116, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%290 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%285, %288 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%289 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%290 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%292 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%285, %288, %291, %cst_54, %cst_55 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%293 = linalg.init_tensor [%261, 384, 384] : tensor<?x384x384xf32>
%294 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_53 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%295 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%296 = linalg.batch_matmul ins(%292, %295 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%294 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%297 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_51 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%298 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_50 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%299 = linalg.batch_matmul ins(%292, %298 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%297 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%300 = tensor.expand_shape %299 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%301 = linalg.init_tensor [%261, 12, %262, 32] : tensor<?x12x?x32xf32>
%302 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%300 : tensor<?x?x12x32xf32>) outs(%301 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%303 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_49 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%304 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_48 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%305 = linalg.batch_matmul ins(%292, %304 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%303 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%306 = tensor.expand_shape %305 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%307 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%306 : tensor<?x?x12x32xf32>) outs(%301 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%308 = tensor.expand_shape %296 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%309 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%308 : tensor<?x?x12x32xf32>) outs(%301 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%310 = linalg.init_tensor [%261, 12, 32, %262] : tensor<?x12x32x?xf32>
%311 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%302 : tensor<?x12x?x32xf32>) outs(%310 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%312 = linalg.init_tensor [%261, 12, %262, %262] : tensor<?x12x?x?xf32>
%313 = linalg.fill(%cst_116, %312) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%314 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%309, %311 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%313 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%315 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%314, %cst_21 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%316 = arith.cmpi eq, %261, %c1 : index
cf.assert %316, "mismatched size for broadcast"
%317 = arith.cmpi eq, %262, %19 : index
cf.assert %317, "mismatched size for broadcast"
%318 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%315, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%319 = linalg.init_tensor [%261, 12, %262, 1] : tensor<?x12x?x1xi64>
%320 = linalg.fill(%c0_i64, %319) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%321 = linalg.init_tensor [%261, 12, %262, 1] : tensor<?x12x?x1xf32>
%322 = linalg.fill(%cst_115, %321) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%323:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%318 : tensor<?x12x?x?xf32>) outs(%322, %320 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%318, %323#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%325 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%324 : tensor<?x12x?x?xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%326 = linalg.fill(%cst_116, %321) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%327 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%325 : tensor<?x12x?x?xf32>) outs(%326 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%328 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%325, %327 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%329 = linalg.fill(%cst_116, %301) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%330 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%328, %307 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%329 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%331 = linalg.init_tensor [%261, %262, 12, 32] : tensor<?x?x12x32xf32>
%332 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%330 : tensor<?x12x?x32xf32>) outs(%331 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%333 = tensor.cast %332 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%334 = tensor.collapse_shape %333 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%335 = tensor.cast %334 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%336 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_47 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%337 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_46 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%338 = linalg.batch_matmul ins(%335, %337 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%336 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%339 = tensor.dim %334, %c0_0 : tensor<?x?x?xf32>
%340 = tensor.dim %334, %c1 : tensor<?x?x?xf32>
%341 = arith.cmpi eq, %339, %261 : index
cf.assert %341, "mismatched size for broadcast"
%342 = arith.cmpi eq, %340, %262 : index
cf.assert %342, "mismatched size for broadcast"
%343 = linalg.init_tensor [%339, %340, 384] : tensor<?x?x384xf32>
%344 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%338, %292 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%345 = linalg.init_tensor [%339, %340] : tensor<?x?xf32>
%346 = linalg.fill(%cst_116, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%347 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%344 : tensor<?x?x384xf32>) outs(%346 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%348 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%347 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%349 = linalg.fill(%cst_116, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%350 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%344, %348 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%349 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%351 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%350 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%352 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%344, %348, %351, %cst_44, %cst_45 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%353 = linalg.init_tensor [%339, %340, 1536] : tensor<?x?x1536xf32>
%354 = linalg.init_tensor [%339, 384, 1536] : tensor<?x384x1536xf32>
%355 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_43 : tensor<1536xf32>) outs(%353 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%356 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42 : tensor<1536x384xf32>) outs(%354 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%357 = linalg.batch_matmul ins(%352, %356 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%355 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%358 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%357 : tensor<?x?x1536xf32>) outs(%353 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_113 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_112 : f32
%568 = arith.mulf %567, %cst_114 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%359 = linalg.init_tensor [%339, 1536, 384] : tensor<?x1536x384xf32>
%360 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_41 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%361 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_40 : tensor<384x1536xf32>) outs(%359 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%362 = linalg.batch_matmul ins(%358, %361 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%360 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%363 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%362, %352 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%364 = linalg.fill(%cst_116, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%365 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%363 : tensor<?x?x384xf32>) outs(%364 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%366 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%365 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%367 = linalg.fill(%cst_116, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%368 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%363, %366 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%367 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%368 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%370 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%363, %366, %369, %cst_38, %cst_39 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%371 = linalg.init_tensor [%339, 384, 384] : tensor<?x384x384xf32>
%372 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_37 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%373 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_36 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%374 = linalg.batch_matmul ins(%370, %373 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%372 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%375 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_35 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%376 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_34 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%377 = linalg.batch_matmul ins(%370, %376 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%375 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%378 = tensor.expand_shape %377 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%379 = linalg.init_tensor [%339, 12, %340, 32] : tensor<?x12x?x32xf32>
%380 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%378 : tensor<?x?x12x32xf32>) outs(%379 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%381 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_33 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%382 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_32 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%383 = linalg.batch_matmul ins(%370, %382 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%381 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%384 = tensor.expand_shape %383 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%385 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%384 : tensor<?x?x12x32xf32>) outs(%379 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%386 = tensor.expand_shape %374 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%387 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%386 : tensor<?x?x12x32xf32>) outs(%379 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%388 = linalg.init_tensor [%339, 12, 32, %340] : tensor<?x12x32x?xf32>
%389 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%380 : tensor<?x12x?x32xf32>) outs(%388 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%390 = linalg.init_tensor [%339, 12, %340, %340] : tensor<?x12x?x?xf32>
%391 = linalg.fill(%cst_116, %390) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%392 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%387, %389 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%391 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%392, %cst_21 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%394 = arith.cmpi eq, %339, %c1 : index
cf.assert %394, "mismatched size for broadcast"
%395 = arith.cmpi eq, %340, %19 : index
cf.assert %395, "mismatched size for broadcast"
%396 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%393, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%397 = linalg.init_tensor [%339, 12, %340, 1] : tensor<?x12x?x1xi64>
%398 = linalg.fill(%c0_i64, %397) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%399 = linalg.init_tensor [%339, 12, %340, 1] : tensor<?x12x?x1xf32>
%400 = linalg.fill(%cst_115, %399) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%401:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%396 : tensor<?x12x?x?xf32>) outs(%400, %398 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%402 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%396, %401#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%403 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%402 : tensor<?x12x?x?xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%404 = linalg.fill(%cst_116, %399) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%405 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%403 : tensor<?x12x?x?xf32>) outs(%404 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%403, %405 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%407 = linalg.fill(%cst_116, %379) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%408 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%406, %385 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%407 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%409 = linalg.init_tensor [%339, %340, 12, 32] : tensor<?x?x12x32xf32>
%410 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%408 : tensor<?x12x?x32xf32>) outs(%409 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%411 = tensor.cast %410 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%412 = tensor.collapse_shape %411 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%413 = tensor.cast %412 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%414 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_31 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%415 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_30 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%416 = linalg.batch_matmul ins(%413, %415 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%414 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%417 = tensor.dim %412, %c0_0 : tensor<?x?x?xf32>
%418 = tensor.dim %412, %c1 : tensor<?x?x?xf32>
%419 = arith.cmpi eq, %417, %339 : index
cf.assert %419, "mismatched size for broadcast"
%420 = arith.cmpi eq, %418, %340 : index
cf.assert %420, "mismatched size for broadcast"
%421 = linalg.init_tensor [%417, %418, 384] : tensor<?x?x384xf32>
%422 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%416, %370 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%423 = linalg.init_tensor [%417, %418] : tensor<?x?xf32>
%424 = linalg.fill(%cst_116, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%425 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%422 : tensor<?x?x384xf32>) outs(%424 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%426 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%425 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%427 = linalg.fill(%cst_116, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%428 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%422, %426 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%427 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%429 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%428 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%430 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%422, %426, %429, %cst_28, %cst_29 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%431 = linalg.init_tensor [%417, %418, 1536] : tensor<?x?x1536xf32>
%432 = linalg.init_tensor [%417, 384, 1536] : tensor<?x384x1536xf32>
%433 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_27 : tensor<1536xf32>) outs(%431 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%434 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_26 : tensor<1536x384xf32>) outs(%432 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%435 = linalg.batch_matmul ins(%430, %434 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%433 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%436 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%435 : tensor<?x?x1536xf32>) outs(%431 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_113 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_112 : f32
%568 = arith.mulf %567, %cst_114 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%437 = linalg.init_tensor [%417, 1536, 384] : tensor<?x1536x384xf32>
%438 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_25 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%439 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_24 : tensor<384x1536xf32>) outs(%437 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%440 = linalg.batch_matmul ins(%436, %439 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%438 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%441 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%440, %430 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%442 = linalg.fill(%cst_116, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%443 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%441 : tensor<?x?x384xf32>) outs(%442 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%443 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%445 = linalg.fill(%cst_116, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%446 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%441, %444 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%445 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%447 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%446 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%448 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%441, %444, %447, %cst_22, %cst_23 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%449 = linalg.init_tensor [%417, 384, 384] : tensor<?x384x384xf32>
%450 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_20 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%451 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_19 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%452 = linalg.batch_matmul ins(%448, %451 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%450 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%453 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%454 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_17 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%455 = linalg.batch_matmul ins(%448, %454 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%453 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%456 = tensor.expand_shape %455 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%457 = linalg.init_tensor [%417, 12, %418, 32] : tensor<?x12x?x32xf32>
%458 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%456 : tensor<?x?x12x32xf32>) outs(%457 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%459 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%460 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_15 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%461 = linalg.batch_matmul ins(%448, %460 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%459 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%462 = tensor.expand_shape %461 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%463 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%462 : tensor<?x?x12x32xf32>) outs(%457 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%464 = tensor.expand_shape %452 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%465 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%464 : tensor<?x?x12x32xf32>) outs(%457 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%466 = linalg.init_tensor [%417, 12, 32, %418] : tensor<?x12x32x?xf32>
%467 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%458 : tensor<?x12x?x32xf32>) outs(%466 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%468 = linalg.init_tensor [%417, 12, %418, %418] : tensor<?x12x?x?xf32>
%469 = linalg.fill(%cst_116, %468) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%470 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%465, %467 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%469 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%471 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%470, %cst_21 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%472 = arith.cmpi eq, %417, %c1 : index
cf.assert %472, "mismatched size for broadcast"
%473 = arith.cmpi eq, %418, %19 : index
cf.assert %473, "mismatched size for broadcast"
%474 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%471, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%475 = linalg.init_tensor [%417, 12, %418, 1] : tensor<?x12x?x1xi64>
%476 = linalg.fill(%c0_i64, %475) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%477 = linalg.init_tensor [%417, 12, %418, 1] : tensor<?x12x?x1xf32>
%478 = linalg.fill(%cst_115, %477) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%479:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%474 : tensor<?x12x?x?xf32>) outs(%478, %476 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%480 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%474, %479#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%481 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%480 : tensor<?x12x?x?xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%482 = linalg.fill(%cst_116, %477) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%483 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%481 : tensor<?x12x?x?xf32>) outs(%482 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%481, %483 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%485 = linalg.fill(%cst_116, %457) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%486 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%484, %463 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%485 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%487 = linalg.init_tensor [%417, %418, 12, 32] : tensor<?x?x12x32xf32>
%488 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%486 : tensor<?x12x?x32xf32>) outs(%487 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%489 = tensor.cast %488 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%490 = tensor.collapse_shape %489 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%491 = tensor.cast %490 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%492 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_14 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%493 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_13 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%494 = linalg.batch_matmul ins(%491, %493 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%492 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%495 = tensor.dim %490, %c0_0 : tensor<?x?x?xf32>
%496 = tensor.dim %490, %c1 : tensor<?x?x?xf32>
%497 = arith.cmpi eq, %495, %417 : index
cf.assert %497, "mismatched size for broadcast"
%498 = arith.cmpi eq, %496, %418 : index
cf.assert %498, "mismatched size for broadcast"
%499 = linalg.init_tensor [%495, %496, 384] : tensor<?x?x384xf32>
%500 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%494, %448 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%501 = linalg.init_tensor [%495, %496] : tensor<?x?xf32>
%502 = linalg.fill(%cst_116, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%503 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%500 : tensor<?x?x384xf32>) outs(%502 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%504 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%503 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%505 = linalg.fill(%cst_116, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%506 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%500, %504 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%505 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%507 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%506 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%508 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%500, %504, %507, %cst_11, %cst_12 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%509 = linalg.init_tensor [%495, %496, 1536] : tensor<?x?x1536xf32>
%510 = linalg.init_tensor [%495, 384, 1536] : tensor<?x384x1536xf32>
%511 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_10 : tensor<1536xf32>) outs(%509 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%512 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_9 : tensor<1536x384xf32>) outs(%510 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%513 = linalg.batch_matmul ins(%508, %512 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%511 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%514 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%513 : tensor<?x?x1536xf32>) outs(%509 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_113 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_112 : f32
%568 = arith.mulf %567, %cst_114 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%515 = linalg.init_tensor [%495, 1536, 384] : tensor<?x1536x384xf32>
%516 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_8 : tensor<384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%517 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_7 : tensor<384x1536xf32>) outs(%515 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%518 = linalg.batch_matmul ins(%514, %517 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%516 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%519 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%518, %508 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%520 = linalg.fill(%cst_116, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%521 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%519 : tensor<?x?x384xf32>) outs(%520 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%522 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%521 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%523 = linalg.fill(%cst_116, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%524 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%519, %522 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%523 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%525 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%524 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%526 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%519, %522, %525, %cst_5, %cst_6 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_111 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%527 = arith.index_cast %495 : index to i64
%528 = arith.cmpi sgt, %c0_i64, %527 : i64
%529 = arith.select %528, %527, %c0_i64 : i64
%530 = arith.index_cast %529 : i64 to index
%531 = arith.cmpi sgt, %c9223372036854775807_i64, %527 : i64
%532 = arith.select %531, %527, %c9223372036854775807_i64 : i64
%533 = arith.index_cast %532 : i64 to index
%534 = arith.cmpi sge, %533, %530 : index
%535 = arith.select %534, %533, %530 : index
%536 = arith.subi %535, %530 : index
%537 = tensor.extract_slice %526[%530, 0, 0] [%536, %496, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%538 = arith.index_cast %496 : index to i64
%539 = arith.cmpi sgt, %c0_i64, %538 : i64
%540 = arith.select %539, %538, %c0_i64 : i64
%541 = arith.index_cast %540 : i64 to index
%542 = arith.cmpi sgt, %c1_i64, %538 : i64
%543 = arith.select %542, %538, %c1_i64 : i64
%544 = arith.index_cast %543 : i64 to index
%545 = arith.cmpi sge, %544, %541 : index
%546 = arith.select %545, %544, %541 : index
%547 = arith.subi %546, %541 : index
%548 = tensor.extract_slice %537[0, %541, 0] [%536, %547, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%549 = tensor.cast %548 : tensor<?x?x384xf32> to tensor<?x1x384xf32>
%550 = tensor.collapse_shape %549 [[0, 1], [2]] : tensor<?x1x384xf32> into tensor<?x384xf32>
%551 = linalg.init_tensor [%536, 384] : tensor<?x384xf32>
%552 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%553 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_4 : tensor<384xf32>) outs(%551 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384xf32>
%554 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_3 : tensor<384x384xf32>) outs(%552 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%555 = linalg.matmul ins(%550, %554 : tensor<?x384xf32>, tensor<384x384xf32>) outs(%553 : tensor<?x384xf32>) -> tensor<?x384xf32>
%556 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%555 : tensor<?x384xf32>) outs(%551 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.tanh %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x384xf32>
%557 = linalg.init_tensor [%536, 2] : tensor<?x2xf32>
%558 = linalg.init_tensor [384, 2] : tensor<384x2xf32>
%559 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_2 : tensor<2xf32>) outs(%557 : tensor<?x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x2xf32>
%560 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<2x384xf32>) outs(%558 : tensor<384x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x2xf32>
%561 = linalg.matmul ins(%556, %560 : tensor<?x384xf32>, tensor<384x2xf32>) outs(%559 : tensor<?x2xf32>) -> tensor<?x2xf32>
%562 = tensor.dim %561, %c0 : tensor<?x2xf32>
%563 = hal.tensor.export %561 : tensor<?x2xf32>{%562} -> !hal.buffer_view
return %563 : !hal.buffer_view
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c30522_i64 = arith.constant 30522 : i64
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant -3.40282347E+38 : f32
%cst_1 = arith.constant 5.000000e-01 : f32
%cst_2 = arith.constant 2.000000e+00 : f32
%cst_3 = arith.constant 1.000000e+00 : f32
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_4 = arith.constant 9.9999999999999998E-13 : f64
%c2_i64 = arith.constant 2 : i64
%cst_5 = arith.constant dense<-1.000000e+04> : tensor<f64>
%cst_6 = arith.constant dense<0> : tensor<1x512xi64>
%cst_7 = arith.constant dense<0> : tensor<i64>
%cst_8 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_20 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_21 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_94 = arith.constant dense<5.6568542494923806> : tensor<f64>
%cst_95 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_100 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_101 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_102 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_103 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_104 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_105 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_106 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_107 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_108 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_109 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_110 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_111 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_112 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_113 = arith.constant dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>
%cst_114 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%c512_i64 = arith.constant 512 : i64
%c1_i64 = arith.constant 1 : i64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c512 = arith.constant 512 : index
%cst_115 = arith.constant 3.840000e+02 : f32
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%1 = linalg.init_tensor [] : tensor<i64>
%2 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%3 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2, %cst_7 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%564 = arith.addi %arg1, %arg2 : i64
linalg.yield %564 : i64
} -> tensor<i64>
%4 = tensor.extract %3[] : tensor<i64>
%5 = arith.index_cast %4 : i64 to index
%6 = linalg.init_tensor [1, %5] : tensor<1x?xf32>
%7 = linalg.fill(%cst_3, %6) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
%8 = tensor.extract_slice %7[0, 0] [1, %5] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
%9 = tensor.cast %8 : tensor<1x?xf32> to tensor<?x?xf32>
%10 = tensor.expand_shape %9 [[0], [1, 2, 3]] : tensor<?x?xf32> into tensor<?x1x1x?xf32>
%11 = arith.cmpi sgt, %c0_i64, %4 : i64
%12 = arith.select %11, %4, %c0_i64 : i64
%13 = arith.index_cast %12 : i64 to index
%14 = arith.cmpi sgt, %c9223372036854775807_i64, %4 : i64
%15 = arith.select %14, %4, %c9223372036854775807_i64 : i64
%16 = arith.index_cast %15 : i64 to index
%17 = arith.cmpi sge, %16, %13 : index
%18 = arith.select %17, %16, %13 : index
%19 = arith.subi %18, %13 : index
%20 = tensor.extract_slice %10[0, 0, 0, %13] [1, 1, 1, %19] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<1x1x1x?xf32>
%21 = linalg.init_tensor [1, 1, 1, %19] : tensor<1x1x1x?xf32>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%20 : tensor<1x1x1x?xf32>) outs(%21 : tensor<1x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.subf %cst_3, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<1x1x1x?xf32>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%22, %cst_5 : tensor<1x1x1x?xf32>, tensor<f64>) outs(%21 : tensor<1x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<1x1x1x?xf32>
%24 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%25 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%24, %cst_7 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%564 = arith.addi %arg1, %arg2 : i64
linalg.yield %564 : i64
} -> tensor<i64>
%26 = tensor.extract %25[] : tensor<i64>
%27 = arith.addi %26, %c512_i64 : i64
%28 = arith.cmpi sge, %26, %c0_i64 : i64
%29 = arith.select %28, %26, %27 : i64
%30 = arith.cmpi slt, %29, %c0_i64 : i64
%31 = arith.select %30, %c0_i64, %29 : i64
%32 = arith.cmpi sgt, %31, %c512_i64 : i64
%33 = arith.select %32, %c512_i64, %31 : i64
%34 = arith.index_cast %33 : i64 to index
%35 = arith.cmpi sge, %34, %c0 : index
%36 = arith.select %35, %34, %c0 : index
%37 = tensor.extract_slice %cst_8[0, 0] [1, %36] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
%38 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%39 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<1x512xi64>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%564 = arith.index_cast %arg1 : i64 to index
%565 = linalg.index 2 : index
%566 = arith.cmpi slt, %arg1, %c30522_i64 : i64
cf.assert %566, "index must be smaller than dim size"
%567 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %567, "index must be larger or equal to 0"
%568 = tensor.extract %cst_9[%564, %565] : tensor<30522x384xf32>
linalg.yield %568 : f32
} -> tensor<1x512x384xf32>
%40 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_6 : tensor<1x512xi64>) outs(%40 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%564 = arith.index_cast %arg1 : i64 to index
%565 = linalg.index 2 : index
%566 = arith.cmpi slt, %arg1, %c2_i64 : i64
cf.assert %566, "index must be smaller than dim size"
%567 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %567, "index must be larger or equal to 0"
%568 = tensor.extract %cst_10[%564, %565] : tensor<2x384xf32>
linalg.yield %568 : f32
} -> tensor<1x512x384xf32>
%42 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%39, %41 : tensor<1x512x384xf32>, tensor<1x512x384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x512x384xf32>
%44 = linalg.init_tensor [1, %36, 384] : tensor<1x?x384xf32>
%45 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%37 : tensor<1x?xi64>) outs(%44 : tensor<1x?x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%564 = arith.index_cast %arg1 : i64 to index
%565 = linalg.index 2 : index
%566 = arith.cmpi slt, %arg1, %c512_i64 : i64
cf.assert %566, "index must be smaller than dim size"
%567 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %567, "index must be larger or equal to 0"
%568 = tensor.extract %cst_11[%564, %565] : tensor<512x384xf32>
linalg.yield %568 : f32
} -> tensor<1x?x384xf32>
%46 = arith.cmpi eq, %c512, %36 : index
cf.assert %46, "mismatched size for broadcast"
%47 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %45 : tensor<1x512x384xf32>, tensor<1x?x384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x512x384xf32>
%48 = linalg.init_tensor [1, 512] : tensor<1x512xf32>
%49 = linalg.fill(%cst, %48) : f32, tensor<1x512xf32> -> tensor<1x512xf32>
%50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%47 : tensor<1x512x384xf32>) outs(%49 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<1x512xf32>
%51 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%50 : tensor<1x512xf32>) outs(%48 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<1x512xf32>
%52 = linalg.fill(%cst, %48) : f32, tensor<1x512xf32> -> tensor<1x512xf32>
%53 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%47, %51 : tensor<1x512x384xf32>, tensor<1x512xf32>) outs(%52 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<1x512xf32>
%54 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%53 : tensor<1x512xf32>) outs(%48 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<1x512xf32>
%55 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47, %51, %54, %cst_13, %cst_12 : tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<1x512x384xf32>
%56 = linalg.init_tensor [1, 384, 384] : tensor<1x384x384xf32>
%57 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_14 : tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%58 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_15 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%59 = linalg.batch_matmul ins(%55, %58 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%57 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%60 = tensor.cast %59 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%61 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16 : tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%62 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_17 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%63 = linalg.batch_matmul ins(%55, %62 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%61 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%64 = tensor.cast %63 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%65 = tensor.expand_shape %64 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%66 = linalg.init_tensor [1, 12, 512, 32] : tensor<1x12x512x32xf32>
%67 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%65 : tensor<?x?x12x32xf32>) outs(%66 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%68 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18 : tensor<384xf32>) outs(%42 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%69 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_19 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%70 = linalg.batch_matmul ins(%55, %69 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%68 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%71 = tensor.cast %70 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%72 = tensor.expand_shape %71 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%73 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%72 : tensor<?x?x12x32xf32>) outs(%66 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%74 = tensor.expand_shape %60 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%75 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%74 : tensor<?x?x12x32xf32>) outs(%66 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%76 = linalg.init_tensor [1, 12, 32, 512] : tensor<1x12x32x512xf32>
%77 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%67 : tensor<1x12x512x32xf32>) outs(%76 : tensor<1x12x32x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x32x512xf32>
%78 = linalg.init_tensor [1, 12, 512, 512] : tensor<1x12x512x512xf32>
%79 = linalg.fill(%cst, %78) : f32, tensor<1x12x512x512xf32> -> tensor<1x12x512x512xf32>
%80 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%75, %77 : tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>) outs(%79 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<1x12x512x512xf32>
%81 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%80, %cst_94 : tensor<1x12x512x512xf32>, tensor<f64>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<1x12x512x512xf32>
%82 = arith.cmpi eq, %c512, %19 : index
cf.assert %82, "mismatched size for broadcast"
%83 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%81, %23 : tensor<1x12x512x512xf32>, tensor<1x1x1x?xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%84 = linalg.init_tensor [1, 12, 512, 1] : tensor<1x12x512x1xi64>
%85 = linalg.fill(%c0_i64, %84) : i64, tensor<1x12x512x1xi64> -> tensor<1x12x512x1xi64>
%86 = linalg.init_tensor [1, 12, 512, 1] : tensor<1x12x512x1xf32>
%87 = linalg.fill(%cst_0, %86) : f32, tensor<1x12x512x1xf32> -> tensor<1x12x512x1xf32>
%88:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%83 : tensor<1x12x512x512xf32>) outs(%87, %85 : tensor<1x12x512x1xf32>, tensor<1x12x512x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<1x12x512x1xf32>, tensor<1x12x512x1xi64>)
%89 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%83, %88#0 : tensor<1x12x512x512xf32>, tensor<1x12x512x1xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%90 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%89 : tensor<1x12x512x512xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%91 = linalg.fill(%cst, %86) : f32, tensor<1x12x512x1xf32> -> tensor<1x12x512x1xf32>
%92 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%90 : tensor<1x12x512x512xf32>) outs(%91 : tensor<1x12x512x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x1xf32>
%93 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%90, %92 : tensor<1x12x512x512xf32>, tensor<1x12x512x1xf32>) outs(%78 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<1x12x512x512xf32>
%94 = linalg.fill(%cst, %66) : f32, tensor<1x12x512x32xf32> -> tensor<1x12x512x32xf32>
%95 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%93, %73 : tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>) outs(%94 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<1x12x512x32xf32>
%96 = linalg.init_tensor [1, 512, 12, 32] : tensor<1x512x12x32xf32>
%97 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%95 : tensor<1x12x512x32xf32>) outs(%96 : tensor<1x512x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x12x32xf32>
%98 = tensor.cast %97 : tensor<1x512x12x32xf32> to tensor<?x?x?x?xf32>
%99 = tensor.collapse_shape %98 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%100 = tensor.cast %99 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%101 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%102 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_20 : tensor<384xf32>) outs(%101 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%103 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_21 : tensor<384x384xf32>) outs(%56 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%104 = linalg.batch_matmul ins(%100, %103 : tensor<?x?x384xf32>, tensor<1x384x384xf32>) outs(%102 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%105 = tensor.dim %99, %c0 : tensor<?x?x?xf32>
%106 = tensor.dim %99, %c1 : tensor<?x?x?xf32>
%107 = arith.cmpi eq, %105, %c1 : index
cf.assert %107, "mismatched size for broadcast"
%108 = arith.cmpi eq, %106, %c512 : index
cf.assert %108, "mismatched size for broadcast"
%109 = linalg.init_tensor [%105, %106, 384] : tensor<?x?x384xf32>
%110 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%104, %55 : tensor<1x512x384xf32>, tensor<1x512x384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%111 = linalg.init_tensor [%105, %106] : tensor<?x?xf32>
%112 = linalg.fill(%cst, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%113 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%110 : tensor<?x?x384xf32>) outs(%112 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%114 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%113 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%115 = linalg.fill(%cst, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%116 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%110, %114 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%115 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%117 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%116 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%118 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%110, %114, %117, %cst_23, %cst_22 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%119 = linalg.init_tensor [%105, %106, 1536] : tensor<?x?x1536xf32>
%120 = linalg.init_tensor [%105, 384, 1536] : tensor<?x384x1536xf32>
%121 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_24 : tensor<1536xf32>) outs(%119 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%122 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_25 : tensor<1536x384xf32>) outs(%120 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%123 = linalg.batch_matmul ins(%118, %122 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%121 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%124 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%123 : tensor<?x?x1536xf32>) outs(%119 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_2 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_3 : f32
%568 = arith.mulf %567, %cst_1 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%125 = linalg.init_tensor [%105, 1536, 384] : tensor<?x1536x384xf32>
%126 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_26 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%127 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_27 : tensor<384x1536xf32>) outs(%125 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%128 = linalg.batch_matmul ins(%124, %127 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%126 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%129 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%128, %118 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%130 = linalg.fill(%cst, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%131 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%129 : tensor<?x?x384xf32>) outs(%130 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%132 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%131 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%133 = linalg.fill(%cst, %111) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%134 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%129, %132 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%133 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%135 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%134 : tensor<?x?xf32>) outs(%111 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%136 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%129, %132, %135, %cst_29, %cst_28 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%137 = linalg.init_tensor [%105, 384, 384] : tensor<?x384x384xf32>
%138 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_30 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%139 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_31 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%140 = linalg.batch_matmul ins(%136, %139 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%138 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%141 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_32 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%142 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_33 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%143 = linalg.batch_matmul ins(%136, %142 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%141 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%144 = tensor.expand_shape %143 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%145 = linalg.init_tensor [%105, 12, %106, 32] : tensor<?x12x?x32xf32>
%146 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%144 : tensor<?x?x12x32xf32>) outs(%145 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%147 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_34 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%148 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_35 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%149 = linalg.batch_matmul ins(%136, %148 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%147 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%150 = tensor.expand_shape %149 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%151 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%150 : tensor<?x?x12x32xf32>) outs(%145 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%152 = tensor.expand_shape %140 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%153 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%152 : tensor<?x?x12x32xf32>) outs(%145 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%154 = linalg.init_tensor [%105, 12, 32, %106] : tensor<?x12x32x?xf32>
%155 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%146 : tensor<?x12x?x32xf32>) outs(%154 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%156 = linalg.init_tensor [%105, 12, %106, %106] : tensor<?x12x?x?xf32>
%157 = linalg.fill(%cst, %156) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%158 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%153, %155 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%157 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%159 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%158, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%160 = arith.cmpi eq, %105, %c1 : index
cf.assert %160, "mismatched size for broadcast"
%161 = arith.cmpi eq, %106, %19 : index
cf.assert %161, "mismatched size for broadcast"
%162 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%159, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%163 = linalg.init_tensor [%105, 12, %106, 1] : tensor<?x12x?x1xi64>
%164 = linalg.fill(%c0_i64, %163) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%165 = linalg.init_tensor [%105, 12, %106, 1] : tensor<?x12x?x1xf32>
%166 = linalg.fill(%cst_0, %165) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%167:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%162 : tensor<?x12x?x?xf32>) outs(%166, %164 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%168 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%162, %167#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%169 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%168 : tensor<?x12x?x?xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%170 = linalg.fill(%cst, %165) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%171 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%169 : tensor<?x12x?x?xf32>) outs(%170 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%172 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%169, %171 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%156 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%173 = linalg.fill(%cst, %145) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%174 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%172, %151 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%173 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%175 = linalg.init_tensor [%105, %106, 12, 32] : tensor<?x?x12x32xf32>
%176 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%174 : tensor<?x12x?x32xf32>) outs(%175 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%177 = tensor.cast %176 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%178 = tensor.collapse_shape %177 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%179 = tensor.cast %178 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%180 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_36 : tensor<384xf32>) outs(%109 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%181 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_37 : tensor<384x384xf32>) outs(%137 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%182 = linalg.batch_matmul ins(%179, %181 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%180 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%183 = tensor.dim %178, %c0 : tensor<?x?x?xf32>
%184 = tensor.dim %178, %c1 : tensor<?x?x?xf32>
%185 = arith.cmpi eq, %183, %105 : index
cf.assert %185, "mismatched size for broadcast"
%186 = arith.cmpi eq, %184, %106 : index
cf.assert %186, "mismatched size for broadcast"
%187 = linalg.init_tensor [%183, %184, 384] : tensor<?x?x384xf32>
%188 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%182, %136 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%189 = linalg.init_tensor [%183, %184] : tensor<?x?xf32>
%190 = linalg.fill(%cst, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%191 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%188 : tensor<?x?x384xf32>) outs(%190 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%192 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%191 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%193 = linalg.fill(%cst, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%194 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%188, %192 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%193 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%195 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%194 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%196 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%188, %192, %195, %cst_39, %cst_38 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%197 = linalg.init_tensor [%183, %184, 1536] : tensor<?x?x1536xf32>
%198 = linalg.init_tensor [%183, 384, 1536] : tensor<?x384x1536xf32>
%199 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_40 : tensor<1536xf32>) outs(%197 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%200 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_41 : tensor<1536x384xf32>) outs(%198 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%201 = linalg.batch_matmul ins(%196, %200 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%199 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%202 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%201 : tensor<?x?x1536xf32>) outs(%197 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_2 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_3 : f32
%568 = arith.mulf %567, %cst_1 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%203 = linalg.init_tensor [%183, 1536, 384] : tensor<?x1536x384xf32>
%204 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%205 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_43 : tensor<384x1536xf32>) outs(%203 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%206 = linalg.batch_matmul ins(%202, %205 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%204 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%207 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%206, %196 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%208 = linalg.fill(%cst, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%209 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%207 : tensor<?x?x384xf32>) outs(%208 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%210 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%209 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%211 = linalg.fill(%cst, %189) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%212 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%207, %210 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%211 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%213 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%212 : tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%214 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%207, %210, %213, %cst_45, %cst_44 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%215 = linalg.init_tensor [%183, 384, 384] : tensor<?x384x384xf32>
%216 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_46 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%217 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_47 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%218 = linalg.batch_matmul ins(%214, %217 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%216 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%219 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_48 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%220 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_49 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%221 = linalg.batch_matmul ins(%214, %220 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%219 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%222 = tensor.expand_shape %221 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%223 = linalg.init_tensor [%183, 12, %184, 32] : tensor<?x12x?x32xf32>
%224 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%222 : tensor<?x?x12x32xf32>) outs(%223 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%225 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_50 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%226 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_51 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%227 = linalg.batch_matmul ins(%214, %226 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%225 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%228 = tensor.expand_shape %227 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%229 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%228 : tensor<?x?x12x32xf32>) outs(%223 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%230 = tensor.expand_shape %218 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%231 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%230 : tensor<?x?x12x32xf32>) outs(%223 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%232 = linalg.init_tensor [%183, 12, 32, %184] : tensor<?x12x32x?xf32>
%233 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%224 : tensor<?x12x?x32xf32>) outs(%232 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%234 = linalg.init_tensor [%183, 12, %184, %184] : tensor<?x12x?x?xf32>
%235 = linalg.fill(%cst, %234) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%231, %233 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%235 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%237 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%236, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%238 = arith.cmpi eq, %183, %c1 : index
cf.assert %238, "mismatched size for broadcast"
%239 = arith.cmpi eq, %184, %19 : index
cf.assert %239, "mismatched size for broadcast"
%240 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%237, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%241 = linalg.init_tensor [%183, 12, %184, 1] : tensor<?x12x?x1xi64>
%242 = linalg.fill(%c0_i64, %241) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%243 = linalg.init_tensor [%183, 12, %184, 1] : tensor<?x12x?x1xf32>
%244 = linalg.fill(%cst_0, %243) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%245:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%240 : tensor<?x12x?x?xf32>) outs(%244, %242 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%246 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%240, %245#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%247 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%246 : tensor<?x12x?x?xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%248 = linalg.fill(%cst, %243) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%249 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%247 : tensor<?x12x?x?xf32>) outs(%248 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%250 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%247, %249 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%251 = linalg.fill(%cst, %223) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%252 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%250, %229 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%251 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%253 = linalg.init_tensor [%183, %184, 12, 32] : tensor<?x?x12x32xf32>
%254 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%252 : tensor<?x12x?x32xf32>) outs(%253 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%255 = tensor.cast %254 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%256 = tensor.collapse_shape %255 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%257 = tensor.cast %256 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%258 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52 : tensor<384xf32>) outs(%187 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%259 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_53 : tensor<384x384xf32>) outs(%215 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%260 = linalg.batch_matmul ins(%257, %259 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%258 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%261 = tensor.dim %256, %c0 : tensor<?x?x?xf32>
%262 = tensor.dim %256, %c1 : tensor<?x?x?xf32>
%263 = arith.cmpi eq, %261, %183 : index
cf.assert %263, "mismatched size for broadcast"
%264 = arith.cmpi eq, %262, %184 : index
cf.assert %264, "mismatched size for broadcast"
%265 = linalg.init_tensor [%261, %262, 384] : tensor<?x?x384xf32>
%266 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%260, %214 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%267 = linalg.init_tensor [%261, %262] : tensor<?x?xf32>
%268 = linalg.fill(%cst, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%269 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%266 : tensor<?x?x384xf32>) outs(%268 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%270 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%269 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%271 = linalg.fill(%cst, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%272 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%266, %270 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%271 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%272 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%266, %270, %273, %cst_55, %cst_54 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%275 = linalg.init_tensor [%261, %262, 1536] : tensor<?x?x1536xf32>
%276 = linalg.init_tensor [%261, 384, 1536] : tensor<?x384x1536xf32>
%277 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_56 : tensor<1536xf32>) outs(%275 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%278 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_57 : tensor<1536x384xf32>) outs(%276 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%279 = linalg.batch_matmul ins(%274, %278 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%277 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%280 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%279 : tensor<?x?x1536xf32>) outs(%275 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_2 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_3 : f32
%568 = arith.mulf %567, %cst_1 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%281 = linalg.init_tensor [%261, 1536, 384] : tensor<?x1536x384xf32>
%282 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_58 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%283 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_59 : tensor<384x1536xf32>) outs(%281 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%284 = linalg.batch_matmul ins(%280, %283 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%282 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%285 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%284, %274 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%286 = linalg.fill(%cst, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%287 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%285 : tensor<?x?x384xf32>) outs(%286 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%288 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%287 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%289 = linalg.fill(%cst, %267) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%290 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%285, %288 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%289 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%290 : tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%292 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%285, %288, %291, %cst_61, %cst_60 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%293 = linalg.init_tensor [%261, 384, 384] : tensor<?x384x384xf32>
%294 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_62 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%295 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%296 = linalg.batch_matmul ins(%292, %295 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%294 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%297 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%298 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%299 = linalg.batch_matmul ins(%292, %298 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%297 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%300 = tensor.expand_shape %299 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%301 = linalg.init_tensor [%261, 12, %262, 32] : tensor<?x12x?x32xf32>
%302 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%300 : tensor<?x?x12x32xf32>) outs(%301 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%303 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_66 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%304 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_67 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%305 = linalg.batch_matmul ins(%292, %304 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%303 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%306 = tensor.expand_shape %305 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%307 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%306 : tensor<?x?x12x32xf32>) outs(%301 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%308 = tensor.expand_shape %296 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%309 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%308 : tensor<?x?x12x32xf32>) outs(%301 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%310 = linalg.init_tensor [%261, 12, 32, %262] : tensor<?x12x32x?xf32>
%311 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%302 : tensor<?x12x?x32xf32>) outs(%310 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%312 = linalg.init_tensor [%261, 12, %262, %262] : tensor<?x12x?x?xf32>
%313 = linalg.fill(%cst, %312) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%314 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%309, %311 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%313 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%315 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%314, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%316 = arith.cmpi eq, %261, %c1 : index
cf.assert %316, "mismatched size for broadcast"
%317 = arith.cmpi eq, %262, %19 : index
cf.assert %317, "mismatched size for broadcast"
%318 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%315, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%319 = linalg.init_tensor [%261, 12, %262, 1] : tensor<?x12x?x1xi64>
%320 = linalg.fill(%c0_i64, %319) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%321 = linalg.init_tensor [%261, 12, %262, 1] : tensor<?x12x?x1xf32>
%322 = linalg.fill(%cst_0, %321) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%323:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%318 : tensor<?x12x?x?xf32>) outs(%322, %320 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%318, %323#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%325 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%324 : tensor<?x12x?x?xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%326 = linalg.fill(%cst, %321) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%327 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%325 : tensor<?x12x?x?xf32>) outs(%326 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%328 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%325, %327 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%312 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%329 = linalg.fill(%cst, %301) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%330 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%328, %307 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%329 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%331 = linalg.init_tensor [%261, %262, 12, 32] : tensor<?x?x12x32xf32>
%332 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%330 : tensor<?x12x?x32xf32>) outs(%331 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%333 = tensor.cast %332 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%334 = tensor.collapse_shape %333 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%335 = tensor.cast %334 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%336 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_68 : tensor<384xf32>) outs(%265 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%337 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_69 : tensor<384x384xf32>) outs(%293 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%338 = linalg.batch_matmul ins(%335, %337 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%336 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%339 = tensor.dim %334, %c0 : tensor<?x?x?xf32>
%340 = tensor.dim %334, %c1 : tensor<?x?x?xf32>
%341 = arith.cmpi eq, %339, %261 : index
cf.assert %341, "mismatched size for broadcast"
%342 = arith.cmpi eq, %340, %262 : index
cf.assert %342, "mismatched size for broadcast"
%343 = linalg.init_tensor [%339, %340, 384] : tensor<?x?x384xf32>
%344 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%338, %292 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%345 = linalg.init_tensor [%339, %340] : tensor<?x?xf32>
%346 = linalg.fill(%cst, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%347 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%344 : tensor<?x?x384xf32>) outs(%346 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%348 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%347 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%349 = linalg.fill(%cst, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%350 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%344, %348 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%349 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%351 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%350 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%352 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%344, %348, %351, %cst_71, %cst_70 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%353 = linalg.init_tensor [%339, %340, 1536] : tensor<?x?x1536xf32>
%354 = linalg.init_tensor [%339, 384, 1536] : tensor<?x384x1536xf32>
%355 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_72 : tensor<1536xf32>) outs(%353 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%356 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_73 : tensor<1536x384xf32>) outs(%354 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%357 = linalg.batch_matmul ins(%352, %356 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%355 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%358 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%357 : tensor<?x?x1536xf32>) outs(%353 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_2 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_3 : f32
%568 = arith.mulf %567, %cst_1 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%359 = linalg.init_tensor [%339, 1536, 384] : tensor<?x1536x384xf32>
%360 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_74 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%361 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_75 : tensor<384x1536xf32>) outs(%359 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%362 = linalg.batch_matmul ins(%358, %361 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%360 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%363 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%362, %352 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%364 = linalg.fill(%cst, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%365 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%363 : tensor<?x?x384xf32>) outs(%364 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%366 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%365 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%367 = linalg.fill(%cst, %345) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%368 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%363, %366 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%367 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%368 : tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%370 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%363, %366, %369, %cst_77, %cst_76 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%371 = linalg.init_tensor [%339, 384, 384] : tensor<?x384x384xf32>
%372 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_78 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%373 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_79 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%374 = linalg.batch_matmul ins(%370, %373 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%372 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%375 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_80 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%376 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_81 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%377 = linalg.batch_matmul ins(%370, %376 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%375 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%378 = tensor.expand_shape %377 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%379 = linalg.init_tensor [%339, 12, %340, 32] : tensor<?x12x?x32xf32>
%380 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%378 : tensor<?x?x12x32xf32>) outs(%379 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%381 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_82 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%382 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_83 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%383 = linalg.batch_matmul ins(%370, %382 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%381 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%384 = tensor.expand_shape %383 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%385 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%384 : tensor<?x?x12x32xf32>) outs(%379 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%386 = tensor.expand_shape %374 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%387 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%386 : tensor<?x?x12x32xf32>) outs(%379 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%388 = linalg.init_tensor [%339, 12, 32, %340] : tensor<?x12x32x?xf32>
%389 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%380 : tensor<?x12x?x32xf32>) outs(%388 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%390 = linalg.init_tensor [%339, 12, %340, %340] : tensor<?x12x?x?xf32>
%391 = linalg.fill(%cst, %390) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%392 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%387, %389 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%391 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%392, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%394 = arith.cmpi eq, %339, %c1 : index
cf.assert %394, "mismatched size for broadcast"
%395 = arith.cmpi eq, %340, %19 : index
cf.assert %395, "mismatched size for broadcast"
%396 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%393, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%397 = linalg.init_tensor [%339, 12, %340, 1] : tensor<?x12x?x1xi64>
%398 = linalg.fill(%c0_i64, %397) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%399 = linalg.init_tensor [%339, 12, %340, 1] : tensor<?x12x?x1xf32>
%400 = linalg.fill(%cst_0, %399) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%401:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%396 : tensor<?x12x?x?xf32>) outs(%400, %398 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%402 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%396, %401#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%403 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%402 : tensor<?x12x?x?xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%404 = linalg.fill(%cst, %399) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%405 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%403 : tensor<?x12x?x?xf32>) outs(%404 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%403, %405 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%390 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%407 = linalg.fill(%cst, %379) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%408 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%406, %385 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%407 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%409 = linalg.init_tensor [%339, %340, 12, 32] : tensor<?x?x12x32xf32>
%410 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%408 : tensor<?x12x?x32xf32>) outs(%409 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%411 = tensor.cast %410 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%412 = tensor.collapse_shape %411 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%413 = tensor.cast %412 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%414 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_84 : tensor<384xf32>) outs(%343 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%415 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_85 : tensor<384x384xf32>) outs(%371 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%416 = linalg.batch_matmul ins(%413, %415 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%414 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%417 = tensor.dim %412, %c0 : tensor<?x?x?xf32>
%418 = tensor.dim %412, %c1 : tensor<?x?x?xf32>
%419 = arith.cmpi eq, %417, %339 : index
cf.assert %419, "mismatched size for broadcast"
%420 = arith.cmpi eq, %418, %340 : index
cf.assert %420, "mismatched size for broadcast"
%421 = linalg.init_tensor [%417, %418, 384] : tensor<?x?x384xf32>
%422 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%416, %370 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%423 = linalg.init_tensor [%417, %418] : tensor<?x?xf32>
%424 = linalg.fill(%cst, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%425 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%422 : tensor<?x?x384xf32>) outs(%424 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%426 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%425 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%427 = linalg.fill(%cst, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%428 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%422, %426 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%427 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%429 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%428 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%430 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%422, %426, %429, %cst_87, %cst_86 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%431 = linalg.init_tensor [%417, %418, 1536] : tensor<?x?x1536xf32>
%432 = linalg.init_tensor [%417, 384, 1536] : tensor<?x384x1536xf32>
%433 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_88 : tensor<1536xf32>) outs(%431 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%434 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_89 : tensor<1536x384xf32>) outs(%432 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%435 = linalg.batch_matmul ins(%430, %434 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%433 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%436 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%435 : tensor<?x?x1536xf32>) outs(%431 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_2 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_3 : f32
%568 = arith.mulf %567, %cst_1 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%437 = linalg.init_tensor [%417, 1536, 384] : tensor<?x1536x384xf32>
%438 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_90 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%439 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_91 : tensor<384x1536xf32>) outs(%437 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%440 = linalg.batch_matmul ins(%436, %439 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%438 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%441 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%440, %430 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%442 = linalg.fill(%cst, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%443 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%441 : tensor<?x?x384xf32>) outs(%442 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%443 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%445 = linalg.fill(%cst, %423) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%446 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%441, %444 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%445 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%447 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%446 : tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%448 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%441, %444, %447, %cst_93, %cst_92 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%449 = linalg.init_tensor [%417, 384, 384] : tensor<?x384x384xf32>
%450 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_95 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%451 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_96 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%452 = linalg.batch_matmul ins(%448, %451 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%450 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%453 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_97 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%454 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_98 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%455 = linalg.batch_matmul ins(%448, %454 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%453 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%456 = tensor.expand_shape %455 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%457 = linalg.init_tensor [%417, 12, %418, 32] : tensor<?x12x?x32xf32>
%458 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%456 : tensor<?x?x12x32xf32>) outs(%457 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%459 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_99 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%460 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_100 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%461 = linalg.batch_matmul ins(%448, %460 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%459 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%462 = tensor.expand_shape %461 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%463 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%462 : tensor<?x?x12x32xf32>) outs(%457 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%464 = tensor.expand_shape %452 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%465 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%464 : tensor<?x?x12x32xf32>) outs(%457 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%466 = linalg.init_tensor [%417, 12, 32, %418] : tensor<?x12x32x?xf32>
%467 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%458 : tensor<?x12x?x32xf32>) outs(%466 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%468 = linalg.init_tensor [%417, 12, %418, %418] : tensor<?x12x?x?xf32>
%469 = linalg.fill(%cst, %468) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%470 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%465, %467 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%469 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%471 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%470, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%564 = arith.truncf %arg2 : f64 to f32
%565 = arith.divf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x?xf32>
%472 = arith.cmpi eq, %417, %c1 : index
cf.assert %472, "mismatched size for broadcast"
%473 = arith.cmpi eq, %418, %19 : index
cf.assert %473, "mismatched size for broadcast"
%474 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%471, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%475 = linalg.init_tensor [%417, 12, %418, 1] : tensor<?x12x?x1xi64>
%476 = linalg.fill(%c0_i64, %475) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%477 = linalg.init_tensor [%417, 12, %418, 1] : tensor<?x12x?x1xf32>
%478 = linalg.fill(%cst_0, %477) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%479:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%474 : tensor<?x12x?x?xf32>) outs(%478, %476 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%564 = linalg.index 3 : index
%565 = arith.index_cast %564 : index to i64
%566 = arith.cmpf ogt, %arg1, %arg2 : f32
%567 = arith.select %566, %arg1, %arg2 : f32
%568 = arith.select %566, %565, %arg3 : i64
linalg.yield %567, %568 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%480 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%474, %479#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%481 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%480 : tensor<?x12x?x?xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.exp %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%482 = linalg.fill(%cst, %477) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%483 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%481 : tensor<?x12x?x?xf32>) outs(%482 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x1xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%481, %483 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%468 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.divf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x12x?x?xf32>
%485 = linalg.fill(%cst, %457) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%486 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%484, %463 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%485 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.mulf %arg1, %arg2 : f32
%565 = arith.addf %564, %arg3 : f32
linalg.yield %565 : f32
} -> tensor<?x12x?x32xf32>
%487 = linalg.init_tensor [%417, %418, 12, 32] : tensor<?x?x12x32xf32>
%488 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%486 : tensor<?x12x?x32xf32>) outs(%487 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%489 = tensor.cast %488 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%490 = tensor.collapse_shape %489 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%491 = tensor.cast %490 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%492 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_101 : tensor<384xf32>) outs(%421 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%493 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_102 : tensor<384x384xf32>) outs(%449 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%494 = linalg.batch_matmul ins(%491, %493 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%492 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%495 = tensor.dim %490, %c0 : tensor<?x?x?xf32>
%496 = tensor.dim %490, %c1 : tensor<?x?x?xf32>
%497 = arith.cmpi eq, %495, %417 : index
cf.assert %497, "mismatched size for broadcast"
%498 = arith.cmpi eq, %496, %418 : index
cf.assert %498, "mismatched size for broadcast"
%499 = linalg.init_tensor [%495, %496, 384] : tensor<?x?x384xf32>
%500 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%494, %448 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%501 = linalg.init_tensor [%495, %496] : tensor<?x?xf32>
%502 = linalg.fill(%cst, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%503 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%500 : tensor<?x?x384xf32>) outs(%502 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%504 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%503 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%505 = linalg.fill(%cst, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%506 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%500, %504 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%505 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%507 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%506 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%508 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%500, %504, %507, %cst_104, %cst_103 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%509 = linalg.init_tensor [%495, %496, 1536] : tensor<?x?x1536xf32>
%510 = linalg.init_tensor [%495, 384, 1536] : tensor<?x384x1536xf32>
%511 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_105 : tensor<1536xf32>) outs(%509 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%512 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_106 : tensor<1536x384xf32>) outs(%510 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%513 = linalg.batch_matmul ins(%508, %512 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%511 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%514 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%513 : tensor<?x?x1536xf32>) outs(%509 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.sqrt %cst_2 : f32
%565 = arith.divf %arg1, %564 : f32
%566 = math.erf %565 : f32
%567 = arith.addf %566, %cst_3 : f32
%568 = arith.mulf %567, %cst_1 : f32
%569 = arith.mulf %arg1, %568 : f32
linalg.yield %569 : f32
} -> tensor<?x?x1536xf32>
%515 = linalg.init_tensor [%495, 1536, 384] : tensor<?x1536x384xf32>
%516 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_107 : tensor<384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%517 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_108 : tensor<384x1536xf32>) outs(%515 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%518 = linalg.batch_matmul ins(%514, %517 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%516 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%519 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%518, %508 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.addf %arg1, %arg2 : f32
linalg.yield %564 : f32
} -> tensor<?x?x384xf32>
%520 = linalg.fill(%cst, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%521 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%519 : tensor<?x?x384xf32>) outs(%520 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.addf %arg2, %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%522 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%521 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%523 = linalg.fill(%cst, %501) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%524 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%519, %522 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%523 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.mulf %564, %564 : f32
%566 = arith.addf %arg3, %565 : f32
linalg.yield %566 : f32
} -> tensor<?x?xf32>
%525 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%524 : tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = arith.divf %arg1, %cst_115 : f32
linalg.yield %564 : f32
} -> tensor<?x?xf32>
%526 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%519, %522, %525, %cst_110, %cst_109 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%499 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%564 = arith.subf %arg1, %arg2 : f32
%565 = arith.truncf %cst_4 : f64 to f32
%566 = arith.addf %arg3, %565 : f32
%567 = math.rsqrt %566 : f32
%568 = arith.mulf %564, %567 : f32
%569 = arith.mulf %568, %arg4 : f32
%570 = arith.addf %569, %arg5 : f32
linalg.yield %570 : f32
} -> tensor<?x?x384xf32>
%527 = arith.index_cast %495 : index to i64
%528 = arith.cmpi sgt, %c0_i64, %527 : i64
%529 = arith.select %528, %527, %c0_i64 : i64
%530 = arith.index_cast %529 : i64 to index
%531 = arith.cmpi sgt, %c9223372036854775807_i64, %527 : i64
%532 = arith.select %531, %527, %c9223372036854775807_i64 : i64
%533 = arith.index_cast %532 : i64 to index
%534 = arith.cmpi sge, %533, %530 : index
%535 = arith.select %534, %533, %530 : index
%536 = arith.subi %535, %530 : index
%537 = tensor.extract_slice %526[%530, 0, 0] [%536, %496, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%538 = arith.index_cast %496 : index to i64
%539 = arith.cmpi sgt, %c0_i64, %538 : i64
%540 = arith.select %539, %538, %c0_i64 : i64
%541 = arith.index_cast %540 : i64 to index
%542 = arith.cmpi sgt, %c1_i64, %538 : i64
%543 = arith.select %542, %538, %c1_i64 : i64
%544 = arith.index_cast %543 : i64 to index
%545 = arith.cmpi sge, %544, %541 : index
%546 = arith.select %545, %544, %541 : index
%547 = arith.subi %546, %541 : index
%548 = tensor.extract_slice %537[0, %541, 0] [%536, %547, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%549 = tensor.cast %548 : tensor<?x?x384xf32> to tensor<?x1x384xf32>
%550 = tensor.collapse_shape %549 [[0, 1], [2]] : tensor<?x1x384xf32> into tensor<?x384xf32>
%551 = linalg.init_tensor [%536, 384] : tensor<?x384xf32>
%552 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%553 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_111 : tensor<384xf32>) outs(%551 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384xf32>
%554 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_112 : tensor<384x384xf32>) outs(%552 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%555 = linalg.matmul ins(%550, %554 : tensor<?x384xf32>, tensor<384x384xf32>) outs(%553 : tensor<?x384xf32>) -> tensor<?x384xf32>
%556 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%555 : tensor<?x384xf32>) outs(%551 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%564 = math.tanh %arg1 : f32
linalg.yield %564 : f32
} -> tensor<?x384xf32>
%557 = linalg.init_tensor [%536, 2] : tensor<?x2xf32>
%558 = linalg.init_tensor [384, 2] : tensor<384x2xf32>
%559 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_113 : tensor<2xf32>) outs(%557 : tensor<?x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x2xf32>
%560 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_114 : tensor<2x384xf32>) outs(%558 : tensor<384x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x2xf32>
%561 = linalg.matmul ins(%556, %560 : tensor<?x384xf32>, tensor<384x2xf32>) outs(%559 : tensor<?x2xf32>) -> tensor<?x2xf32>
%562 = tensor.dim %561, %c0 : tensor<?x2xf32>
%563 = hal.tensor.export %561 : tensor<?x2xf32>{%562} -> !hal.buffer_view
return %563 : !hal.buffer_view
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant 3.840000e+02 : f32
%c512 = arith.constant 512 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c1_i64 = arith.constant 1 : i64
%c512_i64 = arith.constant 512 : i64
%cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_1 = arith.constant dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>
%cst_2 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_3 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_4 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_5 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_6 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_7 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_8 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_20 = arith.constant dense<5.6568542494923806> : tensor<f64>
%cst_21 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_94 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_95 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_100 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_101 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_102 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_103 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_104 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_105 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_106 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_107 = arith.constant dense<0> : tensor<i64>
%cst_108 = arith.constant dense<0> : tensor<1x512xi64>
%cst_109 = arith.constant dense<-1.000000e+04> : tensor<f64>
%c2_i64 = arith.constant 2 : i64
%cst_110 = arith.constant 9.9999999999999998E-13 : f64
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_111 = arith.constant 1.000000e+00 : f32
%cst_112 = arith.constant 2.000000e+00 : f32
%cst_113 = arith.constant 5.000000e-01 : f32
%cst_114 = arith.constant -3.40282347E+38 : f32
%cst_115 = arith.constant 0.000000e+00 : f32
%c0_i64 = arith.constant 0 : i64
%c30522_i64 = arith.constant 30522 : i64
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%1 = linalg.init_tensor [] : tensor<i64>
%2 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%3 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2, %cst_107 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%560 = arith.addi %arg1, %arg2 : i64
linalg.yield %560 : i64
} -> tensor<i64>
%4 = tensor.extract %3[] : tensor<i64>
%5 = arith.index_cast %4 : i64 to index
%6 = linalg.init_tensor [1, %5] : tensor<1x?xf32>
%7 = linalg.fill(%cst_111, %6) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
%8 = tensor.extract_slice %7[0, 0] [1, %5] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
%9 = tensor.cast %8 : tensor<1x?xf32> to tensor<?x?xf32>
%10 = tensor.expand_shape %9 [[0], [1, 2, 3]] : tensor<?x?xf32> into tensor<?x1x1x?xf32>
%11 = arith.cmpi sgt, %c0_i64, %4 : i64
%12 = arith.select %11, %4, %c0_i64 : i64
%13 = arith.index_cast %12 : i64 to index
%14 = arith.cmpi sgt, %c9223372036854775807_i64, %4 : i64
%15 = arith.select %14, %4, %c9223372036854775807_i64 : i64
%16 = arith.index_cast %15 : i64 to index
%17 = arith.cmpi sge, %16, %13 : index
%18 = arith.select %17, %16, %13 : index
%19 = arith.subi %18, %13 : index
%20 = tensor.extract_slice %10[0, 0, 0, %13] [1, 1, 1, %19] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<1x1x1x?xf32>
%21 = linalg.init_tensor [1, 1, 1, %19] : tensor<1x1x1x?xf32>
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%20 : tensor<1x1x1x?xf32>) outs(%21 : tensor<1x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.subf %cst_111, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<1x1x1x?xf32>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%22, %cst_109 : tensor<1x1x1x?xf32>, tensor<f64>) outs(%21 : tensor<1x1x1x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.mulf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<1x1x1x?xf32>
%24 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%25 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%24, %cst_107 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%560 = arith.addi %arg1, %arg2 : i64
linalg.yield %560 : i64
} -> tensor<i64>
%26 = tensor.extract %25[] : tensor<i64>
%27 = arith.addi %26, %c512_i64 : i64
%28 = arith.cmpi sge, %26, %c0_i64 : i64
%29 = arith.select %28, %26, %27 : i64
%30 = arith.cmpi slt, %29, %c0_i64 : i64
%31 = arith.select %30, %c0_i64, %29 : i64
%32 = arith.cmpi sgt, %31, %c512_i64 : i64
%33 = arith.select %32, %c512_i64, %31 : i64
%34 = arith.index_cast %33 : i64 to index
%35 = arith.cmpi sge, %34, %c0 : index
%36 = arith.select %35, %34, %c0 : index
%37 = tensor.extract_slice %cst_106[0, 0] [1, %36] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
%38 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%39 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<1x512xi64>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%560 = arith.index_cast %arg1 : i64 to index
%561 = linalg.index 2 : index
%562 = arith.cmpi slt, %arg1, %c30522_i64 : i64
cf.assert %562, "index must be smaller than dim size"
%563 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %563, "index must be larger or equal to 0"
%564 = tensor.extract %cst_105[%560, %561] : tensor<30522x384xf32>
linalg.yield %564 : f32
} -> tensor<1x512x384xf32>
%40 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_108 : tensor<1x512xi64>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%560 = arith.index_cast %arg1 : i64 to index
%561 = linalg.index 2 : index
%562 = arith.cmpi slt, %arg1, %c2_i64 : i64
cf.assert %562, "index must be smaller than dim size"
%563 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %563, "index must be larger or equal to 0"
%564 = tensor.extract %cst_104[%560, %561] : tensor<2x384xf32>
linalg.yield %564 : f32
} -> tensor<1x512x384xf32>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%39, %40 : tensor<1x512x384xf32>, tensor<1x512x384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<1x512x384xf32>
%42 = linalg.init_tensor [1, %36, 384] : tensor<1x?x384xf32>
%43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%37 : tensor<1x?xi64>) outs(%42 : tensor<1x?x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%560 = arith.index_cast %arg1 : i64 to index
%561 = linalg.index 2 : index
%562 = arith.cmpi slt, %arg1, %c512_i64 : i64
cf.assert %562, "index must be smaller than dim size"
%563 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %563, "index must be larger or equal to 0"
%564 = tensor.extract %cst_103[%560, %561] : tensor<512x384xf32>
linalg.yield %564 : f32
} -> tensor<1x?x384xf32>
%44 = arith.cmpi eq, %c512, %36 : index
cf.assert %44, "mismatched size for broadcast"
%45 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41, %43 : tensor<1x512x384xf32>, tensor<1x?x384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<1x512x384xf32>
%46 = linalg.init_tensor [1, 512] : tensor<1x512xf32>
%47 = linalg.fill(%cst_115, %46) : f32, tensor<1x512xf32> -> tensor<1x512xf32>
%48 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%45 : tensor<1x512x384xf32>) outs(%47 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<1x512xf32>
%49 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%48 : tensor<1x512xf32>) outs(%46 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<1x512xf32>
%50 = linalg.fill(%cst_115, %46) : f32, tensor<1x512xf32> -> tensor<1x512xf32>
%51 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%45, %49 : tensor<1x512x384xf32>, tensor<1x512xf32>) outs(%50 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<1x512xf32>
%52 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%51 : tensor<1x512xf32>) outs(%46 : tensor<1x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<1x512xf32>
%53 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%45, %49, %52, %cst_101, %cst_102 : tensor<1x512x384xf32>, tensor<1x512xf32>, tensor<1x512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<1x512x384xf32>
%54 = linalg.init_tensor [1, 384, 384] : tensor<1x384x384xf32>
%55 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_100 : tensor<384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%56 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_99 : tensor<384x384xf32>) outs(%54 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%57 = linalg.batch_matmul ins(%53, %56 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%55 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%58 = tensor.cast %57 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%59 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_98 : tensor<384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%60 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_97 : tensor<384x384xf32>) outs(%54 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%61 = linalg.batch_matmul ins(%53, %60 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%59 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%62 = tensor.cast %61 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%63 = tensor.expand_shape %62 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%64 = linalg.init_tensor [1, 12, 512, 32] : tensor<1x12x512x32xf32>
%65 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%63 : tensor<?x?x12x32xf32>) outs(%64 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%66 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_96 : tensor<384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%67 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_95 : tensor<384x384xf32>) outs(%54 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%68 = linalg.batch_matmul ins(%53, %67 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%66 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%69 = tensor.cast %68 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%70 = tensor.expand_shape %69 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%71 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%70 : tensor<?x?x12x32xf32>) outs(%64 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%72 = tensor.expand_shape %58 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%73 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%72 : tensor<?x?x12x32xf32>) outs(%64 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%74 = linalg.init_tensor [1, 12, 32, 512] : tensor<1x12x32x512xf32>
%75 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%65 : tensor<1x12x512x32xf32>) outs(%74 : tensor<1x12x32x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x32x512xf32>
%76 = linalg.init_tensor [1, 12, 512, 512] : tensor<1x12x512x512xf32>
%77 = linalg.fill(%cst_115, %76) : f32, tensor<1x12x512x512xf32> -> tensor<1x12x512x512xf32>
%78 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%73, %75 : tensor<1x12x512x32xf32>, tensor<1x12x32x512xf32>) outs(%77 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<1x12x512x512xf32>
%79 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%78, %cst_20 : tensor<1x12x512x512xf32>, tensor<f64>) outs(%76 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.divf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<1x12x512x512xf32>
%80 = arith.cmpi eq, %c512, %19 : index
cf.assert %80, "mismatched size for broadcast"
%81 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%79, %23 : tensor<1x12x512x512xf32>, tensor<1x1x1x?xf32>) outs(%76 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<1x12x512x512xf32>
%82 = linalg.init_tensor [1, 12, 512, 1] : tensor<1x12x512x1xi64>
%83 = linalg.fill(%c0_i64, %82) : i64, tensor<1x12x512x1xi64> -> tensor<1x12x512x1xi64>
%84 = linalg.init_tensor [1, 12, 512, 1] : tensor<1x12x512x1xf32>
%85 = linalg.fill(%cst_114, %84) : f32, tensor<1x12x512x1xf32> -> tensor<1x12x512x1xf32>
%86:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%81 : tensor<1x12x512x512xf32>) outs(%85, %83 : tensor<1x12x512x1xf32>, tensor<1x12x512x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%560 = linalg.index 3 : index
%561 = arith.index_cast %560 : index to i64
%562 = arith.cmpf ogt, %arg1, %arg2 : f32
%563 = arith.select %562, %arg1, %arg2 : f32
%564 = arith.select %562, %561, %arg3 : i64
linalg.yield %563, %564 : f32, i64
} -> (tensor<1x12x512x1xf32>, tensor<1x12x512x1xi64>)
%87 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%81, %86#0 : tensor<1x12x512x512xf32>, tensor<1x12x512x1xf32>) outs(%76 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<1x12x512x512xf32>
%88 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%87 : tensor<1x12x512x512xf32>) outs(%76 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.exp %arg1 : f32
linalg.yield %560 : f32
} -> tensor<1x12x512x512xf32>
%89 = linalg.fill(%cst_115, %84) : f32, tensor<1x12x512x1xf32> -> tensor<1x12x512x1xf32>
%90 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%88 : tensor<1x12x512x512xf32>) outs(%89 : tensor<1x12x512x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<1x12x512x1xf32>
%91 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%88, %90 : tensor<1x12x512x512xf32>, tensor<1x12x512x1xf32>) outs(%76 : tensor<1x12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.divf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<1x12x512x512xf32>
%92 = linalg.fill(%cst_115, %64) : f32, tensor<1x12x512x32xf32> -> tensor<1x12x512x32xf32>
%93 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%91, %71 : tensor<1x12x512x512xf32>, tensor<1x12x512x32xf32>) outs(%92 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<1x12x512x32xf32>
%94 = linalg.init_tensor [1, 512, 12, 32] : tensor<1x512x12x32xf32>
%95 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%93 : tensor<1x12x512x32xf32>) outs(%94 : tensor<1x512x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x12x32xf32>
%96 = tensor.cast %95 : tensor<1x512x12x32xf32> to tensor<?x?x?x?xf32>
%97 = tensor.collapse_shape %96 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%98 = tensor.cast %97 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%99 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_94 : tensor<384xf32>) outs(%38 : tensor<1x512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x512x384xf32>
%100 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_93 : tensor<384x384xf32>) outs(%54 : tensor<1x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x384x384xf32>
%101 = linalg.batch_matmul ins(%98, %100 : tensor<?x?x384xf32>, tensor<1x384x384xf32>) outs(%99 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%102 = tensor.dim %97, %c0 : tensor<?x?x?xf32>
%103 = tensor.dim %97, %c1 : tensor<?x?x?xf32>
%104 = arith.cmpi eq, %102, %c1 : index
cf.assert %104, "mismatched size for broadcast"
%105 = arith.cmpi eq, %103, %c512 : index
cf.assert %105, "mismatched size for broadcast"
%106 = linalg.init_tensor [%102, %103, 384] : tensor<?x?x384xf32>
%107 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%101, %53 : tensor<1x512x384xf32>, tensor<1x512x384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%108 = linalg.init_tensor [%102, %103] : tensor<?x?xf32>
%109 = linalg.fill(%cst_115, %108) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%110 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%107 : tensor<?x?x384xf32>) outs(%109 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%111 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%110 : tensor<?x?xf32>) outs(%108 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%112 = linalg.fill(%cst_115, %108) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%113 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%107, %111 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%112 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%114 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%113 : tensor<?x?xf32>) outs(%108 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%115 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%107, %111, %114, %cst_91, %cst_92 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%116 = linalg.init_tensor [%102, %103, 1536] : tensor<?x?x1536xf32>
%117 = linalg.init_tensor [%102, 384, 1536] : tensor<?x384x1536xf32>
%118 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_90 : tensor<1536xf32>) outs(%116 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%119 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_89 : tensor<1536x384xf32>) outs(%117 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%120 = linalg.batch_matmul ins(%115, %119 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%118 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%121 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%120 : tensor<?x?x1536xf32>) outs(%116 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.sqrt %cst_112 : f32
%561 = arith.divf %arg1, %560 : f32
%562 = math.erf %561 : f32
%563 = arith.addf %562, %cst_111 : f32
%564 = arith.mulf %563, %cst_113 : f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x?x1536xf32>
%122 = linalg.init_tensor [%102, 1536, 384] : tensor<?x1536x384xf32>
%123 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_88 : tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%124 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_87 : tensor<384x1536xf32>) outs(%122 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%125 = linalg.batch_matmul ins(%121, %124 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%123 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%126 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%125, %115 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%127 = linalg.fill(%cst_115, %108) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%128 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%126 : tensor<?x?x384xf32>) outs(%127 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%129 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%128 : tensor<?x?xf32>) outs(%108 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%130 = linalg.fill(%cst_115, %108) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%131 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%126, %129 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%130 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%132 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%131 : tensor<?x?xf32>) outs(%108 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%133 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%126, %129, %132, %cst_85, %cst_86 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%134 = linalg.init_tensor [%102, 384, 384] : tensor<?x384x384xf32>
%135 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_84 : tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%136 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_83 : tensor<384x384xf32>) outs(%134 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%137 = linalg.batch_matmul ins(%133, %136 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%135 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%138 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_82 : tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%139 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_81 : tensor<384x384xf32>) outs(%134 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%140 = linalg.batch_matmul ins(%133, %139 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%138 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%141 = tensor.expand_shape %140 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%142 = linalg.init_tensor [%102, 12, %103, 32] : tensor<?x12x?x32xf32>
%143 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%141 : tensor<?x?x12x32xf32>) outs(%142 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%144 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_80 : tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%145 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_79 : tensor<384x384xf32>) outs(%134 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%146 = linalg.batch_matmul ins(%133, %145 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%144 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%147 = tensor.expand_shape %146 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%148 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%147 : tensor<?x?x12x32xf32>) outs(%142 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%149 = tensor.expand_shape %137 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%150 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%149 : tensor<?x?x12x32xf32>) outs(%142 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%151 = linalg.init_tensor [%102, 12, 32, %103] : tensor<?x12x32x?xf32>
%152 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%143 : tensor<?x12x?x32xf32>) outs(%151 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%153 = linalg.init_tensor [%102, 12, %103, %103] : tensor<?x12x?x?xf32>
%154 = linalg.fill(%cst_115, %153) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%155 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%150, %152 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%154 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%156 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%155, %cst_20 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%153 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.divf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %104, "mismatched size for broadcast"
%157 = arith.cmpi eq, %103, %19 : index
cf.assert %157, "mismatched size for broadcast"
%158 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%156, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%153 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%159 = linalg.init_tensor [%102, 12, %103, 1] : tensor<?x12x?x1xi64>
%160 = linalg.fill(%c0_i64, %159) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%161 = linalg.init_tensor [%102, 12, %103, 1] : tensor<?x12x?x1xf32>
%162 = linalg.fill(%cst_114, %161) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%163:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%158 : tensor<?x12x?x?xf32>) outs(%162, %160 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%560 = linalg.index 3 : index
%561 = arith.index_cast %560 : index to i64
%562 = arith.cmpf ogt, %arg1, %arg2 : f32
%563 = arith.select %562, %arg1, %arg2 : f32
%564 = arith.select %562, %561, %arg3 : i64
linalg.yield %563, %564 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%164 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%158, %163#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%153 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%165 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%164 : tensor<?x12x?x?xf32>) outs(%153 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.exp %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%166 = linalg.fill(%cst_115, %161) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%167 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%165 : tensor<?x12x?x?xf32>) outs(%166 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x1xf32>
%168 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%165, %167 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%153 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.divf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%169 = linalg.fill(%cst_115, %142) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%170 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%168, %148 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%169 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x32xf32>
%171 = linalg.init_tensor [%102, %103, 12, 32] : tensor<?x?x12x32xf32>
%172 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%170 : tensor<?x12x?x32xf32>) outs(%171 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%173 = tensor.cast %172 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%174 = tensor.collapse_shape %173 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%175 = tensor.cast %174 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%176 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_78 : tensor<384xf32>) outs(%106 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%177 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_77 : tensor<384x384xf32>) outs(%134 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%178 = linalg.batch_matmul ins(%175, %177 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%176 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%179 = tensor.dim %174, %c0 : tensor<?x?x?xf32>
%180 = tensor.dim %174, %c1 : tensor<?x?x?xf32>
%181 = arith.cmpi eq, %179, %102 : index
cf.assert %181, "mismatched size for broadcast"
%182 = arith.cmpi eq, %180, %103 : index
cf.assert %182, "mismatched size for broadcast"
%183 = linalg.init_tensor [%179, %180, 384] : tensor<?x?x384xf32>
%184 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%178, %133 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%185 = linalg.init_tensor [%179, %180] : tensor<?x?xf32>
%186 = linalg.fill(%cst_115, %185) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%187 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%184 : tensor<?x?x384xf32>) outs(%186 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%188 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%187 : tensor<?x?xf32>) outs(%185 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%189 = linalg.fill(%cst_115, %185) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%190 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%184, %188 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%189 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%191 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%190 : tensor<?x?xf32>) outs(%185 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%192 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%184, %188, %191, %cst_75, %cst_76 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%193 = linalg.init_tensor [%179, %180, 1536] : tensor<?x?x1536xf32>
%194 = linalg.init_tensor [%179, 384, 1536] : tensor<?x384x1536xf32>
%195 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_74 : tensor<1536xf32>) outs(%193 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%196 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_73 : tensor<1536x384xf32>) outs(%194 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%197 = linalg.batch_matmul ins(%192, %196 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%195 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%198 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%197 : tensor<?x?x1536xf32>) outs(%193 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.sqrt %cst_112 : f32
%561 = arith.divf %arg1, %560 : f32
%562 = math.erf %561 : f32
%563 = arith.addf %562, %cst_111 : f32
%564 = arith.mulf %563, %cst_113 : f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x?x1536xf32>
%199 = linalg.init_tensor [%179, 1536, 384] : tensor<?x1536x384xf32>
%200 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_72 : tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%201 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_71 : tensor<384x1536xf32>) outs(%199 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%202 = linalg.batch_matmul ins(%198, %201 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%200 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%203 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%202, %192 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%204 = linalg.fill(%cst_115, %185) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%205 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%203 : tensor<?x?x384xf32>) outs(%204 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%206 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%205 : tensor<?x?xf32>) outs(%185 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%207 = linalg.fill(%cst_115, %185) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%208 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%203, %206 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%207 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%209 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%208 : tensor<?x?xf32>) outs(%185 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%210 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%203, %206, %209, %cst_69, %cst_70 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%211 = linalg.init_tensor [%179, 384, 384] : tensor<?x384x384xf32>
%212 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_68 : tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%213 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_67 : tensor<384x384xf32>) outs(%211 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%214 = linalg.batch_matmul ins(%210, %213 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%212 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%215 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_66 : tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%216 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65 : tensor<384x384xf32>) outs(%211 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%217 = linalg.batch_matmul ins(%210, %216 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%215 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%218 = tensor.expand_shape %217 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%219 = linalg.init_tensor [%179, 12, %180, 32] : tensor<?x12x?x32xf32>
%220 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%218 : tensor<?x?x12x32xf32>) outs(%219 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%221 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64 : tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%222 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63 : tensor<384x384xf32>) outs(%211 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%223 = linalg.batch_matmul ins(%210, %222 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%221 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%224 = tensor.expand_shape %223 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%225 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%224 : tensor<?x?x12x32xf32>) outs(%219 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%226 = tensor.expand_shape %214 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%227 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%226 : tensor<?x?x12x32xf32>) outs(%219 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%228 = linalg.init_tensor [%179, 12, 32, %180] : tensor<?x12x32x?xf32>
%229 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%220 : tensor<?x12x?x32xf32>) outs(%228 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%230 = linalg.init_tensor [%179, 12, %180, %180] : tensor<?x12x?x?xf32>
%231 = linalg.fill(%cst_115, %230) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%232 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%227, %229 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%231 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%233 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%232, %cst_20 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%230 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.divf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%234 = arith.cmpi eq, %179, %c1 : index
cf.assert %234, "mismatched size for broadcast"
%235 = arith.cmpi eq, %180, %19 : index
cf.assert %235, "mismatched size for broadcast"
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%233, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%230 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%237 = linalg.init_tensor [%179, 12, %180, 1] : tensor<?x12x?x1xi64>
%238 = linalg.fill(%c0_i64, %237) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%239 = linalg.init_tensor [%179, 12, %180, 1] : tensor<?x12x?x1xf32>
%240 = linalg.fill(%cst_114, %239) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%241:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%236 : tensor<?x12x?x?xf32>) outs(%240, %238 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%560 = linalg.index 3 : index
%561 = arith.index_cast %560 : index to i64
%562 = arith.cmpf ogt, %arg1, %arg2 : f32
%563 = arith.select %562, %arg1, %arg2 : f32
%564 = arith.select %562, %561, %arg3 : i64
linalg.yield %563, %564 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%242 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%236, %241#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%230 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%243 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%242 : tensor<?x12x?x?xf32>) outs(%230 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.exp %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%244 = linalg.fill(%cst_115, %239) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%245 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%243 : tensor<?x12x?x?xf32>) outs(%244 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x1xf32>
%246 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%243, %245 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%230 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.divf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%247 = linalg.fill(%cst_115, %219) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%248 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%246, %225 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%247 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x32xf32>
%249 = linalg.init_tensor [%179, %180, 12, 32] : tensor<?x?x12x32xf32>
%250 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%248 : tensor<?x12x?x32xf32>) outs(%249 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%251 = tensor.cast %250 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%252 = tensor.collapse_shape %251 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%253 = tensor.cast %252 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%254 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_62 : tensor<384xf32>) outs(%183 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%255 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_61 : tensor<384x384xf32>) outs(%211 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%256 = linalg.batch_matmul ins(%253, %255 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%254 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%257 = tensor.dim %252, %c0 : tensor<?x?x?xf32>
%258 = tensor.dim %252, %c1 : tensor<?x?x?xf32>
%259 = arith.cmpi eq, %257, %179 : index
cf.assert %259, "mismatched size for broadcast"
%260 = arith.cmpi eq, %258, %180 : index
cf.assert %260, "mismatched size for broadcast"
%261 = linalg.init_tensor [%257, %258, 384] : tensor<?x?x384xf32>
%262 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%256, %210 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%263 = linalg.init_tensor [%257, %258] : tensor<?x?xf32>
%264 = linalg.fill(%cst_115, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%265 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%262 : tensor<?x?x384xf32>) outs(%264 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%266 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%265 : tensor<?x?xf32>) outs(%263 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%267 = linalg.fill(%cst_115, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%268 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%262, %266 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%269 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%268 : tensor<?x?xf32>) outs(%263 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%270 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%262, %266, %269, %cst_59, %cst_60 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%271 = linalg.init_tensor [%257, %258, 1536] : tensor<?x?x1536xf32>
%272 = linalg.init_tensor [%257, 384, 1536] : tensor<?x384x1536xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_58 : tensor<1536xf32>) outs(%271 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_57 : tensor<1536x384xf32>) outs(%272 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%275 = linalg.batch_matmul ins(%270, %274 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%273 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%276 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%275 : tensor<?x?x1536xf32>) outs(%271 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.sqrt %cst_112 : f32
%561 = arith.divf %arg1, %560 : f32
%562 = math.erf %561 : f32
%563 = arith.addf %562, %cst_111 : f32
%564 = arith.mulf %563, %cst_113 : f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x?x1536xf32>
%277 = linalg.init_tensor [%257, 1536, 384] : tensor<?x1536x384xf32>
%278 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_56 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%279 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_55 : tensor<384x1536xf32>) outs(%277 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%280 = linalg.batch_matmul ins(%276, %279 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%278 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%281 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%280, %270 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%282 = linalg.fill(%cst_115, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%283 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%281 : tensor<?x?x384xf32>) outs(%282 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%284 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%283 : tensor<?x?xf32>) outs(%263 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%285 = linalg.fill(%cst_115, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%286 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%281, %284 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%285 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%287 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%286 : tensor<?x?xf32>) outs(%263 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%288 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%281, %284, %287, %cst_53, %cst_54 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%289 = linalg.init_tensor [%257, 384, 384] : tensor<?x384x384xf32>
%290 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_51 : tensor<384x384xf32>) outs(%289 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%292 = linalg.batch_matmul ins(%288, %291 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%290 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%293 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_50 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%294 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_49 : tensor<384x384xf32>) outs(%289 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%295 = linalg.batch_matmul ins(%288, %294 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%293 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%296 = tensor.expand_shape %295 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%297 = linalg.init_tensor [%257, 12, %258, 32] : tensor<?x12x?x32xf32>
%298 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%296 : tensor<?x?x12x32xf32>) outs(%297 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%299 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_48 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%300 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_47 : tensor<384x384xf32>) outs(%289 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%301 = linalg.batch_matmul ins(%288, %300 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%299 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%302 = tensor.expand_shape %301 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%303 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%302 : tensor<?x?x12x32xf32>) outs(%297 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%304 = tensor.expand_shape %292 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%305 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%304 : tensor<?x?x12x32xf32>) outs(%297 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%306 = linalg.init_tensor [%257, 12, 32, %258] : tensor<?x12x32x?xf32>
%307 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%298 : tensor<?x12x?x32xf32>) outs(%306 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%308 = linalg.init_tensor [%257, 12, %258, %258] : tensor<?x12x?x?xf32>
%309 = linalg.fill(%cst_115, %308) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%310 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%305, %307 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%309 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%311 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%310, %cst_20 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%308 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.divf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%312 = arith.cmpi eq, %257, %c1 : index
cf.assert %312, "mismatched size for broadcast"
%313 = arith.cmpi eq, %258, %19 : index
cf.assert %313, "mismatched size for broadcast"
%314 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%311, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%308 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%315 = linalg.init_tensor [%257, 12, %258, 1] : tensor<?x12x?x1xi64>
%316 = linalg.fill(%c0_i64, %315) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%317 = linalg.init_tensor [%257, 12, %258, 1] : tensor<?x12x?x1xf32>
%318 = linalg.fill(%cst_114, %317) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%319:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%314 : tensor<?x12x?x?xf32>) outs(%318, %316 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%560 = linalg.index 3 : index
%561 = arith.index_cast %560 : index to i64
%562 = arith.cmpf ogt, %arg1, %arg2 : f32
%563 = arith.select %562, %arg1, %arg2 : f32
%564 = arith.select %562, %561, %arg3 : i64
linalg.yield %563, %564 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%320 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%314, %319#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%308 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%321 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%320 : tensor<?x12x?x?xf32>) outs(%308 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.exp %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%322 = linalg.fill(%cst_115, %317) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%323 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%321 : tensor<?x12x?x?xf32>) outs(%322 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x1xf32>
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%321, %323 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%308 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.divf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%325 = linalg.fill(%cst_115, %297) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%326 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%324, %303 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%325 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x32xf32>
%327 = linalg.init_tensor [%257, %258, 12, 32] : tensor<?x?x12x32xf32>
%328 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%326 : tensor<?x12x?x32xf32>) outs(%327 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%329 = tensor.cast %328 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%330 = tensor.collapse_shape %329 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%331 = tensor.cast %330 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%332 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_46 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%333 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_45 : tensor<384x384xf32>) outs(%289 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%334 = linalg.batch_matmul ins(%331, %333 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%332 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%335 = tensor.dim %330, %c0 : tensor<?x?x?xf32>
%336 = tensor.dim %330, %c1 : tensor<?x?x?xf32>
%337 = arith.cmpi eq, %335, %257 : index
cf.assert %337, "mismatched size for broadcast"
%338 = arith.cmpi eq, %336, %258 : index
cf.assert %338, "mismatched size for broadcast"
%339 = linalg.init_tensor [%335, %336, 384] : tensor<?x?x384xf32>
%340 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%334, %288 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%341 = linalg.init_tensor [%335, %336] : tensor<?x?xf32>
%342 = linalg.fill(%cst_115, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%343 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%340 : tensor<?x?x384xf32>) outs(%342 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%344 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%343 : tensor<?x?xf32>) outs(%341 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%345 = linalg.fill(%cst_115, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%346 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%340, %344 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%347 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%346 : tensor<?x?xf32>) outs(%341 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%348 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%340, %344, %347, %cst_43, %cst_44 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%349 = linalg.init_tensor [%335, %336, 1536] : tensor<?x?x1536xf32>
%350 = linalg.init_tensor [%335, 384, 1536] : tensor<?x384x1536xf32>
%351 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42 : tensor<1536xf32>) outs(%349 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%352 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_41 : tensor<1536x384xf32>) outs(%350 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%353 = linalg.batch_matmul ins(%348, %352 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%351 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%354 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%353 : tensor<?x?x1536xf32>) outs(%349 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.sqrt %cst_112 : f32
%561 = arith.divf %arg1, %560 : f32
%562 = math.erf %561 : f32
%563 = arith.addf %562, %cst_111 : f32
%564 = arith.mulf %563, %cst_113 : f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x?x1536xf32>
%355 = linalg.init_tensor [%335, 1536, 384] : tensor<?x1536x384xf32>
%356 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_40 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%357 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_39 : tensor<384x1536xf32>) outs(%355 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%358 = linalg.batch_matmul ins(%354, %357 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%356 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%359 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%358, %348 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%360 = linalg.fill(%cst_115, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%361 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%359 : tensor<?x?x384xf32>) outs(%360 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%362 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%361 : tensor<?x?xf32>) outs(%341 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%363 = linalg.fill(%cst_115, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%364 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%359, %362 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%363 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%365 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%364 : tensor<?x?xf32>) outs(%341 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%366 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%359, %362, %365, %cst_37, %cst_38 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%367 = linalg.init_tensor [%335, 384, 384] : tensor<?x384x384xf32>
%368 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_36 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_35 : tensor<384x384xf32>) outs(%367 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%370 = linalg.batch_matmul ins(%366, %369 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%368 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%371 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_34 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%372 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_33 : tensor<384x384xf32>) outs(%367 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%373 = linalg.batch_matmul ins(%366, %372 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%371 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%374 = tensor.expand_shape %373 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%375 = linalg.init_tensor [%335, 12, %336, 32] : tensor<?x12x?x32xf32>
%376 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%374 : tensor<?x?x12x32xf32>) outs(%375 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%377 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_32 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%378 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_31 : tensor<384x384xf32>) outs(%367 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%379 = linalg.batch_matmul ins(%366, %378 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%377 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%380 = tensor.expand_shape %379 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%381 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%380 : tensor<?x?x12x32xf32>) outs(%375 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%382 = tensor.expand_shape %370 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%383 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%382 : tensor<?x?x12x32xf32>) outs(%375 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%384 = linalg.init_tensor [%335, 12, 32, %336] : tensor<?x12x32x?xf32>
%385 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%376 : tensor<?x12x?x32xf32>) outs(%384 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%386 = linalg.init_tensor [%335, 12, %336, %336] : tensor<?x12x?x?xf32>
%387 = linalg.fill(%cst_115, %386) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%388 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%383, %385 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%387 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%389 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%388, %cst_20 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%386 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.divf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%390 = arith.cmpi eq, %335, %c1 : index
cf.assert %390, "mismatched size for broadcast"
%391 = arith.cmpi eq, %336, %19 : index
cf.assert %391, "mismatched size for broadcast"
%392 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%389, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%386 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%393 = linalg.init_tensor [%335, 12, %336, 1] : tensor<?x12x?x1xi64>
%394 = linalg.fill(%c0_i64, %393) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%395 = linalg.init_tensor [%335, 12, %336, 1] : tensor<?x12x?x1xf32>
%396 = linalg.fill(%cst_114, %395) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%397:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%392 : tensor<?x12x?x?xf32>) outs(%396, %394 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%560 = linalg.index 3 : index
%561 = arith.index_cast %560 : index to i64
%562 = arith.cmpf ogt, %arg1, %arg2 : f32
%563 = arith.select %562, %arg1, %arg2 : f32
%564 = arith.select %562, %561, %arg3 : i64
linalg.yield %563, %564 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%398 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%392, %397#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%386 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%399 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%398 : tensor<?x12x?x?xf32>) outs(%386 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.exp %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%400 = linalg.fill(%cst_115, %395) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%401 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%399 : tensor<?x12x?x?xf32>) outs(%400 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x1xf32>
%402 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%399, %401 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%386 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.divf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%403 = linalg.fill(%cst_115, %375) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%404 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%402, %381 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%403 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x32xf32>
%405 = linalg.init_tensor [%335, %336, 12, 32] : tensor<?x?x12x32xf32>
%406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%404 : tensor<?x12x?x32xf32>) outs(%405 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%407 = tensor.cast %406 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%408 = tensor.collapse_shape %407 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%409 = tensor.cast %408 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%410 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_30 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%411 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_29 : tensor<384x384xf32>) outs(%367 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%412 = linalg.batch_matmul ins(%409, %411 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%410 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%413 = tensor.dim %408, %c0 : tensor<?x?x?xf32>
%414 = tensor.dim %408, %c1 : tensor<?x?x?xf32>
%415 = arith.cmpi eq, %413, %335 : index
cf.assert %415, "mismatched size for broadcast"
%416 = arith.cmpi eq, %414, %336 : index
cf.assert %416, "mismatched size for broadcast"
%417 = linalg.init_tensor [%413, %414, 384] : tensor<?x?x384xf32>
%418 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%412, %366 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%419 = linalg.init_tensor [%413, %414] : tensor<?x?xf32>
%420 = linalg.fill(%cst_115, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%421 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%418 : tensor<?x?x384xf32>) outs(%420 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%422 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%421 : tensor<?x?xf32>) outs(%419 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%423 = linalg.fill(%cst_115, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%424 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%418, %422 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%425 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%424 : tensor<?x?xf32>) outs(%419 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%426 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%418, %422, %425, %cst_27, %cst_28 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%427 = linalg.init_tensor [%413, %414, 1536] : tensor<?x?x1536xf32>
%428 = linalg.init_tensor [%413, 384, 1536] : tensor<?x384x1536xf32>
%429 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_26 : tensor<1536xf32>) outs(%427 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%430 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_25 : tensor<1536x384xf32>) outs(%428 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%431 = linalg.batch_matmul ins(%426, %430 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%429 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%432 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%431 : tensor<?x?x1536xf32>) outs(%427 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.sqrt %cst_112 : f32
%561 = arith.divf %arg1, %560 : f32
%562 = math.erf %561 : f32
%563 = arith.addf %562, %cst_111 : f32
%564 = arith.mulf %563, %cst_113 : f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x?x1536xf32>
%433 = linalg.init_tensor [%413, 1536, 384] : tensor<?x1536x384xf32>
%434 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_24 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%435 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_23 : tensor<384x1536xf32>) outs(%433 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%436 = linalg.batch_matmul ins(%432, %435 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%434 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%437 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%436, %426 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%438 = linalg.fill(%cst_115, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%439 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%437 : tensor<?x?x384xf32>) outs(%438 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%440 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%439 : tensor<?x?xf32>) outs(%419 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%441 = linalg.fill(%cst_115, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%442 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%437, %440 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%441 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%443 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%442 : tensor<?x?xf32>) outs(%419 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%437, %440, %443, %cst_21, %cst_22 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%445 = linalg.init_tensor [%413, 384, 384] : tensor<?x384x384xf32>
%446 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_19 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%447 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18 : tensor<384x384xf32>) outs(%445 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%448 = linalg.batch_matmul ins(%444, %447 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%446 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%449 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_17 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%450 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16 : tensor<384x384xf32>) outs(%445 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%451 = linalg.batch_matmul ins(%444, %450 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%449 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%452 = tensor.expand_shape %451 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%453 = linalg.init_tensor [%413, 12, %414, 32] : tensor<?x12x?x32xf32>
%454 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%452 : tensor<?x?x12x32xf32>) outs(%453 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%455 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_15 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%456 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_14 : tensor<384x384xf32>) outs(%445 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%457 = linalg.batch_matmul ins(%444, %456 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%455 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%458 = tensor.expand_shape %457 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%459 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%458 : tensor<?x?x12x32xf32>) outs(%453 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%460 = tensor.expand_shape %448 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%461 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%460 : tensor<?x?x12x32xf32>) outs(%453 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%462 = linalg.init_tensor [%413, 12, 32, %414] : tensor<?x12x32x?xf32>
%463 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%454 : tensor<?x12x?x32xf32>) outs(%462 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%464 = linalg.init_tensor [%413, 12, %414, %414] : tensor<?x12x?x?xf32>
%465 = linalg.fill(%cst_115, %464) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%466 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%461, %463 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%465 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%467 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%466, %cst_20 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%464 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%560 = arith.truncf %arg2 : f64 to f32
%561 = arith.divf %arg1, %560 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x?xf32>
%468 = arith.cmpi eq, %413, %c1 : index
cf.assert %468, "mismatched size for broadcast"
%469 = arith.cmpi eq, %414, %19 : index
cf.assert %469, "mismatched size for broadcast"
%470 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%467, %23 : tensor<?x12x?x?xf32>, tensor<1x1x1x?xf32>) outs(%464 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%471 = linalg.init_tensor [%413, 12, %414, 1] : tensor<?x12x?x1xi64>
%472 = linalg.fill(%c0_i64, %471) : i64, tensor<?x12x?x1xi64> -> tensor<?x12x?x1xi64>
%473 = linalg.init_tensor [%413, 12, %414, 1] : tensor<?x12x?x1xf32>
%474 = linalg.fill(%cst_114, %473) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%475:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%470 : tensor<?x12x?x?xf32>) outs(%474, %472 : tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%560 = linalg.index 3 : index
%561 = arith.index_cast %560 : index to i64
%562 = arith.cmpf ogt, %arg1, %arg2 : f32
%563 = arith.select %562, %arg1, %arg2 : f32
%564 = arith.select %562, %561, %arg3 : i64
linalg.yield %563, %564 : f32, i64
} -> (tensor<?x12x?x1xf32>, tensor<?x12x?x1xi64>)
%476 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%470, %475#0 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%464 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%477 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%476 : tensor<?x12x?x?xf32>) outs(%464 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.exp %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%478 = linalg.fill(%cst_115, %473) : f32, tensor<?x12x?x1xf32> -> tensor<?x12x?x1xf32>
%479 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%477 : tensor<?x12x?x?xf32>) outs(%478 : tensor<?x12x?x1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x1xf32>
%480 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%477, %479 : tensor<?x12x?x?xf32>, tensor<?x12x?x1xf32>) outs(%464 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.divf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x12x?x?xf32>
%481 = linalg.fill(%cst_115, %453) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%482 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%480, %459 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%481 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.mulf %arg1, %arg2 : f32
%561 = arith.addf %560, %arg3 : f32
linalg.yield %561 : f32
} -> tensor<?x12x?x32xf32>
%483 = linalg.init_tensor [%413, %414, 12, 32] : tensor<?x?x12x32xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%482 : tensor<?x12x?x32xf32>) outs(%483 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%485 = tensor.cast %484 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%486 = tensor.collapse_shape %485 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%487 = tensor.cast %486 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%488 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_13 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%489 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_12 : tensor<384x384xf32>) outs(%445 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%490 = linalg.batch_matmul ins(%487, %489 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%488 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%491 = tensor.dim %486, %c0 : tensor<?x?x?xf32>
%492 = tensor.dim %486, %c1 : tensor<?x?x?xf32>
%493 = arith.cmpi eq, %491, %413 : index
cf.assert %493, "mismatched size for broadcast"
%494 = arith.cmpi eq, %492, %414 : index
cf.assert %494, "mismatched size for broadcast"
%495 = linalg.init_tensor [%491, %492, 384] : tensor<?x?x384xf32>
%496 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%490, %444 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%497 = linalg.init_tensor [%491, %492] : tensor<?x?xf32>
%498 = linalg.fill(%cst_115, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%499 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%496 : tensor<?x?x384xf32>) outs(%498 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%500 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%499 : tensor<?x?xf32>) outs(%497 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%501 = linalg.fill(%cst_115, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%502 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%496, %500 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%503 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%502 : tensor<?x?xf32>) outs(%497 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%504 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%496, %500, %503, %cst_10, %cst_11 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%505 = linalg.init_tensor [%491, %492, 1536] : tensor<?x?x1536xf32>
%506 = linalg.init_tensor [%491, 384, 1536] : tensor<?x384x1536xf32>
%507 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_9 : tensor<1536xf32>) outs(%505 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%508 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_8 : tensor<1536x384xf32>) outs(%506 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%509 = linalg.batch_matmul ins(%504, %508 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%507 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%510 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%509 : tensor<?x?x1536xf32>) outs(%505 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.sqrt %cst_112 : f32
%561 = arith.divf %arg1, %560 : f32
%562 = math.erf %561 : f32
%563 = arith.addf %562, %cst_111 : f32
%564 = arith.mulf %563, %cst_113 : f32
%565 = arith.mulf %arg1, %564 : f32
linalg.yield %565 : f32
} -> tensor<?x?x1536xf32>
%511 = linalg.init_tensor [%491, 1536, 384] : tensor<?x1536x384xf32>
%512 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_7 : tensor<384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%513 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_6 : tensor<384x1536xf32>) outs(%511 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%514 = linalg.batch_matmul ins(%510, %513 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%512 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%515 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%514, %504 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.addf %arg1, %arg2 : f32
linalg.yield %560 : f32
} -> tensor<?x?x384xf32>
%516 = linalg.fill(%cst_115, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%517 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%515 : tensor<?x?x384xf32>) outs(%516 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.addf %arg2, %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%518 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%517 : tensor<?x?xf32>) outs(%497 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%519 = linalg.fill(%cst_115, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%520 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%515, %518 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%519 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.mulf %560, %560 : f32
%562 = arith.addf %arg3, %561 : f32
linalg.yield %562 : f32
} -> tensor<?x?xf32>
%521 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%520 : tensor<?x?xf32>) outs(%497 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = arith.divf %arg1, %cst : f32
linalg.yield %560 : f32
} -> tensor<?x?xf32>
%522 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%515, %518, %521, %cst_4, %cst_5 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%560 = arith.subf %arg1, %arg2 : f32
%561 = arith.truncf %cst_110 : f64 to f32
%562 = arith.addf %arg3, %561 : f32
%563 = math.rsqrt %562 : f32
%564 = arith.mulf %560, %563 : f32
%565 = arith.mulf %564, %arg4 : f32
%566 = arith.addf %565, %arg5 : f32
linalg.yield %566 : f32
} -> tensor<?x?x384xf32>
%523 = arith.index_cast %491 : index to i64
%524 = arith.cmpi sgt, %c0_i64, %523 : i64
%525 = arith.select %524, %523, %c0_i64 : i64
%526 = arith.index_cast %525 : i64 to index
%527 = arith.cmpi sgt, %c9223372036854775807_i64, %523 : i64
%528 = arith.select %527, %523, %c9223372036854775807_i64 : i64
%529 = arith.index_cast %528 : i64 to index
%530 = arith.cmpi sge, %529, %526 : index
%531 = arith.select %530, %529, %526 : index
%532 = arith.subi %531, %526 : index
%533 = tensor.extract_slice %522[%526, 0, 0] [%532, %492, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%534 = arith.index_cast %492 : index to i64
%535 = arith.cmpi sgt, %c0_i64, %534 : i64
%536 = arith.select %535, %534, %c0_i64 : i64
%537 = arith.index_cast %536 : i64 to index
%538 = arith.cmpi sgt, %c1_i64, %534 : i64
%539 = arith.select %538, %534, %c1_i64 : i64
%540 = arith.index_cast %539 : i64 to index
%541 = arith.cmpi sge, %540, %537 : index
%542 = arith.select %541, %540, %537 : index
%543 = arith.subi %542, %537 : index
%544 = tensor.extract_slice %533[0, %537, 0] [%532, %543, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%545 = tensor.cast %544 : tensor<?x?x384xf32> to tensor<?x1x384xf32>
%546 = tensor.collapse_shape %545 [[0, 1], [2]] : tensor<?x1x384xf32> into tensor<?x384xf32>
%547 = linalg.init_tensor [%532, 384] : tensor<?x384xf32>
%548 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%549 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_3 : tensor<384xf32>) outs(%547 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384xf32>
%550 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_2 : tensor<384x384xf32>) outs(%548 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%551 = linalg.matmul ins(%546, %550 : tensor<?x384xf32>, tensor<384x384xf32>) outs(%549 : tensor<?x384xf32>) -> tensor<?x384xf32>
%552 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%551 : tensor<?x384xf32>) outs(%547 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%560 = math.tanh %arg1 : f32
linalg.yield %560 : f32
} -> tensor<?x384xf32>
%553 = linalg.init_tensor [%532, 2] : tensor<?x2xf32>
%554 = linalg.init_tensor [384, 2] : tensor<384x2xf32>
%555 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<2xf32>) outs(%553 : tensor<?x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x2xf32>
%556 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_0 : tensor<2x384xf32>) outs(%554 : tensor<384x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x2xf32>
%557 = linalg.matmul ins(%552, %556 : tensor<?x384xf32>, tensor<384x2xf32>) outs(%555 : tensor<?x2xf32>) -> tensor<?x2xf32>
%558 = tensor.dim %557, %c0 : tensor<?x2xf32>
%559 = hal.tensor.export %557 : tensor<?x2xf32>{%558} -> !hal.buffer_view
return %559 : !hal.buffer_view
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c30522_i64 = arith.constant 30522 : i64
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant -3.40282347E+38 : f32
%cst_1 = arith.constant 5.000000e-01 : f32
%cst_2 = arith.constant 2.000000e+00 : f32
%cst_3 = arith.constant 1.000000e+00 : f32
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_4 = arith.constant 9.9999999999999998E-13 : f64
%c2_i64 = arith.constant 2 : i64
%cst_5 = arith.constant dense<-1.000000e+04> : tensor<f64>
%cst_6 = arith.constant dense<0> : tensor<512xi64>
%cst_7 = arith.constant dense<0> : tensor<i64>
%cst_8 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_20 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_21 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_94 = arith.constant dense<5.6568542494923806> : tensor<f64>
%cst_95 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_100 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_101 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_102 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_103 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_104 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_105 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_106 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_107 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_108 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_109 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_110 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_111 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_112 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_113 = arith.constant dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>
%cst_114 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%c512_i64 = arith.constant 512 : i64
%c1_i64 = arith.constant 1 : i64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c512 = arith.constant 512 : index
%cst_115 = arith.constant 3.840000e+02 : f32
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%1 = linalg.init_tensor [] : tensor<i64>
%2 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%3 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2, %cst_7 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%609 = arith.addi %arg1, %arg2 : i64
linalg.yield %609 : i64
} -> tensor<i64>
%4 = tensor.extract %3[] : tensor<i64>
%5 = arith.index_cast %4 : i64 to index
%6 = linalg.init_tensor [1, %5] : tensor<1x?xf32>
%7 = linalg.fill(%cst_3, %6) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
%8 = tensor.extract_slice %7[0, 0] [1, %5] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
%9 = tensor.expand_shape %8 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
%10 = tensor.cast %9 : tensor<1x?xf32> to tensor<?x?xf32>
%11 = tensor.expand_shape %10 [[0], [1, 2, 3]] : tensor<?x?xf32> into tensor<?x1x1x?xf32>
%12 = arith.cmpi sgt, %c0_i64, %4 : i64
%13 = arith.select %12, %4, %c0_i64 : i64
%14 = arith.index_cast %13 : i64 to index
%15 = arith.cmpi sgt, %c9223372036854775807_i64, %4 : i64
%16 = arith.select %15, %4, %c9223372036854775807_i64 : i64
%17 = arith.index_cast %16 : i64 to index
%18 = arith.cmpi sge, %17, %14 : index
%19 = arith.select %18, %17, %14 : index
%20 = arith.subi %19, %14 : index
%21 = tensor.extract_slice %11[0, 0, 0, %14] [1, 1, 1, %20] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<?xf32>
%22 = linalg.init_tensor [%20] : tensor<?xf32>
%23 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%21 : tensor<?xf32>) outs(%22 : tensor<?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.subf %cst_3, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?xf32>
%24 = linalg.init_tensor [%20] : tensor<?xf32>
%25 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%23, %cst_5 : tensor<?xf32>, tensor<f64>) outs(%24 : tensor<?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.mulf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<?xf32>
%26 = tensor.expand_shape %25 [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
%27 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%28 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%27, %cst_7 : tensor<i64>, tensor<i64>) outs(%1 : tensor<i64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64):
%609 = arith.addi %arg1, %arg2 : i64
linalg.yield %609 : i64
} -> tensor<i64>
%29 = tensor.extract %28[] : tensor<i64>
%30 = arith.addi %29, %c512_i64 : i64
%31 = arith.cmpi sge, %29, %c0_i64 : i64
%32 = arith.select %31, %29, %30 : i64
%33 = arith.cmpi slt, %32, %c0_i64 : i64
%34 = arith.select %33, %c0_i64, %32 : i64
%35 = arith.cmpi sgt, %34, %c512_i64 : i64
%36 = arith.select %35, %c512_i64, %34 : i64
%37 = arith.index_cast %36 : i64 to index
%38 = arith.cmpi sge, %37, %c0 : index
%39 = arith.select %38, %37, %c0 : index
%40 = tensor.extract_slice %cst_8[0, 0] [1, %39] [1, 1] : tensor<1x512xi64> to tensor<?xi64>
%41 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x512xi64> into tensor<512xi64>
%42 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%43 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%41 : tensor<512xi64>) outs(%42 : tensor<512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%609 = arith.index_cast %arg1 : i64 to index
%610 = linalg.index 1 : index
%611 = arith.cmpi slt, %arg1, %c30522_i64 : i64
cf.assert %611, "index must be smaller than dim size"
%612 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %612, "index must be larger or equal to 0"
%613 = tensor.extract %cst_9[%609, %610] : tensor<30522x384xf32>
linalg.yield %613 : f32
} -> tensor<512x384xf32>
%44 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%45 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_6 : tensor<512xi64>) outs(%44 : tensor<512x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%609 = arith.index_cast %arg1 : i64 to index
%610 = linalg.index 1 : index
%611 = arith.cmpi slt, %arg1, %c2_i64 : i64
cf.assert %611, "index must be smaller than dim size"
%612 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %612, "index must be larger or equal to 0"
%613 = tensor.extract %cst_10[%609, %610] : tensor<2x384xf32>
linalg.yield %613 : f32
} -> tensor<512x384xf32>
%46 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%47 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%43, %45 : tensor<512x384xf32>, tensor<512x384xf32>) outs(%46 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<512x384xf32>
%48 = linalg.init_tensor [%39, 384] : tensor<?x384xf32>
%49 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%40 : tensor<?xi64>) outs(%48 : tensor<?x384xf32>) {
^bb0(%arg1: i64, %arg2: f32):
%609 = arith.index_cast %arg1 : i64 to index
%610 = linalg.index 1 : index
%611 = arith.cmpi slt, %arg1, %c512_i64 : i64
cf.assert %611, "index must be smaller than dim size"
%612 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %612, "index must be larger or equal to 0"
%613 = tensor.extract %cst_11[%609, %610] : tensor<512x384xf32>
linalg.yield %613 : f32
} -> tensor<?x384xf32>
%50 = arith.cmpi eq, %c512, %39 : index
cf.assert %50, "mismatched size for broadcast"
%51 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%52 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%47, %49 : tensor<512x384xf32>, tensor<?x384xf32>) outs(%51 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<512x384xf32>
%53 = linalg.init_tensor [512] : tensor<512xf32>
%54 = linalg.fill(%cst, %53) : f32, tensor<512xf32> -> tensor<512xf32>
%55 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%52 : tensor<512x384xf32>) outs(%54 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<512xf32>
%56 = linalg.init_tensor [512] : tensor<512xf32>
%57 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%55 : tensor<512xf32>) outs(%56 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<512xf32>
%58 = linalg.init_tensor [512] : tensor<512xf32>
%59 = linalg.fill(%cst, %58) : f32, tensor<512xf32> -> tensor<512xf32>
%60 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%52, %57 : tensor<512x384xf32>, tensor<512xf32>) outs(%59 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<512xf32>
%61 = linalg.init_tensor [512] : tensor<512xf32>
%62 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%60 : tensor<512xf32>) outs(%61 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<512xf32>
%63 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%64 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%52, %57, %62, %cst_13, %cst_12 : tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%63 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<512x384xf32>
%65 = tensor.expand_shape %64 [[0, 1], [2]] : tensor<512x384xf32> into tensor<1x512x384xf32>
%66 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%67 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_14 : tensor<384xf32>) outs(%66 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%68 = tensor.expand_shape %67 [[0, 1], [2]] : tensor<512x384xf32> into tensor<1x512x384xf32>
%69 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%70 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_15 : tensor<384x384xf32>) outs(%69 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%71 = tensor.expand_shape %70 [[0, 1], [2]] : tensor<384x384xf32> into tensor<1x384x384xf32>
%72 = linalg.batch_matmul ins(%65, %71 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%68 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%73 = tensor.cast %72 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%74 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%75 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_16 : tensor<384xf32>) outs(%74 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%76 = tensor.expand_shape %75 [[0, 1], [2]] : tensor<512x384xf32> into tensor<1x512x384xf32>
%77 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%78 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_17 : tensor<384x384xf32>) outs(%77 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%79 = tensor.expand_shape %78 [[0, 1], [2]] : tensor<384x384xf32> into tensor<1x384x384xf32>
%80 = linalg.batch_matmul ins(%65, %79 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%76 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%81 = tensor.cast %80 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%82 = tensor.expand_shape %81 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%83 = linalg.init_tensor [1, 12, 512, 32] : tensor<1x12x512x32xf32>
%84 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%82 : tensor<?x?x12x32xf32>) outs(%83 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%85 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%86 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_18 : tensor<384xf32>) outs(%85 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%87 = tensor.expand_shape %86 [[0, 1], [2]] : tensor<512x384xf32> into tensor<1x512x384xf32>
%88 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%89 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_19 : tensor<384x384xf32>) outs(%88 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%90 = tensor.expand_shape %89 [[0, 1], [2]] : tensor<384x384xf32> into tensor<1x384x384xf32>
%91 = linalg.batch_matmul ins(%65, %90 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%87 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%92 = tensor.cast %91 : tensor<1x512x384xf32> to tensor<?x?x384xf32>
%93 = tensor.expand_shape %92 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%94 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%93 : tensor<?x?x12x32xf32>) outs(%83 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%95 = tensor.expand_shape %73 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%96 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%95 : tensor<?x?x12x32xf32>) outs(%83 : tensor<1x12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x12x512x32xf32>
%97 = tensor.collapse_shape %84 [[0, 1], [2], [3]] : tensor<1x12x512x32xf32> into tensor<12x512x32xf32>
%98 = linalg.init_tensor [12, 32, 512] : tensor<12x32x512xf32>
%99 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%97 : tensor<12x512x32xf32>) outs(%98 : tensor<12x32x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<12x32x512xf32>
%100 = tensor.collapse_shape %96 [[0, 1], [2], [3]] : tensor<1x12x512x32xf32> into tensor<12x512x32xf32>
%101 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%102 = linalg.fill(%cst, %101) : f32, tensor<12x512x512xf32> -> tensor<12x512x512xf32>
%103 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%100, %99 : tensor<12x512x32xf32>, tensor<12x32x512xf32>) outs(%102 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<12x512x512xf32>
%104 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%105 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%103, %cst_94 : tensor<12x512x512xf32>, tensor<f64>) outs(%104 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.divf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<12x512x512xf32>
%106 = arith.cmpi eq, %c512, %20 : index
cf.assert %106, "mismatched size for broadcast"
%107 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%108 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%105, %25 : tensor<12x512x512xf32>, tensor<?xf32>) outs(%107 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<12x512x512xf32>
%109 = linalg.init_tensor [12, 512] : tensor<12x512xf32>
%110 = linalg.fill(%cst_0, %109) : f32, tensor<12x512xf32> -> tensor<12x512xf32>
%111 = linalg.init_tensor [12, 512] : tensor<12x512xi64>
%112 = linalg.fill(%c0_i64, %111) : i64, tensor<12x512xi64> -> tensor<12x512xi64>
%113:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%108 : tensor<12x512x512xf32>) outs(%110, %112 : tensor<12x512xf32>, tensor<12x512xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%609 = linalg.index 2 : index
%610 = arith.index_cast %609 : index to i64
%611 = arith.cmpf ogt, %arg1, %arg2 : f32
%612 = arith.select %611, %arg1, %arg2 : f32
%613 = arith.select %611, %610, %arg3 : i64
linalg.yield %612, %613 : f32, i64
} -> (tensor<12x512xf32>, tensor<12x512xi64>)
%114 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%115 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%108, %113#0 : tensor<12x512x512xf32>, tensor<12x512xf32>) outs(%114 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<12x512x512xf32>
%116 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%117 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%115 : tensor<12x512x512xf32>) outs(%116 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.exp %arg1 : f32
linalg.yield %609 : f32
} -> tensor<12x512x512xf32>
%118 = linalg.init_tensor [12, 512] : tensor<12x512xf32>
%119 = linalg.fill(%cst, %118) : f32, tensor<12x512xf32> -> tensor<12x512xf32>
%120 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%117 : tensor<12x512x512xf32>) outs(%119 : tensor<12x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<12x512xf32>
%121 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%122 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%117, %120 : tensor<12x512x512xf32>, tensor<12x512xf32>) outs(%121 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.divf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<12x512x512xf32>
%123 = tensor.collapse_shape %94 [[0, 1], [2], [3]] : tensor<1x12x512x32xf32> into tensor<12x512x32xf32>
%124 = linalg.init_tensor [12, 512, 32] : tensor<12x512x32xf32>
%125 = linalg.fill(%cst, %124) : f32, tensor<12x512x32xf32> -> tensor<12x512x32xf32>
%126 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%122, %123 : tensor<12x512x512xf32>, tensor<12x512x32xf32>) outs(%125 : tensor<12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<12x512x32xf32>
%127 = linalg.init_tensor [512, 12, 32] : tensor<512x12x32xf32>
%128 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%126 : tensor<12x512x32xf32>) outs(%127 : tensor<512x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x12x32xf32>
%129 = tensor.expand_shape %128 [[0, 1], [2], [3]] : tensor<512x12x32xf32> into tensor<1x512x12x32xf32>
%130 = tensor.cast %129 : tensor<1x512x12x32xf32> to tensor<?x?x?x?xf32>
%131 = tensor.collapse_shape %130 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%132 = tensor.cast %131 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%133 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%134 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_20 : tensor<384xf32>) outs(%133 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%135 = tensor.expand_shape %134 [[0, 1], [2]] : tensor<512x384xf32> into tensor<1x512x384xf32>
%136 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%137 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_21 : tensor<384x384xf32>) outs(%136 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%138 = tensor.expand_shape %137 [[0, 1], [2]] : tensor<384x384xf32> into tensor<1x384x384xf32>
%139 = linalg.batch_matmul ins(%132, %138 : tensor<?x?x384xf32>, tensor<1x384x384xf32>) outs(%135 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%140 = tensor.dim %131, %c0 : tensor<?x?x?xf32>
%141 = tensor.dim %131, %c1 : tensor<?x?x?xf32>
%142 = arith.cmpi eq, %140, %c1 : index
cf.assert %142, "mismatched size for broadcast"
%143 = arith.cmpi eq, %141, %c512 : index
cf.assert %143, "mismatched size for broadcast"
%144 = linalg.init_tensor [%140, %141, 384] : tensor<?x?x384xf32>
%145 = tensor.collapse_shape %139 [[0, 1], [2]] : tensor<1x512x384xf32> into tensor<512x384xf32>
%146 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%145, %64 : tensor<512x384xf32>, tensor<512x384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%147 = linalg.init_tensor [%140, %141] : tensor<?x?xf32>
%148 = linalg.fill(%cst, %147) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%149 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%146 : tensor<?x?x384xf32>) outs(%148 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%150 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%149 : tensor<?x?xf32>) outs(%147 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%151 = linalg.fill(%cst, %147) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%152 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%146, %150 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%151 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%153 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%152 : tensor<?x?xf32>) outs(%147 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%154 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%146, %150, %153, %cst_23, %cst_22 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%155 = linalg.init_tensor [%140, %141, 1536] : tensor<?x?x1536xf32>
%156 = linalg.init_tensor [%140, 384, 1536] : tensor<?x384x1536xf32>
%157 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_24 : tensor<1536xf32>) outs(%155 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%158 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_25 : tensor<1536x384xf32>) outs(%156 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%159 = linalg.batch_matmul ins(%154, %158 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%157 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%160 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%159 : tensor<?x?x1536xf32>) outs(%155 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.sqrt %cst_2 : f32
%610 = arith.divf %arg1, %609 : f32
%611 = math.erf %610 : f32
%612 = arith.addf %611, %cst_3 : f32
%613 = arith.mulf %612, %cst_1 : f32
%614 = arith.mulf %arg1, %613 : f32
linalg.yield %614 : f32
} -> tensor<?x?x1536xf32>
%161 = linalg.init_tensor [%140, 1536, 384] : tensor<?x1536x384xf32>
%162 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_26 : tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%163 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_27 : tensor<384x1536xf32>) outs(%161 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%164 = linalg.batch_matmul ins(%160, %163 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%162 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%165 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%164, %154 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%166 = linalg.fill(%cst, %147) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%167 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%165 : tensor<?x?x384xf32>) outs(%166 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%168 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%167 : tensor<?x?xf32>) outs(%147 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%169 = linalg.fill(%cst, %147) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%170 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%165, %168 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%169 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%171 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%170 : tensor<?x?xf32>) outs(%147 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%172 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%165, %168, %171, %cst_29, %cst_28 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%173 = linalg.init_tensor [%140, 384, 384] : tensor<?x384x384xf32>
%174 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_30 : tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%175 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_31 : tensor<384x384xf32>) outs(%173 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%176 = linalg.batch_matmul ins(%172, %175 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%174 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%177 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_32 : tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%178 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_33 : tensor<384x384xf32>) outs(%173 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%179 = linalg.batch_matmul ins(%172, %178 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%177 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%180 = tensor.expand_shape %179 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%181 = linalg.init_tensor [%140, 12, %141, 32] : tensor<?x12x?x32xf32>
%182 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%180 : tensor<?x?x12x32xf32>) outs(%181 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%183 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_34 : tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%184 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_35 : tensor<384x384xf32>) outs(%173 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%185 = linalg.batch_matmul ins(%172, %184 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%183 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%186 = tensor.expand_shape %185 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%187 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%186 : tensor<?x?x12x32xf32>) outs(%181 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%188 = tensor.expand_shape %176 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%189 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%188 : tensor<?x?x12x32xf32>) outs(%181 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%190 = linalg.init_tensor [%140, 12, 32, %141] : tensor<?x12x32x?xf32>
%191 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%182 : tensor<?x12x?x32xf32>) outs(%190 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%192 = linalg.init_tensor [%140, 12, %141, %141] : tensor<?x12x?x?xf32>
%193 = linalg.fill(%cst, %192) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%194 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%189, %191 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%193 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%195 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%194, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%192 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.divf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %142, "mismatched size for broadcast"
%196 = arith.cmpi eq, %141, %20 : index
cf.assert %196, "mismatched size for broadcast"
%197 = tensor.collapse_shape %26 [[0, 1, 2], [3]] : tensor<1x1x1x?xf32> into tensor<1x?xf32>
%198 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%195, %197 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%192 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%199 = linalg.init_tensor [%140, 12, %141] : tensor<?x12x?xf32>
%200 = linalg.fill(%cst_0, %199) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%201 = linalg.init_tensor [%140, 12, %141] : tensor<?x12x?xi64>
%202 = linalg.fill(%c0_i64, %201) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%203:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%198 : tensor<?x12x?x?xf32>) outs(%200, %202 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%609 = linalg.index 3 : index
%610 = arith.index_cast %609 : index to i64
%611 = arith.cmpf ogt, %arg1, %arg2 : f32
%612 = arith.select %611, %arg1, %arg2 : f32
%613 = arith.select %611, %610, %arg3 : i64
linalg.yield %612, %613 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%204 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%198, %203#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%192 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%205 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%204 : tensor<?x12x?x?xf32>) outs(%192 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.exp %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%206 = linalg.init_tensor [%140, 12, %141] : tensor<?x12x?xf32>
%207 = linalg.fill(%cst, %206) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%208 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%205 : tensor<?x12x?x?xf32>) outs(%207 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?xf32>
%209 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%205, %208 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%192 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.divf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%210 = linalg.fill(%cst, %181) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%211 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%209, %187 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%210 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x32xf32>
%212 = linalg.init_tensor [%140, %141, 12, 32] : tensor<?x?x12x32xf32>
%213 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%211 : tensor<?x12x?x32xf32>) outs(%212 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%214 = tensor.cast %213 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%215 = tensor.collapse_shape %214 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%216 = tensor.cast %215 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%217 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_36 : tensor<384xf32>) outs(%144 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%218 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_37 : tensor<384x384xf32>) outs(%173 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%219 = linalg.batch_matmul ins(%216, %218 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%217 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%220 = tensor.dim %215, %c0 : tensor<?x?x?xf32>
%221 = tensor.dim %215, %c1 : tensor<?x?x?xf32>
%222 = arith.cmpi eq, %220, %140 : index
cf.assert %222, "mismatched size for broadcast"
%223 = arith.cmpi eq, %221, %141 : index
cf.assert %223, "mismatched size for broadcast"
%224 = linalg.init_tensor [%220, %221, 384] : tensor<?x?x384xf32>
%225 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%219, %172 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%226 = linalg.init_tensor [%220, %221] : tensor<?x?xf32>
%227 = linalg.fill(%cst, %226) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%228 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%225 : tensor<?x?x384xf32>) outs(%227 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%229 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%228 : tensor<?x?xf32>) outs(%226 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%230 = linalg.fill(%cst, %226) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%231 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%225, %229 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%230 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%232 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%231 : tensor<?x?xf32>) outs(%226 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%233 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%225, %229, %232, %cst_39, %cst_38 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%234 = linalg.init_tensor [%220, %221, 1536] : tensor<?x?x1536xf32>
%235 = linalg.init_tensor [%220, 384, 1536] : tensor<?x384x1536xf32>
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_40 : tensor<1536xf32>) outs(%234 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%237 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_41 : tensor<1536x384xf32>) outs(%235 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%238 = linalg.batch_matmul ins(%233, %237 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%236 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%239 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%238 : tensor<?x?x1536xf32>) outs(%234 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.sqrt %cst_2 : f32
%610 = arith.divf %arg1, %609 : f32
%611 = math.erf %610 : f32
%612 = arith.addf %611, %cst_3 : f32
%613 = arith.mulf %612, %cst_1 : f32
%614 = arith.mulf %arg1, %613 : f32
linalg.yield %614 : f32
} -> tensor<?x?x1536xf32>
%240 = linalg.init_tensor [%220, 1536, 384] : tensor<?x1536x384xf32>
%241 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42 : tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%242 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_43 : tensor<384x1536xf32>) outs(%240 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%243 = linalg.batch_matmul ins(%239, %242 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%241 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%244 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%243, %233 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%245 = linalg.fill(%cst, %226) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%246 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%244 : tensor<?x?x384xf32>) outs(%245 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%247 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%246 : tensor<?x?xf32>) outs(%226 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%248 = linalg.fill(%cst, %226) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%249 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%244, %247 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%248 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%250 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%249 : tensor<?x?xf32>) outs(%226 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%251 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%244, %247, %250, %cst_45, %cst_44 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%252 = linalg.init_tensor [%220, 384, 384] : tensor<?x384x384xf32>
%253 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_46 : tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%254 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_47 : tensor<384x384xf32>) outs(%252 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%255 = linalg.batch_matmul ins(%251, %254 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%253 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%256 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_48 : tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%257 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_49 : tensor<384x384xf32>) outs(%252 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%258 = linalg.batch_matmul ins(%251, %257 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%256 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%259 = tensor.expand_shape %258 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%260 = linalg.init_tensor [%220, 12, %221, 32] : tensor<?x12x?x32xf32>
%261 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%259 : tensor<?x?x12x32xf32>) outs(%260 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%262 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_50 : tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%263 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_51 : tensor<384x384xf32>) outs(%252 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%264 = linalg.batch_matmul ins(%251, %263 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%262 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%265 = tensor.expand_shape %264 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%266 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%265 : tensor<?x?x12x32xf32>) outs(%260 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%267 = tensor.expand_shape %255 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%268 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%267 : tensor<?x?x12x32xf32>) outs(%260 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%269 = linalg.init_tensor [%220, 12, 32, %221] : tensor<?x12x32x?xf32>
%270 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%261 : tensor<?x12x?x32xf32>) outs(%269 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%271 = linalg.init_tensor [%220, 12, %221, %221] : tensor<?x12x?x?xf32>
%272 = linalg.fill(%cst, %271) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%268, %270 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%272 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%273, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%271 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.divf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%275 = arith.cmpi eq, %220, %c1 : index
cf.assert %275, "mismatched size for broadcast"
%276 = arith.cmpi eq, %221, %20 : index
cf.assert %276, "mismatched size for broadcast"
%277 = tensor.collapse_shape %26 [[0, 1, 2], [3]] : tensor<1x1x1x?xf32> into tensor<1x?xf32>
%278 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%274, %277 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%271 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%279 = linalg.init_tensor [%220, 12, %221] : tensor<?x12x?xf32>
%280 = linalg.fill(%cst_0, %279) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%281 = linalg.init_tensor [%220, 12, %221] : tensor<?x12x?xi64>
%282 = linalg.fill(%c0_i64, %281) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%283:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%278 : tensor<?x12x?x?xf32>) outs(%280, %282 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%609 = linalg.index 3 : index
%610 = arith.index_cast %609 : index to i64
%611 = arith.cmpf ogt, %arg1, %arg2 : f32
%612 = arith.select %611, %arg1, %arg2 : f32
%613 = arith.select %611, %610, %arg3 : i64
linalg.yield %612, %613 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%284 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%278, %283#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%271 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%285 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%284 : tensor<?x12x?x?xf32>) outs(%271 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.exp %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%286 = linalg.init_tensor [%220, 12, %221] : tensor<?x12x?xf32>
%287 = linalg.fill(%cst, %286) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%288 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%285 : tensor<?x12x?x?xf32>) outs(%287 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?xf32>
%289 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%285, %288 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%271 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.divf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%290 = linalg.fill(%cst, %260) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%289, %266 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%290 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x32xf32>
%292 = linalg.init_tensor [%220, %221, 12, 32] : tensor<?x?x12x32xf32>
%293 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%291 : tensor<?x12x?x32xf32>) outs(%292 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%294 = tensor.cast %293 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%295 = tensor.collapse_shape %294 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%296 = tensor.cast %295 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%297 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52 : tensor<384xf32>) outs(%224 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%298 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_53 : tensor<384x384xf32>) outs(%252 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%299 = linalg.batch_matmul ins(%296, %298 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%297 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%300 = tensor.dim %295, %c0 : tensor<?x?x?xf32>
%301 = tensor.dim %295, %c1 : tensor<?x?x?xf32>
%302 = arith.cmpi eq, %300, %220 : index
cf.assert %302, "mismatched size for broadcast"
%303 = arith.cmpi eq, %301, %221 : index
cf.assert %303, "mismatched size for broadcast"
%304 = linalg.init_tensor [%300, %301, 384] : tensor<?x?x384xf32>
%305 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%299, %251 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%306 = linalg.init_tensor [%300, %301] : tensor<?x?xf32>
%307 = linalg.fill(%cst, %306) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%308 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%305 : tensor<?x?x384xf32>) outs(%307 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%309 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%308 : tensor<?x?xf32>) outs(%306 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%310 = linalg.fill(%cst, %306) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%311 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%305, %309 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%310 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%312 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%311 : tensor<?x?xf32>) outs(%306 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%313 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%305, %309, %312, %cst_55, %cst_54 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%314 = linalg.init_tensor [%300, %301, 1536] : tensor<?x?x1536xf32>
%315 = linalg.init_tensor [%300, 384, 1536] : tensor<?x384x1536xf32>
%316 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_56 : tensor<1536xf32>) outs(%314 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%317 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_57 : tensor<1536x384xf32>) outs(%315 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%318 = linalg.batch_matmul ins(%313, %317 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%316 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%319 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%318 : tensor<?x?x1536xf32>) outs(%314 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.sqrt %cst_2 : f32
%610 = arith.divf %arg1, %609 : f32
%611 = math.erf %610 : f32
%612 = arith.addf %611, %cst_3 : f32
%613 = arith.mulf %612, %cst_1 : f32
%614 = arith.mulf %arg1, %613 : f32
linalg.yield %614 : f32
} -> tensor<?x?x1536xf32>
%320 = linalg.init_tensor [%300, 1536, 384] : tensor<?x1536x384xf32>
%321 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_58 : tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%322 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_59 : tensor<384x1536xf32>) outs(%320 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%323 = linalg.batch_matmul ins(%319, %322 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%321 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%323, %313 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%325 = linalg.fill(%cst, %306) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%326 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%324 : tensor<?x?x384xf32>) outs(%325 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%327 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%326 : tensor<?x?xf32>) outs(%306 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%328 = linalg.fill(%cst, %306) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%329 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%324, %327 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%328 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%330 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%329 : tensor<?x?xf32>) outs(%306 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%331 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%324, %327, %330, %cst_61, %cst_60 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%332 = linalg.init_tensor [%300, 384, 384] : tensor<?x384x384xf32>
%333 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_62 : tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%334 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63 : tensor<384x384xf32>) outs(%332 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%335 = linalg.batch_matmul ins(%331, %334 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%333 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%336 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64 : tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%337 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65 : tensor<384x384xf32>) outs(%332 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%338 = linalg.batch_matmul ins(%331, %337 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%336 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%339 = tensor.expand_shape %338 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%340 = linalg.init_tensor [%300, 12, %301, 32] : tensor<?x12x?x32xf32>
%341 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%339 : tensor<?x?x12x32xf32>) outs(%340 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%342 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_66 : tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%343 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_67 : tensor<384x384xf32>) outs(%332 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%344 = linalg.batch_matmul ins(%331, %343 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%342 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%345 = tensor.expand_shape %344 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%346 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%345 : tensor<?x?x12x32xf32>) outs(%340 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%347 = tensor.expand_shape %335 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%348 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%347 : tensor<?x?x12x32xf32>) outs(%340 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%349 = linalg.init_tensor [%300, 12, 32, %301] : tensor<?x12x32x?xf32>
%350 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%341 : tensor<?x12x?x32xf32>) outs(%349 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%351 = linalg.init_tensor [%300, 12, %301, %301] : tensor<?x12x?x?xf32>
%352 = linalg.fill(%cst, %351) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%353 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%348, %350 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%352 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%354 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%353, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%351 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.divf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%355 = arith.cmpi eq, %300, %c1 : index
cf.assert %355, "mismatched size for broadcast"
%356 = arith.cmpi eq, %301, %20 : index
cf.assert %356, "mismatched size for broadcast"
%357 = tensor.collapse_shape %26 [[0, 1, 2], [3]] : tensor<1x1x1x?xf32> into tensor<1x?xf32>
%358 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%354, %357 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%351 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%359 = linalg.init_tensor [%300, 12, %301] : tensor<?x12x?xf32>
%360 = linalg.fill(%cst_0, %359) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%361 = linalg.init_tensor [%300, 12, %301] : tensor<?x12x?xi64>
%362 = linalg.fill(%c0_i64, %361) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%363:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%358 : tensor<?x12x?x?xf32>) outs(%360, %362 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%609 = linalg.index 3 : index
%610 = arith.index_cast %609 : index to i64
%611 = arith.cmpf ogt, %arg1, %arg2 : f32
%612 = arith.select %611, %arg1, %arg2 : f32
%613 = arith.select %611, %610, %arg3 : i64
linalg.yield %612, %613 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%364 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%358, %363#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%351 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%365 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%364 : tensor<?x12x?x?xf32>) outs(%351 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.exp %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%366 = linalg.init_tensor [%300, 12, %301] : tensor<?x12x?xf32>
%367 = linalg.fill(%cst, %366) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%368 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%365 : tensor<?x12x?x?xf32>) outs(%367 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%365, %368 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%351 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.divf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%370 = linalg.fill(%cst, %340) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%371 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%369, %346 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%370 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x32xf32>
%372 = linalg.init_tensor [%300, %301, 12, 32] : tensor<?x?x12x32xf32>
%373 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%371 : tensor<?x12x?x32xf32>) outs(%372 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%374 = tensor.cast %373 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%375 = tensor.collapse_shape %374 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%376 = tensor.cast %375 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%377 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_68 : tensor<384xf32>) outs(%304 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%378 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_69 : tensor<384x384xf32>) outs(%332 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%379 = linalg.batch_matmul ins(%376, %378 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%377 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%380 = tensor.dim %375, %c0 : tensor<?x?x?xf32>
%381 = tensor.dim %375, %c1 : tensor<?x?x?xf32>
%382 = arith.cmpi eq, %380, %300 : index
cf.assert %382, "mismatched size for broadcast"
%383 = arith.cmpi eq, %381, %301 : index
cf.assert %383, "mismatched size for broadcast"
%384 = linalg.init_tensor [%380, %381, 384] : tensor<?x?x384xf32>
%385 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%379, %331 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%386 = linalg.init_tensor [%380, %381] : tensor<?x?xf32>
%387 = linalg.fill(%cst, %386) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%388 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%385 : tensor<?x?x384xf32>) outs(%387 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%389 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%388 : tensor<?x?xf32>) outs(%386 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%390 = linalg.fill(%cst, %386) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%391 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%385, %389 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%390 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%392 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%391 : tensor<?x?xf32>) outs(%386 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%385, %389, %392, %cst_71, %cst_70 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%394 = linalg.init_tensor [%380, %381, 1536] : tensor<?x?x1536xf32>
%395 = linalg.init_tensor [%380, 384, 1536] : tensor<?x384x1536xf32>
%396 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_72 : tensor<1536xf32>) outs(%394 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%397 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_73 : tensor<1536x384xf32>) outs(%395 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%398 = linalg.batch_matmul ins(%393, %397 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%396 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%399 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%398 : tensor<?x?x1536xf32>) outs(%394 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.sqrt %cst_2 : f32
%610 = arith.divf %arg1, %609 : f32
%611 = math.erf %610 : f32
%612 = arith.addf %611, %cst_3 : f32
%613 = arith.mulf %612, %cst_1 : f32
%614 = arith.mulf %arg1, %613 : f32
linalg.yield %614 : f32
} -> tensor<?x?x1536xf32>
%400 = linalg.init_tensor [%380, 1536, 384] : tensor<?x1536x384xf32>
%401 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_74 : tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%402 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_75 : tensor<384x1536xf32>) outs(%400 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%403 = linalg.batch_matmul ins(%399, %402 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%401 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%404 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%403, %393 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%405 = linalg.fill(%cst, %386) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%404 : tensor<?x?x384xf32>) outs(%405 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%407 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%406 : tensor<?x?xf32>) outs(%386 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%408 = linalg.fill(%cst, %386) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%409 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%404, %407 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%408 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%410 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%409 : tensor<?x?xf32>) outs(%386 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%411 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%404, %407, %410, %cst_77, %cst_76 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%412 = linalg.init_tensor [%380, 384, 384] : tensor<?x384x384xf32>
%413 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_78 : tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%414 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_79 : tensor<384x384xf32>) outs(%412 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%415 = linalg.batch_matmul ins(%411, %414 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%413 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%416 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_80 : tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%417 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_81 : tensor<384x384xf32>) outs(%412 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%418 = linalg.batch_matmul ins(%411, %417 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%416 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%419 = tensor.expand_shape %418 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%420 = linalg.init_tensor [%380, 12, %381, 32] : tensor<?x12x?x32xf32>
%421 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%419 : tensor<?x?x12x32xf32>) outs(%420 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%422 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_82 : tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%423 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_83 : tensor<384x384xf32>) outs(%412 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%424 = linalg.batch_matmul ins(%411, %423 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%422 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%425 = tensor.expand_shape %424 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%426 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%425 : tensor<?x?x12x32xf32>) outs(%420 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%427 = tensor.expand_shape %415 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%428 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%427 : tensor<?x?x12x32xf32>) outs(%420 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%429 = linalg.init_tensor [%380, 12, 32, %381] : tensor<?x12x32x?xf32>
%430 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%421 : tensor<?x12x?x32xf32>) outs(%429 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%431 = linalg.init_tensor [%380, 12, %381, %381] : tensor<?x12x?x?xf32>
%432 = linalg.fill(%cst, %431) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%433 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%428, %430 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%432 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%434 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%433, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%431 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.divf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%435 = arith.cmpi eq, %380, %c1 : index
cf.assert %435, "mismatched size for broadcast"
%436 = arith.cmpi eq, %381, %20 : index
cf.assert %436, "mismatched size for broadcast"
%437 = tensor.collapse_shape %26 [[0, 1, 2], [3]] : tensor<1x1x1x?xf32> into tensor<1x?xf32>
%438 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%434, %437 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%431 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%439 = linalg.init_tensor [%380, 12, %381] : tensor<?x12x?xf32>
%440 = linalg.fill(%cst_0, %439) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%441 = linalg.init_tensor [%380, 12, %381] : tensor<?x12x?xi64>
%442 = linalg.fill(%c0_i64, %441) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%443:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%438 : tensor<?x12x?x?xf32>) outs(%440, %442 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%609 = linalg.index 3 : index
%610 = arith.index_cast %609 : index to i64
%611 = arith.cmpf ogt, %arg1, %arg2 : f32
%612 = arith.select %611, %arg1, %arg2 : f32
%613 = arith.select %611, %610, %arg3 : i64
linalg.yield %612, %613 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%438, %443#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%431 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%445 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%444 : tensor<?x12x?x?xf32>) outs(%431 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.exp %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%446 = linalg.init_tensor [%380, 12, %381] : tensor<?x12x?xf32>
%447 = linalg.fill(%cst, %446) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%448 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%445 : tensor<?x12x?x?xf32>) outs(%447 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?xf32>
%449 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%445, %448 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%431 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.divf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%450 = linalg.fill(%cst, %420) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%451 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%449, %426 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%450 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x32xf32>
%452 = linalg.init_tensor [%380, %381, 12, 32] : tensor<?x?x12x32xf32>
%453 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%451 : tensor<?x12x?x32xf32>) outs(%452 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%454 = tensor.cast %453 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%455 = tensor.collapse_shape %454 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%456 = tensor.cast %455 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%457 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_84 : tensor<384xf32>) outs(%384 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%458 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_85 : tensor<384x384xf32>) outs(%412 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%459 = linalg.batch_matmul ins(%456, %458 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%457 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%460 = tensor.dim %455, %c0 : tensor<?x?x?xf32>
%461 = tensor.dim %455, %c1 : tensor<?x?x?xf32>
%462 = arith.cmpi eq, %460, %380 : index
cf.assert %462, "mismatched size for broadcast"
%463 = arith.cmpi eq, %461, %381 : index
cf.assert %463, "mismatched size for broadcast"
%464 = linalg.init_tensor [%460, %461, 384] : tensor<?x?x384xf32>
%465 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%459, %411 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%466 = linalg.init_tensor [%460, %461] : tensor<?x?xf32>
%467 = linalg.fill(%cst, %466) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%468 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%465 : tensor<?x?x384xf32>) outs(%467 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%469 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%468 : tensor<?x?xf32>) outs(%466 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%470 = linalg.fill(%cst, %466) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%471 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%465, %469 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%470 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%472 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%471 : tensor<?x?xf32>) outs(%466 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%473 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%465, %469, %472, %cst_87, %cst_86 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%474 = linalg.init_tensor [%460, %461, 1536] : tensor<?x?x1536xf32>
%475 = linalg.init_tensor [%460, 384, 1536] : tensor<?x384x1536xf32>
%476 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_88 : tensor<1536xf32>) outs(%474 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%477 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_89 : tensor<1536x384xf32>) outs(%475 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%478 = linalg.batch_matmul ins(%473, %477 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%476 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%479 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%478 : tensor<?x?x1536xf32>) outs(%474 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.sqrt %cst_2 : f32
%610 = arith.divf %arg1, %609 : f32
%611 = math.erf %610 : f32
%612 = arith.addf %611, %cst_3 : f32
%613 = arith.mulf %612, %cst_1 : f32
%614 = arith.mulf %arg1, %613 : f32
linalg.yield %614 : f32
} -> tensor<?x?x1536xf32>
%480 = linalg.init_tensor [%460, 1536, 384] : tensor<?x1536x384xf32>
%481 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_90 : tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%482 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_91 : tensor<384x1536xf32>) outs(%480 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%483 = linalg.batch_matmul ins(%479, %482 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%481 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%483, %473 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%485 = linalg.fill(%cst, %466) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%486 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%484 : tensor<?x?x384xf32>) outs(%485 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%487 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%486 : tensor<?x?xf32>) outs(%466 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%488 = linalg.fill(%cst, %466) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%489 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%484, %487 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%488 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%490 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%489 : tensor<?x?xf32>) outs(%466 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%491 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%484, %487, %490, %cst_93, %cst_92 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%492 = linalg.init_tensor [%460, 384, 384] : tensor<?x384x384xf32>
%493 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_95 : tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%494 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_96 : tensor<384x384xf32>) outs(%492 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%495 = linalg.batch_matmul ins(%491, %494 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%493 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%496 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_97 : tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%497 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_98 : tensor<384x384xf32>) outs(%492 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%498 = linalg.batch_matmul ins(%491, %497 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%496 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%499 = tensor.expand_shape %498 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%500 = linalg.init_tensor [%460, 12, %461, 32] : tensor<?x12x?x32xf32>
%501 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%499 : tensor<?x?x12x32xf32>) outs(%500 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%502 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_99 : tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%503 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_100 : tensor<384x384xf32>) outs(%492 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%504 = linalg.batch_matmul ins(%491, %503 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%502 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%505 = tensor.expand_shape %504 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%506 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%505 : tensor<?x?x12x32xf32>) outs(%500 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%507 = tensor.expand_shape %495 [[0], [1], [2, 3]] : tensor<?x?x384xf32> into tensor<?x?x12x32xf32>
%508 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%507 : tensor<?x?x12x32xf32>) outs(%500 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x?x32xf32>
%509 = linalg.init_tensor [%460, 12, 32, %461] : tensor<?x12x32x?xf32>
%510 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%501 : tensor<?x12x?x32xf32>) outs(%509 : tensor<?x12x32x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x12x32x?xf32>
%511 = linalg.init_tensor [%460, 12, %461, %461] : tensor<?x12x?x?xf32>
%512 = linalg.fill(%cst, %511) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%513 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%508, %510 : tensor<?x12x?x32xf32>, tensor<?x12x32x?xf32>) outs(%512 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%514 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%513, %cst_94 : tensor<?x12x?x?xf32>, tensor<f64>) outs(%511 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f64, %arg3: f32):
%609 = arith.truncf %arg2 : f64 to f32
%610 = arith.divf %arg1, %609 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x?xf32>
%515 = arith.cmpi eq, %460, %c1 : index
cf.assert %515, "mismatched size for broadcast"
%516 = arith.cmpi eq, %461, %20 : index
cf.assert %516, "mismatched size for broadcast"
%517 = tensor.collapse_shape %26 [[0, 1, 2], [3]] : tensor<1x1x1x?xf32> into tensor<1x?xf32>
%518 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%514, %517 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%511 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%519 = linalg.init_tensor [%460, 12, %461] : tensor<?x12x?xf32>
%520 = linalg.fill(%cst_0, %519) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%521 = linalg.init_tensor [%460, 12, %461] : tensor<?x12x?xi64>
%522 = linalg.fill(%c0_i64, %521) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%523:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%518 : tensor<?x12x?x?xf32>) outs(%520, %522 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%609 = linalg.index 3 : index
%610 = arith.index_cast %609 : index to i64
%611 = arith.cmpf ogt, %arg1, %arg2 : f32
%612 = arith.select %611, %arg1, %arg2 : f32
%613 = arith.select %611, %610, %arg3 : i64
linalg.yield %612, %613 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%524 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%518, %523#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%511 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%525 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%524 : tensor<?x12x?x?xf32>) outs(%511 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.exp %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%526 = linalg.init_tensor [%460, 12, %461] : tensor<?x12x?xf32>
%527 = linalg.fill(%cst, %526) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%528 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%525 : tensor<?x12x?x?xf32>) outs(%527 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?xf32>
%529 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%525, %528 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%511 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.divf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x12x?x?xf32>
%530 = linalg.fill(%cst, %500) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%531 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%529, %506 : tensor<?x12x?x?xf32>, tensor<?x12x?x32xf32>) outs(%530 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.mulf %arg1, %arg2 : f32
%610 = arith.addf %609, %arg3 : f32
linalg.yield %610 : f32
} -> tensor<?x12x?x32xf32>
%532 = linalg.init_tensor [%460, %461, 12, 32] : tensor<?x?x12x32xf32>
%533 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%531 : tensor<?x12x?x32xf32>) outs(%532 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%534 = tensor.cast %533 : tensor<?x?x12x32xf32> to tensor<?x?x?x?xf32>
%535 = tensor.collapse_shape %534 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%536 = tensor.cast %535 : tensor<?x?x?xf32> to tensor<?x?x384xf32>
%537 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_101 : tensor<384xf32>) outs(%464 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%538 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_102 : tensor<384x384xf32>) outs(%492 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%539 = linalg.batch_matmul ins(%536, %538 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%537 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%540 = tensor.dim %535, %c0 : tensor<?x?x?xf32>
%541 = tensor.dim %535, %c1 : tensor<?x?x?xf32>
%542 = arith.cmpi eq, %540, %460 : index
cf.assert %542, "mismatched size for broadcast"
%543 = arith.cmpi eq, %541, %461 : index
cf.assert %543, "mismatched size for broadcast"
%544 = linalg.init_tensor [%540, %541, 384] : tensor<?x?x384xf32>
%545 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%539, %491 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%544 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%546 = linalg.init_tensor [%540, %541] : tensor<?x?xf32>
%547 = linalg.fill(%cst, %546) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%548 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%545 : tensor<?x?x384xf32>) outs(%547 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%549 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%548 : tensor<?x?xf32>) outs(%546 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%550 = linalg.fill(%cst, %546) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%551 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%545, %549 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%550 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%552 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%551 : tensor<?x?xf32>) outs(%546 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%553 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%545, %549, %552, %cst_104, %cst_103 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%544 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%554 = linalg.init_tensor [%540, %541, 1536] : tensor<?x?x1536xf32>
%555 = linalg.init_tensor [%540, 384, 1536] : tensor<?x384x1536xf32>
%556 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_105 : tensor<1536xf32>) outs(%554 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%557 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_106 : tensor<1536x384xf32>) outs(%555 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%558 = linalg.batch_matmul ins(%553, %557 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%556 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%559 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%558 : tensor<?x?x1536xf32>) outs(%554 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.sqrt %cst_2 : f32
%610 = arith.divf %arg1, %609 : f32
%611 = math.erf %610 : f32
%612 = arith.addf %611, %cst_3 : f32
%613 = arith.mulf %612, %cst_1 : f32
%614 = arith.mulf %arg1, %613 : f32
linalg.yield %614 : f32
} -> tensor<?x?x1536xf32>
%560 = linalg.init_tensor [%540, 1536, 384] : tensor<?x1536x384xf32>
%561 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_107 : tensor<384xf32>) outs(%544 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%562 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_108 : tensor<384x1536xf32>) outs(%560 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%563 = linalg.batch_matmul ins(%559, %562 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%561 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%564 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%563, %553 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%544 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.addf %arg1, %arg2 : f32
linalg.yield %609 : f32
} -> tensor<?x?x384xf32>
%565 = linalg.fill(%cst, %546) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%566 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%564 : tensor<?x?x384xf32>) outs(%565 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.addf %arg2, %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%567 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%566 : tensor<?x?xf32>) outs(%546 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%568 = linalg.fill(%cst, %546) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%569 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%564, %567 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%568 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.mulf %609, %609 : f32
%611 = arith.addf %arg3, %610 : f32
linalg.yield %611 : f32
} -> tensor<?x?xf32>
%570 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%569 : tensor<?x?xf32>) outs(%546 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = arith.divf %arg1, %cst_115 : f32
linalg.yield %609 : f32
} -> tensor<?x?xf32>
%571 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%564, %567, %570, %cst_110, %cst_109 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%544 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%609 = arith.subf %arg1, %arg2 : f32
%610 = arith.truncf %cst_4 : f64 to f32
%611 = arith.addf %arg3, %610 : f32
%612 = math.rsqrt %611 : f32
%613 = arith.mulf %609, %612 : f32
%614 = arith.mulf %613, %arg4 : f32
%615 = arith.addf %614, %arg5 : f32
linalg.yield %615 : f32
} -> tensor<?x?x384xf32>
%572 = arith.index_cast %540 : index to i64
%573 = arith.cmpi sgt, %c0_i64, %572 : i64
%574 = arith.select %573, %572, %c0_i64 : i64
%575 = arith.index_cast %574 : i64 to index
%576 = arith.cmpi sgt, %c9223372036854775807_i64, %572 : i64
%577 = arith.select %576, %572, %c9223372036854775807_i64 : i64
%578 = arith.index_cast %577 : i64 to index
%579 = arith.cmpi sge, %578, %575 : index
%580 = arith.select %579, %578, %575 : index
%581 = arith.subi %580, %575 : index
%582 = tensor.extract_slice %571[%575, 0, 0] [%581, %541, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%583 = arith.index_cast %541 : index to i64
%584 = arith.cmpi sgt, %c0_i64, %583 : i64
%585 = arith.select %584, %583, %c0_i64 : i64
%586 = arith.index_cast %585 : i64 to index
%587 = arith.cmpi sgt, %c1_i64, %583 : i64
%588 = arith.select %587, %583, %c1_i64 : i64
%589 = arith.index_cast %588 : i64 to index
%590 = arith.cmpi sge, %589, %586 : index
%591 = arith.select %590, %589, %586 : index
%592 = arith.subi %591, %586 : index
%593 = tensor.extract_slice %582[0, %586, 0] [%581, %592, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%594 = tensor.cast %593 : tensor<?x?x384xf32> to tensor<?x1x384xf32>
%595 = tensor.collapse_shape %594 [[0, 1], [2]] : tensor<?x1x384xf32> into tensor<?x384xf32>
%596 = linalg.init_tensor [%581, 384] : tensor<?x384xf32>
%597 = linalg.init_tensor [384, 384] : tensor<384x384xf32>
%598 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_111 : tensor<384xf32>) outs(%596 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384xf32>
%599 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_112 : tensor<384x384xf32>) outs(%597 : tensor<384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x384xf32>
%600 = linalg.matmul ins(%595, %599 : tensor<?x384xf32>, tensor<384x384xf32>) outs(%598 : tensor<?x384xf32>) -> tensor<?x384xf32>
%601 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%600 : tensor<?x384xf32>) outs(%596 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%609 = math.tanh %arg1 : f32
linalg.yield %609 : f32
} -> tensor<?x384xf32>
%602 = linalg.init_tensor [%581, 2] : tensor<?x2xf32>
%603 = linalg.init_tensor [384, 2] : tensor<384x2xf32>
%604 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_113 : tensor<2xf32>) outs(%602 : tensor<?x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x2xf32>
%605 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_114 : tensor<2x384xf32>) outs(%603 : tensor<384x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<384x2xf32>
%606 = linalg.matmul ins(%601, %605 : tensor<?x384xf32>, tensor<384x2xf32>) outs(%604 : tensor<?x2xf32>) -> tensor<?x2xf32>
%607 = tensor.dim %606, %c0 : tensor<?x2xf32>
%608 = hal.tensor.export %606 : tensor<?x2xf32>{%607} -> !hal.buffer_view
return %608 : !hal.buffer_view
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant 3.840000e+02 : f32
%c512 = arith.constant 512 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c1_i64 = arith.constant 1 : i64
%c512_i64 = arith.constant 512 : i64
%cst_0 = arith.constant dense<[-0.00115577725, 0.00115577038]> : tensor<2xf32>
%cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_2 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_3 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_4 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_5 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_6 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_7 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_8 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_20 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_21 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_94 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_95 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_100 = arith.constant 9.9999999999999998E-13 : f64
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_101 = arith.constant 1.000000e+00 : f32
%cst_102 = arith.constant 2.000000e+00 : f32
%cst_103 = arith.constant 5.000000e-01 : f32
%cst_104 = arith.constant -3.40282347E+38 : f32
%cst_105 = arith.constant 0.000000e+00 : f32
%c0_i64 = arith.constant 0 : i64
%c30522_i64 = arith.constant 30522 : i64
%false = arith.constant false
%true = arith.constant true
%cst_106 = arith.constant -1.000000e+04 : f32
%cst_107 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_108 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_109 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_110 = arith.constant 5.6568542494923806 : f64
%cst_111 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_112 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_113 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x2xf32>
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%1 = linalg.init_tensor [] : tensor<i64>
%2 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%3 = tensor.extract %2[] : tensor<i64>
%4 = arith.index_cast %3 : i64 to index
%5 = linalg.init_tensor [1, %4] : tensor<1x?xf32>
%6 = linalg.fill(%cst_101, %5) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
%7 = tensor.extract_slice %6[0, 0] [1, %4] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
%8 = tensor.dim %7, %c0 : tensor<?xf32>
%9 = tensor.dim %7, %c0 : tensor<?xf32>
%10 = flow.tensor.reshape %7 : tensor<?xf32>{%9} -> tensor<?x1x1x?xf32>{%c1, %8}
%11 = arith.cmpi sgt, %c0_i64, %3 : i64
%12 = arith.select %11, %3, %c0_i64 : i64
%13 = arith.index_cast %12 : i64 to index
%14 = arith.cmpi sgt, %c9223372036854775807_i64, %3 : i64
%15 = arith.select %14, %3, %c9223372036854775807_i64 : i64
%16 = arith.index_cast %15 : i64 to index
%17 = arith.cmpi sge, %16, %13 : index
%18 = arith.select %17, %16, %13 : index
%19 = arith.subi %18, %13 : index
%20 = tensor.extract_slice %10[0, 0, 0, %13] [1, 1, 1, %19] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<?xf32>
%21 = linalg.init_tensor [%19] : tensor<?xf32>
%22 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%20 : tensor<?xf32>) outs(%21 : tensor<?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.subf %cst_101, %arg1 : f32
%553 = arith.mulf %552, %cst_106 : f32
linalg.yield %553 : f32
} -> tensor<?xf32>
%23 = tensor.dim %22, %c0 : tensor<?xf32>
%24 = tensor.dim %22, %c0 : tensor<?xf32>
%25 = linalg.fill(%c512_i64, %1) : i64, tensor<i64> -> tensor<i64>
%26 = tensor.extract %25[] : tensor<i64>
%27 = arith.addi %26, %c512_i64 : i64
%28 = arith.cmpi sge, %26, %c0_i64 : i64
%29 = arith.select %28, %26, %27 : i64
%30 = arith.cmpi slt, %29, %c0_i64 : i64
%31 = arith.select %30, %c0_i64, %29 : i64
%32 = arith.cmpi sgt, %31, %c512_i64 : i64
%33 = arith.select %32, %c512_i64, %31 : i64
%34 = arith.index_cast %33 : i64 to index
%35 = arith.cmpi sge, %34, %c0 : index
%36 = arith.select %35, %34, %c0 : index
%37 = tensor.extract_slice %cst_99[0, 0] [1, %36] [1, 1] : tensor<1x512xi64> to tensor<?xi64>
%38 = flow.tensor.reshape %0 : tensor<1x512xi64> -> tensor<512xi64>
%39 = arith.cmpi eq, %c512, %36 : index
cf.assert %39, "mismatched size for broadcast"
%40 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%38, %37 : tensor<512xi64>, tensor<?xi64>) outs(%40 : tensor<512x384xf32>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: f32):
%552 = linalg.index 1 : index
cf.assert %true, "index must be smaller than dim size"
cf.assert %true, "index must be larger or equal to 0"
%553 = tensor.extract %cst_97[%c0, %552] : tensor<2x384xf32>
%554 = arith.index_cast %arg1 : i64 to index
%555 = arith.cmpi slt, %arg1, %c30522_i64 : i64
cf.assert %555, "index must be smaller than dim size"
%556 = arith.cmpi sge, %arg1, %c0_i64 : i64
cf.assert %556, "index must be larger or equal to 0"
%557 = tensor.extract %cst_98[%554, %552] : tensor<30522x384xf32>
%558 = arith.index_cast %arg2 : i64 to index
%559 = arith.cmpi slt, %arg2, %c512_i64 : i64
cf.assert %559, "index must be smaller than dim size"
%560 = arith.cmpi sge, %arg2, %c0_i64 : i64
cf.assert %560, "index must be larger or equal to 0"
%561 = tensor.extract %cst_96[%558, %552] : tensor<512x384xf32>
%562 = arith.addf %557, %553 : f32
%563 = arith.addf %562, %561 : f32
linalg.yield %563 : f32
} -> tensor<512x384xf32>
%42 = linalg.init_tensor [512] : tensor<512xf32>
%43 = linalg.fill(%cst_105, %42) : f32, tensor<512xf32> -> tensor<512xf32>
%44 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%41 : tensor<512x384xf32>) outs(%43 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<512xf32>
%45 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%44 : tensor<512xf32>) outs(%42 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<512xf32>
%46 = linalg.fill(%cst_105, %42) : f32, tensor<512xf32> -> tensor<512xf32>
%47 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%41, %45 : tensor<512x384xf32>, tensor<512xf32>) outs(%46 : tensor<512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<512xf32>
%48 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%41, %45, %47, %cst_94, %cst_95 : tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%40 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<512x384xf32>
%49 = flow.tensor.reshape %48 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%50 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_93 : tensor<384xf32>) outs(%40 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%51 = flow.tensor.reshape %50 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%52 = linalg.batch_matmul ins(%49, %cst_107 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%51 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%53 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_92 : tensor<384xf32>) outs(%40 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%54 = flow.tensor.reshape %53 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%55 = linalg.batch_matmul ins(%49, %cst_108 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%54 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%56 = flow.tensor.reshape %55 : tensor<1x512x384xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%57 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_91 : tensor<384xf32>) outs(%40 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%58 = flow.tensor.reshape %57 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%59 = linalg.batch_matmul ins(%49, %cst_109 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%58 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
%60 = flow.tensor.reshape %59 : tensor<1x512x384xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%61 = flow.tensor.reshape %52 : tensor<1x512x384xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%62 = linalg.init_tensor [12, 512, 32] : tensor<12x512x32xf32>
%63 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%56 : tensor<?x?x12x32xf32>) outs(%62 : tensor<12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<12x512x32xf32>
%64 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%61 : tensor<?x?x12x32xf32>) outs(%62 : tensor<12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<12x512x32xf32>
%65 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%66 = linalg.fill(%cst_105, %65) : f32, tensor<12x512x512xf32> -> tensor<12x512x512xf32>
%67 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%64, %63 : tensor<12x512x32xf32>, tensor<12x512x32xf32>) outs(%66 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.mulf %arg1, %arg2 : f32
%553 = arith.addf %552, %arg3 : f32
linalg.yield %553 : f32
} -> tensor<12x512x512xf32>
%68 = arith.cmpi eq, %c512, %19 : index
cf.assert %68, "mismatched size for broadcast"
%69 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%67, %22 : tensor<12x512x512xf32>, tensor<?xf32>) outs(%65 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.truncf %cst_110 : f64 to f32
%553 = arith.divf %arg1, %552 : f32
%554 = arith.addf %553, %arg2 : f32
linalg.yield %554 : f32
} -> tensor<12x512x512xf32>
%70 = linalg.init_tensor [12, 512] : tensor<12x512xf32>
%71 = linalg.fill(%cst_104, %70) : f32, tensor<12x512xf32> -> tensor<12x512xf32>
%72 = linalg.init_tensor [12, 512] : tensor<12x512xi64>
%73 = linalg.fill(%c0_i64, %72) : i64, tensor<12x512xi64> -> tensor<12x512xi64>
%74:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%69 : tensor<12x512x512xf32>) outs(%71, %73 : tensor<12x512xf32>, tensor<12x512xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%552 = linalg.index 2 : index
%553 = arith.index_cast %552 : index to i64
%554 = arith.cmpf ogt, %arg1, %arg2 : f32
%555 = arith.select %554, %arg1, %arg2 : f32
%556 = arith.select %554, %553, %arg3 : i64
linalg.yield %555, %556 : f32, i64
} -> (tensor<12x512xf32>, tensor<12x512xi64>)
%75 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%69, %74#0 : tensor<12x512x512xf32>, tensor<12x512xf32>) outs(%65 : tensor<12x512x512xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = math.exp %552 : f32
linalg.yield %553 : f32
} -> tensor<12x512x512xf32>
%76 = linalg.fill(%cst_105, %70) : f32, tensor<12x512xf32> -> tensor<12x512xf32>
%77 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%75 : tensor<12x512x512xf32>) outs(%76 : tensor<12x512xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<12x512xf32>
%78 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%60 : tensor<?x?x12x32xf32>) outs(%62 : tensor<12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<12x512x32xf32>
%79 = linalg.fill(%cst_105, %62) : f32, tensor<12x512x32xf32> -> tensor<12x512x32xf32>
%80 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%75, %77, %78 : tensor<12x512x512xf32>, tensor<12x512xf32>, tensor<12x512x32xf32>) outs(%79 : tensor<12x512x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%552 = arith.divf %arg1, %arg2 : f32
%553 = arith.mulf %552, %arg3 : f32
%554 = arith.addf %553, %arg4 : f32
linalg.yield %554 : f32
} -> tensor<12x512x32xf32>
%81 = linalg.init_tensor [512, 12, 32] : tensor<512x12x32xf32>
%82 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%80 : tensor<12x512x32xf32>) outs(%81 : tensor<512x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x12x32xf32>
%83 = flow.tensor.reshape %82 : tensor<512x12x32xf32> -> tensor<?x?x384xf32>{%c1, %c512}
%84 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_90 : tensor<384xf32>) outs(%40 : tensor<512x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<512x384xf32>
%85 = flow.tensor.reshape %84 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%86 = linalg.batch_matmul ins(%83, %cst_111 : tensor<?x?x384xf32>, tensor<1x384x384xf32>) outs(%85 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
cf.assert %true, "mismatched size for broadcast"
cf.assert %true, "mismatched size for broadcast"
%87 = linalg.init_tensor [1, 512, 384] : tensor<1x512x384xf32>
%88 = flow.tensor.reshape %87 : tensor<1x512x384xf32> -> tensor<?x?x384xf32>{%c1, %c512}
%89 = flow.tensor.reshape %86 : tensor<1x512x384xf32> -> tensor<512x384xf32>
%90 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%89, %48 : tensor<512x384xf32>, tensor<512x384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%91 = linalg.init_tensor [1, 512] : tensor<1x512xf32>
%92 = flow.tensor.reshape %91 : tensor<1x512xf32> -> tensor<?x?xf32>{%c1, %c512}
%93 = linalg.fill(%cst_105, %92) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%94 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%90 : tensor<?x?x384xf32>) outs(%93 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%95 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%94 : tensor<?x?xf32>) outs(%92 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%96 = linalg.fill(%cst_105, %92) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%97 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%90, %95 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%96 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%98 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%90, %95, %97, %cst_88, %cst_89 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%99 = linalg.init_tensor [1, 512, 1536] : tensor<1x512x1536xf32>
%100 = flow.tensor.reshape %99 : tensor<1x512x1536xf32> -> tensor<?x?x1536xf32>{%c1, %c512}
%101 = linalg.init_tensor [1, 384, 1536] : tensor<1x384x1536xf32>
%102 = flow.tensor.reshape %101 : tensor<1x384x1536xf32> -> tensor<?x384x1536xf32>{%c1}
%103 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_87 : tensor<1536xf32>) outs(%100 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%104 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_86 : tensor<1536x384xf32>) outs(%102 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%105 = linalg.batch_matmul ins(%98, %104 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%103 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%106 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%105 : tensor<?x?x1536xf32>) outs(%100 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.sqrt %cst_102 : f32
%553 = arith.divf %arg1, %552 : f32
%554 = math.erf %553 : f32
%555 = arith.addf %554, %cst_101 : f32
%556 = arith.mulf %555, %cst_103 : f32
%557 = arith.mulf %arg1, %556 : f32
linalg.yield %557 : f32
} -> tensor<?x?x1536xf32>
%107 = linalg.init_tensor [1, 1536, 384] : tensor<1x1536x384xf32>
%108 = flow.tensor.reshape %107 : tensor<1x1536x384xf32> -> tensor<?x1536x384xf32>{%c1}
%109 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_85 : tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%110 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_84 : tensor<384x1536xf32>) outs(%108 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%111 = linalg.batch_matmul ins(%106, %110 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%109 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%112 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%111, %98 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%113 = linalg.fill(%cst_105, %92) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%114 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%112 : tensor<?x?x384xf32>) outs(%113 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%115 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%114 : tensor<?x?xf32>) outs(%92 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%116 = linalg.fill(%cst_105, %92) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%117 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%112, %115 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%116 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%118 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%112, %115, %117, %cst_82, %cst_83 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%119 = linalg.init_tensor [1, 384, 384] : tensor<1x384x384xf32>
%120 = flow.tensor.reshape %119 : tensor<1x384x384xf32> -> tensor<?x384x384xf32>{%c1}
%121 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_81 : tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%122 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_80 : tensor<384x384xf32>) outs(%120 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%123 = linalg.batch_matmul ins(%118, %122 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%121 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%124 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_79 : tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%125 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_78 : tensor<384x384xf32>) outs(%120 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%126 = linalg.batch_matmul ins(%118, %125 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%124 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%127 = tensor.dim %126, %c0 : tensor<?x?x384xf32>
%128 = tensor.dim %126, %c1 : tensor<?x?x384xf32>
%129 = tensor.dim %126, %c0 : tensor<?x?x384xf32>
%130 = tensor.dim %126, %c1 : tensor<?x?x384xf32>
%131 = flow.tensor.reshape %126 : tensor<?x?x384xf32>{%129, %130} -> tensor<?x?x12x32xf32>{%127, %128}
%132 = linalg.init_tensor [1, 12, 512, 32] : tensor<1x12x512x32xf32>
%133 = flow.tensor.reshape %132 : tensor<1x12x512x32xf32> -> tensor<?x12x?x32xf32>{%c1, %c512}
%134 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_77 : tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%135 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_76 : tensor<384x384xf32>) outs(%120 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%136 = linalg.batch_matmul ins(%118, %135 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%134 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%137 = tensor.dim %136, %c0 : tensor<?x?x384xf32>
%138 = tensor.dim %136, %c1 : tensor<?x?x384xf32>
%139 = tensor.dim %136, %c0 : tensor<?x?x384xf32>
%140 = tensor.dim %136, %c1 : tensor<?x?x384xf32>
%141 = flow.tensor.reshape %136 : tensor<?x?x384xf32>{%139, %140} -> tensor<?x?x12x32xf32>{%137, %138}
%142 = tensor.dim %123, %c0 : tensor<?x?x384xf32>
%143 = tensor.dim %123, %c1 : tensor<?x?x384xf32>
%144 = tensor.dim %123, %c0 : tensor<?x?x384xf32>
%145 = tensor.dim %123, %c1 : tensor<?x?x384xf32>
%146 = flow.tensor.reshape %123 : tensor<?x?x384xf32>{%144, %145} -> tensor<?x?x12x32xf32>{%142, %143}
%147 = linalg.init_tensor [1, 12, 512, 512] : tensor<1x12x512x512xf32>
%148 = flow.tensor.reshape %147 : tensor<1x12x512x512xf32> -> tensor<?x12x?x?xf32>{%c1, %c512, %c512}
%149 = linalg.fill(%cst_105, %148) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%150 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%146, %131 : tensor<?x?x12x32xf32>, tensor<?x?x12x32xf32>) outs(%149 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.mulf %arg1, %arg2 : f32
%553 = arith.addf %552, %arg3 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %true, "mismatched size for broadcast"
cf.assert %68, "mismatched size for broadcast"
%151 = flow.tensor.reshape %22 : tensor<?xf32>{%24} -> tensor<1x?xf32>{%23}
%152 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%150, %151 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%148 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.truncf %cst_110 : f64 to f32
%553 = arith.divf %arg1, %552 : f32
%554 = arith.addf %553, %arg2 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x?xf32>
%153 = linalg.init_tensor [1, 12, 512] : tensor<1x12x512xf32>
%154 = flow.tensor.reshape %153 : tensor<1x12x512xf32> -> tensor<?x12x?xf32>{%c1, %c512}
%155 = linalg.fill(%cst_104, %154) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%156 = linalg.init_tensor [1, 12, 512] : tensor<1x12x512xi64>
%157 = flow.tensor.reshape %156 : tensor<1x12x512xi64> -> tensor<?x12x?xi64>{%c1, %c512}
%158 = linalg.fill(%c0_i64, %157) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%159:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%152 : tensor<?x12x?x?xf32>) outs(%155, %158 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%552 = linalg.index 3 : index
%553 = arith.index_cast %552 : index to i64
%554 = arith.cmpf ogt, %arg1, %arg2 : f32
%555 = arith.select %554, %arg1, %arg2 : f32
%556 = arith.select %554, %553, %arg3 : i64
linalg.yield %555, %556 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%160 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%152, %159#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%148 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = math.exp %552 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%161 = linalg.fill(%cst_105, %154) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%162 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%160 : tensor<?x12x?x?xf32>) outs(%161 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x12x?xf32>
%163 = linalg.fill(%cst_105, %133) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%164 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%160, %162, %141 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>, tensor<?x?x12x32xf32>) outs(%163 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%552 = arith.divf %arg1, %arg2 : f32
%553 = arith.mulf %552, %arg3 : f32
%554 = arith.addf %553, %arg4 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x32xf32>
%165 = linalg.init_tensor [1, 512, 12, 32] : tensor<1x512x12x32xf32>
%166 = flow.tensor.reshape %165 : tensor<1x512x12x32xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%167 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%164 : tensor<?x12x?x32xf32>) outs(%166 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%168 = tensor.dim %167, %c0 : tensor<?x?x12x32xf32>
%169 = tensor.dim %167, %c1 : tensor<?x?x12x32xf32>
%170 = flow.tensor.reshape %167 : tensor<?x?x12x32xf32>{%168, %169} -> tensor<?x?x384xf32>{%168, %169}
%171 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_75 : tensor<384xf32>) outs(%88 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%172 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_74 : tensor<384x384xf32>) outs(%120 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%173 = linalg.batch_matmul ins(%170, %172 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%171 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
cf.assert %false, "mismatched size for broadcast"
cf.assert %true, "mismatched size for broadcast"
%174 = linalg.init_tensor [0, 512, 384] : tensor<0x512x384xf32>
%175 = flow.tensor.reshape %174 : tensor<0x512x384xf32> -> tensor<?x?x384xf32>{%c0, %c512}
%176 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%173, %118 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%177 = linalg.init_tensor [0, 512] : tensor<0x512xf32>
%178 = flow.tensor.reshape %177 : tensor<0x512xf32> -> tensor<?x?xf32>{%c0, %c512}
%179 = linalg.fill(%cst_105, %178) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%180 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%176 : tensor<?x?x384xf32>) outs(%179 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%181 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%180 : tensor<?x?xf32>) outs(%178 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%182 = linalg.fill(%cst_105, %178) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%183 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%176, %181 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%182 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%184 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%176, %181, %183, %cst_72, %cst_73 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%185 = linalg.init_tensor [0, 512, 1536] : tensor<0x512x1536xf32>
%186 = flow.tensor.reshape %185 : tensor<0x512x1536xf32> -> tensor<?x?x1536xf32>{%c0, %c512}
%187 = linalg.init_tensor [0, 384, 1536] : tensor<0x384x1536xf32>
%188 = flow.tensor.reshape %187 : tensor<0x384x1536xf32> -> tensor<?x384x1536xf32>{%c0}
%189 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_71 : tensor<1536xf32>) outs(%186 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%190 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_70 : tensor<1536x384xf32>) outs(%188 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%191 = linalg.batch_matmul ins(%184, %190 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%189 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%192 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%191 : tensor<?x?x1536xf32>) outs(%186 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.sqrt %cst_102 : f32
%553 = arith.divf %arg1, %552 : f32
%554 = math.erf %553 : f32
%555 = arith.addf %554, %cst_101 : f32
%556 = arith.mulf %555, %cst_103 : f32
%557 = arith.mulf %arg1, %556 : f32
linalg.yield %557 : f32
} -> tensor<?x?x1536xf32>
%193 = linalg.init_tensor [0, 1536, 384] : tensor<0x1536x384xf32>
%194 = flow.tensor.reshape %193 : tensor<0x1536x384xf32> -> tensor<?x1536x384xf32>{%c0}
%195 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_69 : tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%196 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_68 : tensor<384x1536xf32>) outs(%194 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%197 = linalg.batch_matmul ins(%192, %196 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%195 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%198 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%197, %184 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%199 = linalg.fill(%cst_105, %178) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%200 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%198 : tensor<?x?x384xf32>) outs(%199 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%201 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%200 : tensor<?x?xf32>) outs(%178 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%202 = linalg.fill(%cst_105, %178) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%203 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%198, %201 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%202 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%204 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%198, %201, %203, %cst_66, %cst_67 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%205 = linalg.init_tensor [0, 384, 384] : tensor<0x384x384xf32>
%206 = flow.tensor.reshape %205 : tensor<0x384x384xf32> -> tensor<?x384x384xf32>{%c0}
%207 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65 : tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%208 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64 : tensor<384x384xf32>) outs(%206 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%209 = linalg.batch_matmul ins(%204, %208 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%207 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%210 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63 : tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%211 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_62 : tensor<384x384xf32>) outs(%206 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%212 = linalg.batch_matmul ins(%204, %211 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%210 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%213 = tensor.dim %212, %c0 : tensor<?x?x384xf32>
%214 = tensor.dim %212, %c1 : tensor<?x?x384xf32>
%215 = tensor.dim %212, %c0 : tensor<?x?x384xf32>
%216 = tensor.dim %212, %c1 : tensor<?x?x384xf32>
%217 = flow.tensor.reshape %212 : tensor<?x?x384xf32>{%215, %216} -> tensor<?x?x12x32xf32>{%213, %214}
%218 = linalg.init_tensor [0, 12, 512, 32] : tensor<0x12x512x32xf32>
%219 = flow.tensor.reshape %218 : tensor<0x12x512x32xf32> -> tensor<?x12x?x32xf32>{%c0, %c512}
%220 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_61 : tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%221 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_60 : tensor<384x384xf32>) outs(%206 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%222 = linalg.batch_matmul ins(%204, %221 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%220 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%223 = tensor.dim %222, %c0 : tensor<?x?x384xf32>
%224 = tensor.dim %222, %c1 : tensor<?x?x384xf32>
%225 = tensor.dim %222, %c0 : tensor<?x?x384xf32>
%226 = tensor.dim %222, %c1 : tensor<?x?x384xf32>
%227 = flow.tensor.reshape %222 : tensor<?x?x384xf32>{%225, %226} -> tensor<?x?x12x32xf32>{%223, %224}
%228 = tensor.dim %209, %c0 : tensor<?x?x384xf32>
%229 = tensor.dim %209, %c1 : tensor<?x?x384xf32>
%230 = tensor.dim %209, %c0 : tensor<?x?x384xf32>
%231 = tensor.dim %209, %c1 : tensor<?x?x384xf32>
%232 = flow.tensor.reshape %209 : tensor<?x?x384xf32>{%230, %231} -> tensor<?x?x12x32xf32>{%228, %229}
%233 = linalg.init_tensor [0, 12, 512, 512] : tensor<0x12x512x512xf32>
%234 = flow.tensor.reshape %233 : tensor<0x12x512x512xf32> -> tensor<?x12x?x?xf32>{%c0, %c512, %c512}
%235 = linalg.fill(%cst_105, %234) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%236 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%232, %217 : tensor<?x?x12x32xf32>, tensor<?x?x12x32xf32>) outs(%235 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.mulf %arg1, %arg2 : f32
%553 = arith.addf %552, %arg3 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
cf.assert %false, "mismatched size for broadcast"
cf.assert %68, "mismatched size for broadcast"
%237 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%236, %151 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.truncf %cst_110 : f64 to f32
%553 = arith.divf %arg1, %552 : f32
%554 = arith.addf %553, %arg2 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x?xf32>
%238 = linalg.init_tensor [0, 12, 512] : tensor<0x12x512xf32>
%239 = flow.tensor.reshape %238 : tensor<0x12x512xf32> -> tensor<?x12x?xf32>{%c0, %c512}
%240 = linalg.fill(%cst_104, %239) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%241 = linalg.init_tensor [0, 12, 512] : tensor<0x12x512xi64>
%242 = flow.tensor.reshape %241 : tensor<0x12x512xi64> -> tensor<?x12x?xi64>{%c0, %c512}
%243 = linalg.fill(%c0_i64, %242) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%244:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%237 : tensor<?x12x?x?xf32>) outs(%240, %243 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%552 = linalg.index 3 : index
%553 = arith.index_cast %552 : index to i64
%554 = arith.cmpf ogt, %arg1, %arg2 : f32
%555 = arith.select %554, %arg1, %arg2 : f32
%556 = arith.select %554, %553, %arg3 : i64
linalg.yield %555, %556 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%245 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%237, %244#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%234 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = math.exp %552 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%246 = linalg.fill(%cst_105, %239) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%247 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%245 : tensor<?x12x?x?xf32>) outs(%246 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x12x?xf32>
%248 = linalg.fill(%cst_105, %219) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%249 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%245, %247, %227 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>, tensor<?x?x12x32xf32>) outs(%248 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%552 = arith.divf %arg1, %arg2 : f32
%553 = arith.mulf %552, %arg3 : f32
%554 = arith.addf %553, %arg4 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x32xf32>
%250 = linalg.init_tensor [0, 512, 12, 32] : tensor<0x512x12x32xf32>
%251 = flow.tensor.reshape %250 : tensor<0x512x12x32xf32> -> tensor<?x?x12x32xf32>{%c0, %c512}
%252 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%249 : tensor<?x12x?x32xf32>) outs(%251 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%253 = tensor.dim %252, %c0 : tensor<?x?x12x32xf32>
%254 = tensor.dim %252, %c1 : tensor<?x?x12x32xf32>
%255 = flow.tensor.reshape %252 : tensor<?x?x12x32xf32>{%253, %254} -> tensor<?x?x384xf32>{%253, %254}
%256 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_59 : tensor<384xf32>) outs(%175 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%257 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_58 : tensor<384x384xf32>) outs(%206 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%258 = linalg.batch_matmul ins(%255, %257 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%256 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%259 = arith.cmpi eq, %168, %c0 : index
cf.assert %259, "mismatched size for broadcast"
%260 = arith.cmpi eq, %169, %c512 : index
cf.assert %260, "mismatched size for broadcast"
%261 = linalg.init_tensor [%168, %169, 384] : tensor<?x?x384xf32>
%262 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%258, %204 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%263 = linalg.init_tensor [%168, %169] : tensor<?x?xf32>
%264 = linalg.fill(%cst_105, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%265 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%262 : tensor<?x?x384xf32>) outs(%264 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%266 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%265 : tensor<?x?xf32>) outs(%263 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%267 = linalg.fill(%cst_105, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%268 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%262, %266 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%267 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%269 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%262, %266, %268, %cst_56, %cst_57 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%270 = linalg.init_tensor [%168, %169, 1536] : tensor<?x?x1536xf32>
%271 = linalg.init_tensor [%168, 384, 1536] : tensor<?x384x1536xf32>
%272 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_55 : tensor<1536xf32>) outs(%270 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_54 : tensor<1536x384xf32>) outs(%271 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%274 = linalg.batch_matmul ins(%269, %273 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%272 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%275 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%274 : tensor<?x?x1536xf32>) outs(%270 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.sqrt %cst_102 : f32
%553 = arith.divf %arg1, %552 : f32
%554 = math.erf %553 : f32
%555 = arith.addf %554, %cst_101 : f32
%556 = arith.mulf %555, %cst_103 : f32
%557 = arith.mulf %arg1, %556 : f32
linalg.yield %557 : f32
} -> tensor<?x?x1536xf32>
%276 = linalg.init_tensor [%168, 1536, 384] : tensor<?x1536x384xf32>
%277 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_53 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%278 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52 : tensor<384x1536xf32>) outs(%276 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%279 = linalg.batch_matmul ins(%275, %278 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%277 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%280 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%279, %269 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%281 = linalg.fill(%cst_105, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%282 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%280 : tensor<?x?x384xf32>) outs(%281 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%283 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%282 : tensor<?x?xf32>) outs(%263 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%284 = linalg.fill(%cst_105, %263) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%285 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%280, %283 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%284 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%286 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%280, %283, %285, %cst_50, %cst_51 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%287 = linalg.init_tensor [%168, 384, 384] : tensor<?x384x384xf32>
%288 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_49 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%289 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_48 : tensor<384x384xf32>) outs(%287 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%290 = linalg.batch_matmul ins(%286, %289 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%288 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%291 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_47 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%292 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_46 : tensor<384x384xf32>) outs(%287 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%293 = linalg.batch_matmul ins(%286, %292 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%291 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%294 = tensor.dim %293, %c0 : tensor<?x?x384xf32>
%295 = tensor.dim %293, %c1 : tensor<?x?x384xf32>
%296 = tensor.dim %293, %c0 : tensor<?x?x384xf32>
%297 = tensor.dim %293, %c1 : tensor<?x?x384xf32>
%298 = flow.tensor.reshape %293 : tensor<?x?x384xf32>{%296, %297} -> tensor<?x?x12x32xf32>{%294, %295}
%299 = linalg.init_tensor [%168, 12, %169, 32] : tensor<?x12x?x32xf32>
%300 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_45 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%301 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_44 : tensor<384x384xf32>) outs(%287 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%302 = linalg.batch_matmul ins(%286, %301 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%300 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%303 = tensor.dim %302, %c0 : tensor<?x?x384xf32>
%304 = tensor.dim %302, %c1 : tensor<?x?x384xf32>
%305 = tensor.dim %302, %c0 : tensor<?x?x384xf32>
%306 = tensor.dim %302, %c1 : tensor<?x?x384xf32>
%307 = flow.tensor.reshape %302 : tensor<?x?x384xf32>{%305, %306} -> tensor<?x?x12x32xf32>{%303, %304}
%308 = tensor.dim %290, %c0 : tensor<?x?x384xf32>
%309 = tensor.dim %290, %c1 : tensor<?x?x384xf32>
%310 = tensor.dim %290, %c0 : tensor<?x?x384xf32>
%311 = tensor.dim %290, %c1 : tensor<?x?x384xf32>
%312 = flow.tensor.reshape %290 : tensor<?x?x384xf32>{%310, %311} -> tensor<?x?x12x32xf32>{%308, %309}
%313 = linalg.init_tensor [%168, 12, %169, %169] : tensor<?x12x?x?xf32>
%314 = linalg.fill(%cst_105, %313) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%315 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%312, %298 : tensor<?x?x12x32xf32>, tensor<?x?x12x32xf32>) outs(%314 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.mulf %arg1, %arg2 : f32
%553 = arith.addf %552, %arg3 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%316 = arith.cmpi eq, %168, %c1 : index
cf.assert %316, "mismatched size for broadcast"
%317 = arith.cmpi eq, %169, %19 : index
cf.assert %317, "mismatched size for broadcast"
%318 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%315, %151 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%313 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.truncf %cst_110 : f64 to f32
%553 = arith.divf %arg1, %552 : f32
%554 = arith.addf %553, %arg2 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x?xf32>
%319 = linalg.init_tensor [%168, 12, %169] : tensor<?x12x?xf32>
%320 = linalg.fill(%cst_104, %319) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%321 = linalg.init_tensor [%168, 12, %169] : tensor<?x12x?xi64>
%322 = linalg.fill(%c0_i64, %321) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%323:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%318 : tensor<?x12x?x?xf32>) outs(%320, %322 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%552 = linalg.index 3 : index
%553 = arith.index_cast %552 : index to i64
%554 = arith.cmpf ogt, %arg1, %arg2 : f32
%555 = arith.select %554, %arg1, %arg2 : f32
%556 = arith.select %554, %553, %arg3 : i64
linalg.yield %555, %556 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%324 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%318, %323#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%313 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = math.exp %552 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%325 = linalg.fill(%cst_105, %319) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%326 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%324 : tensor<?x12x?x?xf32>) outs(%325 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x12x?xf32>
%327 = linalg.fill(%cst_105, %299) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%328 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%324, %326, %307 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>, tensor<?x?x12x32xf32>) outs(%327 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%552 = arith.divf %arg1, %arg2 : f32
%553 = arith.mulf %552, %arg3 : f32
%554 = arith.addf %553, %arg4 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x32xf32>
%329 = linalg.init_tensor [%168, %169, 12, 32] : tensor<?x?x12x32xf32>
%330 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%328 : tensor<?x12x?x32xf32>) outs(%329 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%331 = tensor.dim %330, %c0 : tensor<?x?x12x32xf32>
%332 = tensor.dim %330, %c1 : tensor<?x?x12x32xf32>
%333 = flow.tensor.reshape %330 : tensor<?x?x12x32xf32>{%331, %332} -> tensor<?x?x384xf32>{%331, %332}
%334 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_43 : tensor<384xf32>) outs(%261 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%335 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42 : tensor<384x384xf32>) outs(%287 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%336 = linalg.batch_matmul ins(%333, %335 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%334 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%337 = arith.cmpi eq, %253, %168 : index
cf.assert %337, "mismatched size for broadcast"
%338 = arith.cmpi eq, %254, %169 : index
cf.assert %338, "mismatched size for broadcast"
%339 = linalg.init_tensor [%253, %254, 384] : tensor<?x?x384xf32>
%340 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%336, %286 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%341 = linalg.init_tensor [%253, %254] : tensor<?x?xf32>
%342 = linalg.fill(%cst_105, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%343 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%340 : tensor<?x?x384xf32>) outs(%342 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%344 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%343 : tensor<?x?xf32>) outs(%341 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%345 = linalg.fill(%cst_105, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%346 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%340, %344 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%345 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%347 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%340, %344, %346, %cst_40, %cst_41 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%348 = linalg.init_tensor [%253, %254, 1536] : tensor<?x?x1536xf32>
%349 = linalg.init_tensor [%253, 384, 1536] : tensor<?x384x1536xf32>
%350 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_39 : tensor<1536xf32>) outs(%348 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%351 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_38 : tensor<1536x384xf32>) outs(%349 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%352 = linalg.batch_matmul ins(%347, %351 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%350 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%353 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%352 : tensor<?x?x1536xf32>) outs(%348 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.sqrt %cst_102 : f32
%553 = arith.divf %arg1, %552 : f32
%554 = math.erf %553 : f32
%555 = arith.addf %554, %cst_101 : f32
%556 = arith.mulf %555, %cst_103 : f32
%557 = arith.mulf %arg1, %556 : f32
linalg.yield %557 : f32
} -> tensor<?x?x1536xf32>
%354 = linalg.init_tensor [%253, 1536, 384] : tensor<?x1536x384xf32>
%355 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_37 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%356 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_36 : tensor<384x1536xf32>) outs(%354 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%357 = linalg.batch_matmul ins(%353, %356 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%355 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%358 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%357, %347 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%359 = linalg.fill(%cst_105, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%360 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%358 : tensor<?x?x384xf32>) outs(%359 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%361 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%360 : tensor<?x?xf32>) outs(%341 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%362 = linalg.fill(%cst_105, %341) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%363 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%358, %361 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%362 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%364 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%358, %361, %363, %cst_34, %cst_35 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%365 = linalg.init_tensor [%253, 384, 384] : tensor<?x384x384xf32>
%366 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_33 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%367 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_32 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%368 = linalg.batch_matmul ins(%364, %367 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%366 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%369 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_31 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%370 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_30 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%371 = linalg.batch_matmul ins(%364, %370 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%369 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%372 = tensor.dim %371, %c0 : tensor<?x?x384xf32>
%373 = tensor.dim %371, %c1 : tensor<?x?x384xf32>
%374 = tensor.dim %371, %c0 : tensor<?x?x384xf32>
%375 = tensor.dim %371, %c1 : tensor<?x?x384xf32>
%376 = flow.tensor.reshape %371 : tensor<?x?x384xf32>{%374, %375} -> tensor<?x?x12x32xf32>{%372, %373}
%377 = linalg.init_tensor [%253, 12, %254, 32] : tensor<?x12x?x32xf32>
%378 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_29 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%379 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_28 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%380 = linalg.batch_matmul ins(%364, %379 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%378 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%381 = tensor.dim %380, %c0 : tensor<?x?x384xf32>
%382 = tensor.dim %380, %c1 : tensor<?x?x384xf32>
%383 = tensor.dim %380, %c0 : tensor<?x?x384xf32>
%384 = tensor.dim %380, %c1 : tensor<?x?x384xf32>
%385 = flow.tensor.reshape %380 : tensor<?x?x384xf32>{%383, %384} -> tensor<?x?x12x32xf32>{%381, %382}
%386 = tensor.dim %368, %c0 : tensor<?x?x384xf32>
%387 = tensor.dim %368, %c1 : tensor<?x?x384xf32>
%388 = tensor.dim %368, %c0 : tensor<?x?x384xf32>
%389 = tensor.dim %368, %c1 : tensor<?x?x384xf32>
%390 = flow.tensor.reshape %368 : tensor<?x?x384xf32>{%388, %389} -> tensor<?x?x12x32xf32>{%386, %387}
%391 = linalg.init_tensor [%253, 12, %254, %254] : tensor<?x12x?x?xf32>
%392 = linalg.fill(%cst_105, %391) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%390, %376 : tensor<?x?x12x32xf32>, tensor<?x?x12x32xf32>) outs(%392 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.mulf %arg1, %arg2 : f32
%553 = arith.addf %552, %arg3 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%394 = arith.cmpi eq, %253, %c1 : index
cf.assert %394, "mismatched size for broadcast"
%395 = arith.cmpi eq, %254, %19 : index
cf.assert %395, "mismatched size for broadcast"
%396 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%393, %151 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%391 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.truncf %cst_110 : f64 to f32
%553 = arith.divf %arg1, %552 : f32
%554 = arith.addf %553, %arg2 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x?xf32>
%397 = linalg.init_tensor [%253, 12, %254] : tensor<?x12x?xf32>
%398 = linalg.fill(%cst_104, %397) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%399 = linalg.init_tensor [%253, 12, %254] : tensor<?x12x?xi64>
%400 = linalg.fill(%c0_i64, %399) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%401:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%396 : tensor<?x12x?x?xf32>) outs(%398, %400 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%552 = linalg.index 3 : index
%553 = arith.index_cast %552 : index to i64
%554 = arith.cmpf ogt, %arg1, %arg2 : f32
%555 = arith.select %554, %arg1, %arg2 : f32
%556 = arith.select %554, %553, %arg3 : i64
linalg.yield %555, %556 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%402 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%396, %401#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%391 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = math.exp %552 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%403 = linalg.fill(%cst_105, %397) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%404 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%402 : tensor<?x12x?x?xf32>) outs(%403 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x12x?xf32>
%405 = linalg.fill(%cst_105, %377) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%402, %404, %385 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>, tensor<?x?x12x32xf32>) outs(%405 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%552 = arith.divf %arg1, %arg2 : f32
%553 = arith.mulf %552, %arg3 : f32
%554 = arith.addf %553, %arg4 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x32xf32>
%407 = linalg.init_tensor [%253, %254, 12, 32] : tensor<?x?x12x32xf32>
%408 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%406 : tensor<?x12x?x32xf32>) outs(%407 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%409 = tensor.dim %408, %c0 : tensor<?x?x12x32xf32>
%410 = tensor.dim %408, %c1 : tensor<?x?x12x32xf32>
%411 = flow.tensor.reshape %408 : tensor<?x?x12x32xf32>{%409, %410} -> tensor<?x?x384xf32>{%409, %410}
%412 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_27 : tensor<384xf32>) outs(%339 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%413 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_26 : tensor<384x384xf32>) outs(%365 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%414 = linalg.batch_matmul ins(%411, %413 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%412 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%415 = arith.cmpi eq, %331, %253 : index
cf.assert %415, "mismatched size for broadcast"
%416 = arith.cmpi eq, %332, %254 : index
cf.assert %416, "mismatched size for broadcast"
%417 = linalg.init_tensor [%331, %332, 384] : tensor<?x?x384xf32>
%418 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%414, %364 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%419 = linalg.init_tensor [%331, %332] : tensor<?x?xf32>
%420 = linalg.fill(%cst_105, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%421 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%418 : tensor<?x?x384xf32>) outs(%420 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%422 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%421 : tensor<?x?xf32>) outs(%419 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%423 = linalg.fill(%cst_105, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%424 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%418, %422 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%423 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%425 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%418, %422, %424, %cst_24, %cst_25 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%426 = linalg.init_tensor [%331, %332, 1536] : tensor<?x?x1536xf32>
%427 = linalg.init_tensor [%331, 384, 1536] : tensor<?x384x1536xf32>
%428 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_23 : tensor<1536xf32>) outs(%426 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%429 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_22 : tensor<1536x384xf32>) outs(%427 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%430 = linalg.batch_matmul ins(%425, %429 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%428 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%431 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%430 : tensor<?x?x1536xf32>) outs(%426 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.sqrt %cst_102 : f32
%553 = arith.divf %arg1, %552 : f32
%554 = math.erf %553 : f32
%555 = arith.addf %554, %cst_101 : f32
%556 = arith.mulf %555, %cst_103 : f32
%557 = arith.mulf %arg1, %556 : f32
linalg.yield %557 : f32
} -> tensor<?x?x1536xf32>
%432 = linalg.init_tensor [%331, 1536, 384] : tensor<?x1536x384xf32>
%433 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_21 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%434 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_20 : tensor<384x1536xf32>) outs(%432 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%435 = linalg.batch_matmul ins(%431, %434 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%433 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%436 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%435, %425 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%437 = linalg.fill(%cst_105, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%438 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%436 : tensor<?x?x384xf32>) outs(%437 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%439 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%438 : tensor<?x?xf32>) outs(%419 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%440 = linalg.fill(%cst_105, %419) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%441 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%436, %439 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%440 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%442 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%436, %439, %441, %cst_18, %cst_19 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%443 = linalg.init_tensor [%331, 384, 384] : tensor<?x384x384xf32>
%444 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_17 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%445 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16 : tensor<384x384xf32>) outs(%443 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%446 = linalg.batch_matmul ins(%442, %445 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%444 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%447 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_15 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%448 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_14 : tensor<384x384xf32>) outs(%443 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%449 = linalg.batch_matmul ins(%442, %448 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%447 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%450 = tensor.dim %449, %c0 : tensor<?x?x384xf32>
%451 = tensor.dim %449, %c1 : tensor<?x?x384xf32>
%452 = tensor.dim %449, %c0 : tensor<?x?x384xf32>
%453 = tensor.dim %449, %c1 : tensor<?x?x384xf32>
%454 = flow.tensor.reshape %449 : tensor<?x?x384xf32>{%452, %453} -> tensor<?x?x12x32xf32>{%450, %451}
%455 = linalg.init_tensor [%331, 12, %332, 32] : tensor<?x12x?x32xf32>
%456 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_13 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%457 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_12 : tensor<384x384xf32>) outs(%443 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%458 = linalg.batch_matmul ins(%442, %457 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%456 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%459 = tensor.dim %458, %c0 : tensor<?x?x384xf32>
%460 = tensor.dim %458, %c1 : tensor<?x?x384xf32>
%461 = tensor.dim %458, %c0 : tensor<?x?x384xf32>
%462 = tensor.dim %458, %c1 : tensor<?x?x384xf32>
%463 = flow.tensor.reshape %458 : tensor<?x?x384xf32>{%461, %462} -> tensor<?x?x12x32xf32>{%459, %460}
%464 = tensor.dim %446, %c0 : tensor<?x?x384xf32>
%465 = tensor.dim %446, %c1 : tensor<?x?x384xf32>
%466 = tensor.dim %446, %c0 : tensor<?x?x384xf32>
%467 = tensor.dim %446, %c1 : tensor<?x?x384xf32>
%468 = flow.tensor.reshape %446 : tensor<?x?x384xf32>{%466, %467} -> tensor<?x?x12x32xf32>{%464, %465}
%469 = linalg.init_tensor [%331, 12, %332, %332] : tensor<?x12x?x?xf32>
%470 = linalg.fill(%cst_105, %469) : f32, tensor<?x12x?x?xf32> -> tensor<?x12x?x?xf32>
%471 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%468, %454 : tensor<?x?x12x32xf32>, tensor<?x?x12x32xf32>) outs(%470 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.mulf %arg1, %arg2 : f32
%553 = arith.addf %552, %arg3 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%472 = arith.cmpi eq, %331, %c1 : index
cf.assert %472, "mismatched size for broadcast"
%473 = arith.cmpi eq, %332, %19 : index
cf.assert %473, "mismatched size for broadcast"
%474 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%471, %151 : tensor<?x12x?x?xf32>, tensor<1x?xf32>) outs(%469 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.truncf %cst_110 : f64 to f32
%553 = arith.divf %arg1, %552 : f32
%554 = arith.addf %553, %arg2 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x?xf32>
%475 = linalg.init_tensor [%331, 12, %332] : tensor<?x12x?xf32>
%476 = linalg.fill(%cst_104, %475) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%477 = linalg.init_tensor [%331, 12, %332] : tensor<?x12x?xi64>
%478 = linalg.fill(%c0_i64, %477) : i64, tensor<?x12x?xi64> -> tensor<?x12x?xi64>
%479:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%474 : tensor<?x12x?x?xf32>) outs(%476, %478 : tensor<?x12x?xf32>, tensor<?x12x?xi64>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: i64):
%552 = linalg.index 3 : index
%553 = arith.index_cast %552 : index to i64
%554 = arith.cmpf ogt, %arg1, %arg2 : f32
%555 = arith.select %554, %arg1, %arg2 : f32
%556 = arith.select %554, %553, %arg3 : i64
linalg.yield %555, %556 : f32, i64
} -> (tensor<?x12x?xf32>, tensor<?x12x?xi64>)
%480 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%474, %479#0 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>) outs(%469 : tensor<?x12x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = math.exp %552 : f32
linalg.yield %553 : f32
} -> tensor<?x12x?x?xf32>
%481 = linalg.fill(%cst_105, %475) : f32, tensor<?x12x?xf32> -> tensor<?x12x?xf32>
%482 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%480 : tensor<?x12x?x?xf32>) outs(%481 : tensor<?x12x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x12x?xf32>
%483 = linalg.fill(%cst_105, %455) : f32, tensor<?x12x?x32xf32> -> tensor<?x12x?x32xf32>
%484 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%480, %482, %463 : tensor<?x12x?x?xf32>, tensor<?x12x?xf32>, tensor<?x?x12x32xf32>) outs(%483 : tensor<?x12x?x32xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
%552 = arith.divf %arg1, %arg2 : f32
%553 = arith.mulf %552, %arg3 : f32
%554 = arith.addf %553, %arg4 : f32
linalg.yield %554 : f32
} -> tensor<?x12x?x32xf32>
%485 = linalg.init_tensor [%331, %332, 12, 32] : tensor<?x?x12x32xf32>
%486 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%484 : tensor<?x12x?x32xf32>) outs(%485 : tensor<?x?x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x12x32xf32>
%487 = tensor.dim %486, %c0 : tensor<?x?x12x32xf32>
%488 = tensor.dim %486, %c1 : tensor<?x?x12x32xf32>
%489 = flow.tensor.reshape %486 : tensor<?x?x12x32xf32>{%487, %488} -> tensor<?x?x384xf32>{%487, %488}
%490 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_11 : tensor<384xf32>) outs(%417 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%491 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_10 : tensor<384x384xf32>) outs(%443 : tensor<?x384x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x384xf32>
%492 = linalg.batch_matmul ins(%489, %491 : tensor<?x?x384xf32>, tensor<?x384x384xf32>) outs(%490 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%493 = arith.cmpi eq, %409, %331 : index
cf.assert %493, "mismatched size for broadcast"
%494 = arith.cmpi eq, %410, %332 : index
cf.assert %494, "mismatched size for broadcast"
%495 = linalg.init_tensor [%409, %410, 384] : tensor<?x?x384xf32>
%496 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%492, %442 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%497 = linalg.init_tensor [%409, %410] : tensor<?x?xf32>
%498 = linalg.fill(%cst_105, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%499 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%496 : tensor<?x?x384xf32>) outs(%498 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%500 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%499 : tensor<?x?xf32>) outs(%497 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%501 = linalg.fill(%cst_105, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%502 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%496, %500 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%501 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%503 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%496, %500, %502, %cst_8, %cst_9 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%504 = linalg.init_tensor [%409, %410, 1536] : tensor<?x?x1536xf32>
%505 = linalg.init_tensor [%409, 384, 1536] : tensor<?x384x1536xf32>
%506 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_7 : tensor<1536xf32>) outs(%504 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x1536xf32>
%507 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_6 : tensor<1536x384xf32>) outs(%505 : tensor<?x384x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384x1536xf32>
%508 = linalg.batch_matmul ins(%503, %507 : tensor<?x?x384xf32>, tensor<?x384x1536xf32>) outs(%506 : tensor<?x?x1536xf32>) -> tensor<?x?x1536xf32>
%509 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%508 : tensor<?x?x1536xf32>) outs(%504 : tensor<?x?x1536xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.sqrt %cst_102 : f32
%553 = arith.divf %arg1, %552 : f32
%554 = math.erf %553 : f32
%555 = arith.addf %554, %cst_101 : f32
%556 = arith.mulf %555, %cst_103 : f32
%557 = arith.mulf %arg1, %556 : f32
linalg.yield %557 : f32
} -> tensor<?x?x1536xf32>
%510 = linalg.init_tensor [%409, 1536, 384] : tensor<?x1536x384xf32>
%511 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_5 : tensor<384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x?x384xf32>
%512 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_4 : tensor<384x1536xf32>) outs(%510 : tensor<?x1536x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x1536x384xf32>
%513 = linalg.batch_matmul ins(%509, %512 : tensor<?x?x1536xf32>, tensor<?x1536x384xf32>) outs(%511 : tensor<?x?x384xf32>) -> tensor<?x?x384xf32>
%514 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%513, %503 : tensor<?x?x384xf32>, tensor<?x?x384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.addf %arg1, %arg2 : f32
linalg.yield %552 : f32
} -> tensor<?x?x384xf32>
%515 = linalg.fill(%cst_105, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%516 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%514 : tensor<?x?x384xf32>) outs(%515 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.addf %arg2, %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%517 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%516 : tensor<?x?xf32>) outs(%497 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = arith.divf %arg1, %cst : f32
linalg.yield %552 : f32
} -> tensor<?x?xf32>
%518 = linalg.fill(%cst_105, %497) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%519 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%514, %517 : tensor<?x?x384xf32>, tensor<?x?xf32>) outs(%518 : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%552 = arith.subf %arg1, %arg2 : f32
%553 = arith.mulf %552, %552 : f32
%554 = arith.addf %arg3, %553 : f32
linalg.yield %554 : f32
} -> tensor<?x?xf32>
%520 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%514, %517, %519, %cst_2, %cst_3 : tensor<?x?x384xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<384xf32>, tensor<384xf32>) outs(%495 : tensor<?x?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32):
%552 = arith.divf %arg3, %cst : f32
%553 = arith.subf %arg1, %arg2 : f32
%554 = arith.truncf %cst_100 : f64 to f32
%555 = arith.addf %552, %554 : f32
%556 = math.rsqrt %555 : f32
%557 = arith.mulf %553, %556 : f32
%558 = arith.mulf %557, %arg4 : f32
%559 = arith.addf %558, %arg5 : f32
linalg.yield %559 : f32
} -> tensor<?x?x384xf32>
%521 = arith.index_cast %409 : index to i64
%522 = arith.cmpi sgt, %c0_i64, %521 : i64
%523 = arith.select %522, %521, %c0_i64 : i64
%524 = arith.index_cast %523 : i64 to index
%525 = arith.cmpi sgt, %c9223372036854775807_i64, %521 : i64
%526 = arith.select %525, %521, %c9223372036854775807_i64 : i64
%527 = arith.index_cast %526 : i64 to index
%528 = arith.cmpi sge, %527, %524 : index
%529 = arith.select %528, %527, %524 : index
%530 = arith.subi %529, %524 : index
%531 = tensor.extract_slice %520[%524, 0, 0] [%530, %410, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%532 = arith.index_cast %410 : index to i64
%533 = arith.cmpi sgt, %c0_i64, %532 : i64
%534 = arith.select %533, %532, %c0_i64 : i64
%535 = arith.index_cast %534 : i64 to index
%536 = arith.cmpi sgt, %c1_i64, %532 : i64
%537 = arith.select %536, %532, %c1_i64 : i64
%538 = arith.index_cast %537 : i64 to index
%539 = arith.cmpi sge, %538, %535 : index
%540 = arith.select %539, %538, %535 : index
%541 = arith.subi %540, %535 : index
%542 = tensor.extract_slice %531[0, %535, 0] [%530, %541, 384] [1, 1, 1] : tensor<?x?x384xf32> to tensor<?x?x384xf32>
%543 = flow.tensor.reshape %542 : tensor<?x?x384xf32>{%530, %c1} -> tensor<?x384xf32>{%530}
%544 = linalg.init_tensor [%530, 384] : tensor<?x384xf32>
%545 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<384xf32>) outs(%544 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x384xf32>
%546 = linalg.matmul ins(%543, %cst_112 : tensor<?x384xf32>, tensor<384x384xf32>) outs(%545 : tensor<?x384xf32>) -> tensor<?x384xf32>
%547 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%546 : tensor<?x384xf32>) outs(%544 : tensor<?x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%552 = math.tanh %arg1 : f32
linalg.yield %552 : f32
} -> tensor<?x384xf32>
%548 = linalg.init_tensor [%530, 2] : tensor<?x2xf32>
%549 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst_0 : tensor<2xf32>) outs(%548 : tensor<?x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<?x2xf32>
%550 = linalg.matmul ins(%547, %cst_113 : tensor<?x384xf32>, tensor<384x2xf32>) outs(%549 : tensor<?x2xf32>) -> tensor<?x2xf32>
%551 = hal.tensor.export %550 : tensor<?x2xf32>{%530} -> !hal.buffer_view
return %551 : !hal.buffer_view
}
// -----// IR Dump Before Canonicalizer //----- //
func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%cst = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x2xf32>
%cst_0 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_1 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_2 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_3 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%cst_4 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x384x384xf32>
%false = arith.constant false
%c0_i64 = arith.constant 0 : i64
%cst_5 = arith.constant 1.000000e+00 : f32
%c9223372036854775807_i64 = arith.constant 9223372036854775807 : i64
%cst_6 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xi64>
%cst_7 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<30522x384xf32>
%cst_8 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<2x384xf32>
%cst_9 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<512x384xf32>
%cst_10 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_11 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_12 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_13 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_14 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_15 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_16 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_17 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_18 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_19 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_20 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_21 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_22 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_23 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_24 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_25 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_26 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_27 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_28 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_29 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_30 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_31 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_32 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_33 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_34 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_35 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_36 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_37 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_38 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_39 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_40 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_41 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_42 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_43 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_44 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_45 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_46 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_47 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_48 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_49 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_50 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_51 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_52 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_53 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_54 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_55 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_56 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_57 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_58 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_59 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_60 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_61 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_62 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_63 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_64 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_65 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_66 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_67 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_68 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_69 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_70 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_71 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_72 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_73 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_74 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_75 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_76 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_77 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_78 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_79 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_80 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_81 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_82 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_83 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_84 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_85 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_86 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_87 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_88 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_89 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_90 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_91 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_92 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_93 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_94 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_95 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x384xf32>
%cst_96 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_97 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_98 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536xf32>
%cst_99 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1536x384xf32>
%cst_100 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_101 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384x1536xf32>
%cst_102 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_103 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%cst_104 = arith.constant opaque<"elided_large_const", "0xDEADBEEF"> : tensor<384xf32>
%c512_i64 = arith.constant 512 : i64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c512 = arith.constant 512 : index
%c384 = arith.constant 384 : index
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%c1536 = arith.constant 1536 : index
%c2 = arith.constant 2 : index
%true = arith.constant true
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x512xi64>
%1 = flow.tensor.splat %c512_i64 : tensor<i64>
%2 = flow.tensor.load %1 : tensor<i64>
%3 = arith.index_cast %2 : i64 to index
%4 = linalg.init_tensor [1, %3] : tensor<1x?xf32>
%5 = tensor.dim %4, %c1 : tensor<1x?xf32>
%6 = flow.tensor.splat %cst_5 : tensor<1x?xf32>{%5}
%7 = flow.dispatch.workgroups[%3, %c1, %c1](%6, %3, %1) : (tensor<1x?xf32>{%3}, index, tensor<i64>) -> tensor<?xf32>{%3} =
(%arg1: !flow.dispatch.tensor<readonly:1x?xf32>, %arg2: index, %arg3: !flow.dispatch.tensor<readonly:i64>, %arg4: !flow.dispatch.tensor<writeonly:?xf32>) {
%270 = flow.dispatch.tie_shape %arg1 : !flow.dispatch.tensor<readonly:1x?xf32>{%arg2}
%271 = flow.dispatch.tie_shape %arg4 : !flow.dispatch.tensor<writeonly:?xf32>{%arg2}
%272 = flow.dispatch.tensor.load %270, offsets = [0, 0], sizes = [1, %arg2], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x?xf32>{%arg2} -> tensor<1x?xf32>
%273 = flow.dispatch.tensor.load %arg3, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:i64> -> tensor<i64>
%274 = tensor.extract %273[] : tensor<i64>
%275 = arith.index_cast %274 : i64 to index
%276 = tensor.extract_slice %272[0, 0] [1, %275] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
flow.dispatch.tensor.store %276, %271, offsets = [0], sizes = [%275], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:?xf32>{%275}
flow.return
}
%8 = flow.tensor.reshape %7 : tensor<?xf32>{%3} -> tensor<?x1x1x?xf32>{%c1, %3}
%9 = arith.cmpi sgt, %c0_i64, %2 : i64
%10 = arith.select %9, %2, %c0_i64 : i64
%11 = arith.index_cast %10 : i64 to index
%12 = arith.cmpi sgt, %c9223372036854775807_i64, %2 : i64
%13 = arith.select %12, %2, %c9223372036854775807_i64 : i64
%14 = arith.index_cast %13 : i64 to index
%15 = arith.cmpi sge, %14, %11 : index
%16 = arith.select %15, %14, %11 : index
%17 = arith.subi %16, %11 : index
%18 = flow.dispatch.workgroups[%17, %c1, %c1](%8, %c1, %3, %1, %17) : (tensor<?x1x1x?xf32>{%c1, %3}, index, index, tensor<i64>, index) -> tensor<?xf32>{%17} =
(%arg1: !flow.dispatch.tensor<readonly:?x1x1x?xf32>, %arg2: index, %arg3: index, %arg4: !flow.dispatch.tensor<readonly:i64>, %arg5: index, %arg6: !flow.dispatch.tensor<writeonly:?xf32>) {
%c0_i64_105 = arith.constant 0 : i64
%cst_106 = arith.constant -1.000000e+04 : f32
%cst_107 = arith.constant 1.000000e+00 : f32
%270 = flow.dispatch.tie_shape %arg1 : !flow.dispatch.tensor<readonly:?x1x1x?xf32>{%arg2, %arg3}
%271 = flow.dispatch.tie_shape %arg6 : !flow.dispatch.tensor<writeonly:?xf32>{%arg5}
%272 = flow.dispatch.tensor.load %270, offsets = [0, 0, 0, 0], sizes = [%arg2, 1, 1, %arg3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x1x1x?xf32>{%arg2, %arg3} -> tensor<?x1x1x?xf32>
%273 = flow.dispatch.tensor.load %arg4, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:i64> -> tensor<i64>
%274 = linalg.init_tensor [%arg5] : tensor<?xf32>
%275 = tensor.extract %273[] : tensor<i64>
%276 = arith.cmpi sgt, %c0_i64_105, %275 : i64
%277 = arith.select %276, %275, %c0_i64_105 : i64
%278 = arith.index_cast %277 : i64 to index
%279 = tensor.extract_slice %272[0, 0, 0, %278] [1, 1, 1, %arg5] [1, 1, 1, 1] : tensor<?x1x1x?xf32> to tensor<?xf32>
%280 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%279 : tensor<?xf32>) outs(%274 : tensor<?xf32>) {
^bb0(%arg7: f32, %arg8: f32):
%281 = arith.subf %cst_107, %arg7 : f32
%282 = arith.mulf %281, %cst_106 : f32
linalg.yield %282 : f32
} -> tensor<?xf32>
flow.dispatch.tensor.store %280, %271, offsets = [0], sizes = [%arg5], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:?xf32>{%arg5}
flow.return
}
%19 = flow.tensor.splat %c512_i64 : tensor<i64>
%20 = flow.tensor.load %19 : tensor<i64>
%21 = arith.addi %20, %c512_i64 : i64
%22 = arith.cmpi sge, %20, %c0_i64 : i64
%23 = arith.select %22, %20, %21 : i64
%24 = arith.cmpi slt, %23, %c0_i64 : i64
%25 = arith.select %24, %c0_i64, %23 : i64
%26 = arith.cmpi sgt, %25, %c512_i64 : i64
%27 = arith.select %26, %c512_i64, %25 : i64
%28 = arith.index_cast %27 : i64 to index
%29 = arith.cmpi sge, %28, %c0 : index
%30 = arith.select %29, %28, %c0 : index
%31 = flow.tensor.reshape %0 : tensor<1x512xi64> -> tensor<512xi64>
%32 = arith.cmpi eq, %c512, %30 : index
cf.assert %32, "mismatched size for broadcast"
%33 = flow.dispatch.workgroups[%c384, %c512, %c1](%cst_8, %cst_7, %cst_9, %31, %cst_6, %30) : (tensor<2x384xf32>, tensor<30522x384xf32>, tensor<512x384xf32>, tensor<512xi64>, tensor<1x512xi64>, index) -> tensor<512x384xf32> =
(%arg1: !flow.dispatch.tensor<readonly:2x384xf32>, %arg2: !flow.dispatch.tensor<readonly:30522x384xf32>, %arg3: !flow.dispatch.tensor<readonly:512x384xf32>, %arg4: !flow.dispatch.tensor<readonly:512xi64>, %arg5: !flow.dispatch.tensor<readonly:1x512xi64>, %arg6: index, %arg7: !flow.dispatch.tensor<writeonly:512x384xf32>) {
%c512_i64_105 = arith.constant 512 : i64
%c0_i64_106 = arith.constant 0 : i64
%c30522_i64 = arith.constant 30522 : i64
%c0_107 = arith.constant 0 : index
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [2, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:2x384xf32> -> tensor<2x384xf32>
%271 = flow.dispatch.tensor.load %arg2, offsets = [0, 0], sizes = [30522, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:30522x384xf32> -> tensor<30522x384xf32>
%272 = flow.dispatch.tensor.load %arg3, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x384xf32> -> tensor<512x384xf32>
%273 = flow.dispatch.tensor.load %arg4, offsets = [0], sizes = [512], strides = [1] : !flow.dispatch.tensor<readonly:512xi64> -> tensor<512xi64>
%274 = flow.dispatch.tensor.load %arg5, offsets = [0, 0], sizes = [1, 512], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x512xi64> -> tensor<1x512xi64>
%275 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%276 = tensor.extract_slice %274[0, 0] [1, %arg6] [1, 1] : tensor<1x512xi64> to tensor<?xi64>
%277 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%273, %276 : tensor<512xi64>, tensor<?xi64>) outs(%275 : tensor<512x384xf32>) {
^bb0(%arg8: i64, %arg9: i64, %arg10: f32):
%278 = linalg.index 1 : index
%279 = tensor.extract %270[%c0_107, %278] : tensor<2x384xf32>
%280 = arith.index_cast %arg8 : i64 to index
%281 = arith.cmpi slt, %arg8, %c30522_i64 : i64
cf.assert %281, "index must be smaller than dim size"
%282 = arith.cmpi sge, %arg8, %c0_i64_106 : i64
cf.assert %282, "index must be larger or equal to 0"
%283 = tensor.extract %271[%280, %278] : tensor<30522x384xf32>
%284 = arith.index_cast %arg9 : i64 to index
%285 = arith.cmpi slt, %arg9, %c512_i64_105 : i64
cf.assert %285, "index must be smaller than dim size"
%286 = arith.cmpi sge, %arg9, %c0_i64_106 : i64
cf.assert %286, "index must be larger or equal to 0"
%287 = tensor.extract %272[%284, %278] : tensor<512x384xf32>
%288 = arith.addf %283, %279 : f32
%289 = arith.addf %288, %287 : f32
linalg.yield %289 : f32
} -> tensor<512x384xf32>
flow.dispatch.tensor.store %277, %arg7, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : tensor<512x384xf32> -> !flow.dispatch.tensor<writeonly:512x384xf32>
flow.return
}
%34 = flow.dispatch.workgroups[%c512, %c1, %c1](%33) : (tensor<512x384xf32>) -> tensor<512xf32> =
(%arg1: !flow.dispatch.tensor<readonly:512x384xf32>, %arg2: !flow.dispatch.tensor<writeonly:512xf32>) {
%cst_105 = arith.constant 3.840000e+02 : f32
%cst_106 = arith.constant 0.000000e+00 : f32
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x384xf32> -> tensor<512x384xf32>
%271 = linalg.init_tensor [512] : tensor<512xf32>
%272 = linalg.fill(%cst_106, %271) : f32, tensor<512xf32> -> tensor<512xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%270 : tensor<512x384xf32>) outs(%272 : tensor<512xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%275 = arith.addf %arg4, %arg3 : f32
linalg.yield %275 : f32
} -> tensor<512xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%273 : tensor<512xf32>) outs(%271 : tensor<512xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%275 = arith.divf %arg3, %cst_105 : f32
linalg.yield %275 : f32
} -> tensor<512xf32>
flow.dispatch.tensor.store %274, %arg2, offsets = [0], sizes = [512], strides = [1] : tensor<512xf32> -> !flow.dispatch.tensor<writeonly:512xf32>
flow.return
}
%35 = flow.dispatch.workgroups[%c512, %c1, %c1](%33, %34) : (tensor<512x384xf32>, tensor<512xf32>) -> tensor<512xf32> =
(%arg1: !flow.dispatch.tensor<readonly:512x384xf32>, %arg2: !flow.dispatch.tensor<readonly:512xf32>, %arg3: !flow.dispatch.tensor<writeonly:512xf32>) {
%cst_105 = arith.constant 0.000000e+00 : f32
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x384xf32> -> tensor<512x384xf32>
%271 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [512], strides = [1] : !flow.dispatch.tensor<readonly:512xf32> -> tensor<512xf32>
%272 = linalg.init_tensor [512] : tensor<512xf32>
%273 = linalg.fill(%cst_105, %272) : f32, tensor<512xf32> -> tensor<512xf32>
%274 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%270, %271 : tensor<512x384xf32>, tensor<512xf32>) outs(%273 : tensor<512xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%275 = arith.subf %arg4, %arg5 : f32
%276 = arith.mulf %275, %275 : f32
%277 = arith.addf %arg6, %276 : f32
linalg.yield %277 : f32
} -> tensor<512xf32>
flow.dispatch.tensor.store %274, %arg3, offsets = [0], sizes = [512], strides = [1] : tensor<512xf32> -> !flow.dispatch.tensor<writeonly:512xf32>
flow.return
}
%36 = flow.dispatch.workgroups[%c384, %c512, %c1](%33, %34, %35, %cst_11, %cst_10) : (tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>, tensor<384xf32>, tensor<384xf32>) -> tensor<512x384xf32> =
(%arg1: !flow.dispatch.tensor<readonly:512x384xf32>, %arg2: !flow.dispatch.tensor<readonly:512xf32>, %arg3: !flow.dispatch.tensor<readonly:512xf32>, %arg4: !flow.dispatch.tensor<readonly:384xf32>, %arg5: !flow.dispatch.tensor<readonly:384xf32>, %arg6: !flow.dispatch.tensor<writeonly:512x384xf32>) {
%cst_105 = arith.constant 9.9999999999999998E-13 : f64
%cst_106 = arith.constant 3.840000e+02 : f32
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:512x384xf32> -> tensor<512x384xf32>
%271 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [512], strides = [1] : !flow.dispatch.tensor<readonly:512xf32> -> tensor<512xf32>
%272 = flow.dispatch.tensor.load %arg3, offsets = [0], sizes = [512], strides = [1] : !flow.dispatch.tensor<readonly:512xf32> -> tensor<512xf32>
%273 = flow.dispatch.tensor.load %arg4, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
%274 = flow.dispatch.tensor.load %arg5, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
%275 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%276 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%270, %271, %272, %273, %274 : tensor<512x384xf32>, tensor<512xf32>, tensor<512xf32>, tensor<384xf32>, tensor<384xf32>) outs(%275 : tensor<512x384xf32>) {
^bb0(%arg7: f32, %arg8: f32, %arg9: f32, %arg10: f32, %arg11: f32, %arg12: f32):
%277 = arith.divf %arg9, %cst_106 : f32
%278 = arith.subf %arg7, %arg8 : f32
%279 = arith.truncf %cst_105 : f64 to f32
%280 = arith.addf %277, %279 : f32
%281 = math.rsqrt %280 : f32
%282 = arith.mulf %278, %281 : f32
%283 = arith.mulf %282, %arg10 : f32
%284 = arith.addf %283, %arg11 : f32
linalg.yield %284 : f32
} -> tensor<512x384xf32>
flow.dispatch.tensor.store %276, %arg6, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : tensor<512x384xf32> -> !flow.dispatch.tensor<writeonly:512x384xf32>
flow.return
}
%37 = flow.tensor.reshape %36 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%38 = flow.dispatch.workgroups[%c384, %c512, %c1](%cst_12) : (tensor<384xf32>) -> tensor<512x384xf32> =
(%arg1: !flow.dispatch.tensor<readonly:384xf32>, %arg2: !flow.dispatch.tensor<writeonly:512x384xf32>) {
%270 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
%271 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%272 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%270 : tensor<384xf32>) outs(%271 : tensor<512x384xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<512x384xf32>
flow.dispatch.tensor.store %272, %arg2, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : tensor<512x384xf32> -> !flow.dispatch.tensor<writeonly:512x384xf32>
flow.return
}
%39 = flow.tensor.reshape %38 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%40 = flow.dispatch.workgroups[%c384, %c512, %c1](%37, %cst_4, %39) : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> %39 =
(%arg1: !flow.dispatch.tensor<readonly:1x512x384xf32>, %arg2: !flow.dispatch.tensor<readonly:1x384x384xf32>, %arg3: !flow.dispatch.tensor<readwrite:1x512x384xf32>) {
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x512x384xf32> -> tensor<1x512x384xf32>
%271 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0], sizes = [1, 384, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x384x384xf32> -> tensor<1x384x384xf32>
%272 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readwrite:1x512x384xf32> -> tensor<1x512x384xf32>
%273 = linalg.batch_matmul ins(%270, %271 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%272 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
flow.dispatch.tensor.store %273, %arg3, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : tensor<1x512x384xf32> -> !flow.dispatch.tensor<readwrite:1x512x384xf32>
flow.return
}
%41 = flow.dispatch.workgroups[%c384, %c512, %c1](%cst_13) : (tensor<384xf32>) -> tensor<512x384xf32> =
(%arg1: !flow.dispatch.tensor<readonly:384xf32>, %arg2: !flow.dispatch.tensor<writeonly:512x384xf32>) {
%270 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
%271 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%272 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%270 : tensor<384xf32>) outs(%271 : tensor<512x384xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<512x384xf32>
flow.dispatch.tensor.store %272, %arg2, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : tensor<512x384xf32> -> !flow.dispatch.tensor<writeonly:512x384xf32>
flow.return
}
%42 = flow.tensor.reshape %41 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%43 = flow.dispatch.workgroups[%c384, %c512, %c1](%37, %cst_3, %42) : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> %42 =
(%arg1: !flow.dispatch.tensor<readonly:1x512x384xf32>, %arg2: !flow.dispatch.tensor<readonly:1x384x384xf32>, %arg3: !flow.dispatch.tensor<readwrite:1x512x384xf32>) {
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x512x384xf32> -> tensor<1x512x384xf32>
%271 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0], sizes = [1, 384, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x384x384xf32> -> tensor<1x384x384xf32>
%272 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readwrite:1x512x384xf32> -> tensor<1x512x384xf32>
%273 = linalg.batch_matmul ins(%270, %271 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%272 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
flow.dispatch.tensor.store %273, %arg3, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : tensor<1x512x384xf32> -> !flow.dispatch.tensor<readwrite:1x512x384xf32>
flow.return
}
%44 = flow.tensor.reshape %43 : tensor<1x512x384xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%45 = flow.dispatch.workgroups[%c384, %c512, %c1](%cst_14) : (tensor<384xf32>) -> tensor<512x384xf32> =
(%arg1: !flow.dispatch.tensor<readonly:384xf32>, %arg2: !flow.dispatch.tensor<writeonly:512x384xf32>) {
%270 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [384], strides = [1] : !flow.dispatch.tensor<readonly:384xf32> -> tensor<384xf32>
%271 = linalg.init_tensor [512, 384] : tensor<512x384xf32>
%272 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%270 : tensor<384xf32>) outs(%271 : tensor<512x384xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
} -> tensor<512x384xf32>
flow.dispatch.tensor.store %272, %arg2, offsets = [0, 0], sizes = [512, 384], strides = [1, 1] : tensor<512x384xf32> -> !flow.dispatch.tensor<writeonly:512x384xf32>
flow.return
}
%46 = flow.tensor.reshape %45 : tensor<512x384xf32> -> tensor<1x512x384xf32>
%47 = flow.dispatch.workgroups[%c384, %c512, %c1](%37, %cst_2, %46) : (tensor<1x512x384xf32>, tensor<1x384x384xf32>, tensor<1x512x384xf32>) -> %46 =
(%arg1: !flow.dispatch.tensor<readonly:1x512x384xf32>, %arg2: !flow.dispatch.tensor<readonly:1x384x384xf32>, %arg3: !flow.dispatch.tensor<readwrite:1x512x384xf32>) {
%270 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x512x384xf32> -> tensor<1x512x384xf32>
%271 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0], sizes = [1, 384, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:1x384x384xf32> -> tensor<1x384x384xf32>
%272 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : !flow.dispatch.tensor<readwrite:1x512x384xf32> -> tensor<1x512x384xf32>
%273 = linalg.batch_matmul ins(%270, %271 : tensor<1x512x384xf32>, tensor<1x384x384xf32>) outs(%272 : tensor<1x512x384xf32>) -> tensor<1x512x384xf32>
flow.dispatch.tensor.store %273, %arg3, offsets = [0, 0, 0], sizes = [1, 512, 384], strides = [1, 1, 1] : tensor<1x512x384xf32> -> !flow.dispatch.tensor<readwrite:1x512x384xf32>
flow.return
}
%48 = flow.tensor.reshape %47 : tensor<1x512x384xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%49 = flow.tensor.reshape %40 : tensor<1x512x384xf32> -> tensor<?x?x12x32xf32>{%c1, %c512}
%50 = flow.dispatch.workgroups[%c32, %c12, %c512](%44, %c1, %c512) : (tensor<?x?x12x32xf32>{%c1, %c512}, index, index) -> tensor<12x512x32xf32> =
(%arg1: !flow.dispatch.tensor<readonly:?x?x12x32xf32>, %arg2: index, %arg3: index, %arg4: !flow.dispatch.tensor<writeonly:12x512x32xf32>) {
%270 = flow.dispatch.tie_shape %arg1 : !flow.dispatch.tensor<readonly:?x?x12x32xf32>{%arg2, %arg3}
%271 = flow.dispatch.tensor.load %270, offsets = [0, 0, 0, 0], sizes = [%arg2, %arg3, 12, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x12x32xf32>{%arg2, %arg3} -> tensor<?x?x12x32xf32>
%272 = linalg.init_tensor [12, 512, 32] : tensor<12x512x32xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%271 : tensor<?x?x12x32xf32>) outs(%272 : tensor<12x512x32xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<12x512x32xf32>
flow.dispatch.tensor.store %273, %arg4, offsets = [0, 0, 0], sizes = [12, 512, 32], strides = [1, 1, 1] : tensor<12x512x32xf32> -> !flow.dispatch.tensor<writeonly:12x512x32xf32>
flow.return
}
%51 = flow.dispatch.workgroups[%c32, %c12, %c512](%49, %c1, %c512) : (tensor<?x?x12x32xf32>{%c1, %c512}, index, index) -> tensor<12x512x32xf32> =
(%arg1: !flow.dispatch.tensor<readonly:?x?x12x32xf32>, %arg2: index, %arg3: index, %arg4: !flow.dispatch.tensor<writeonly:12x512x32xf32>) {
%270 = flow.dispatch.tie_shape %arg1 : !flow.dispatch.tensor<readonly:?x?x12x32xf32>{%arg2, %arg3}
%271 = flow.dispatch.tensor.load %270, offsets = [0, 0, 0, 0], sizes = [%arg2, %arg3, 12, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:?x?x12x32xf32>{%arg2, %arg3} -> tensor<?x?x12x32xf32>
%272 = linalg.init_tensor [12, 512, 32] : tensor<12x512x32xf32>
%273 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%271 : tensor<?x?x12x32xf32>) outs(%272 : tensor<12x512x32xf32>) {
^bb0(%arg5: f32, %arg6: f32):
linalg.yield %arg5 : f32
} -> tensor<12x512x32xf32>
flow.dispatch.tensor.store %273, %arg4, offsets = [0, 0, 0], sizes = [12, 512, 32], strides = [1, 1, 1] : tensor<12x512x32xf32> -> !flow.dispatch.tensor<writeonly:12x512x32xf32>
flow.return
}
%52 = arith.cmpi eq, %c512, %17 : index
cf.assert %52, "mismatched size for broadcast"
%53 = flow.dispatch.workgroups[%c512, %c512, %c12](%51, %50, %18, %17) : (tensor<12x512x32xf32>, tensor<12x512x32xf32>, tensor<?xf32>{%17}, index) -> tensor<12x512x512xf32> =
(%arg1: !flow.dispatch.tensor<readonly:12x512x32xf32>, %arg2: !flow.dispatch.tensor<readonly:12x512x32xf32>, %arg3: !flow.dispatch.tensor<readonly:?xf32>, %arg4: index, %arg5: !flow.dispatch.tensor<writeonly:12x512x512xf32>) {
%cst_105 = arith.constant 5.6568542494923806 : f64
%cst_106 = arith.constant 0.000000e+00 : f32
%270 = flow.dispatch.tie_shape %arg3 : !flow.dispatch.tensor<readonly:?xf32>{%arg4}
%271 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [12, 512, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:12x512x32xf32> -> tensor<12x512x32xf32>
%272 = flow.dispatch.tensor.load %arg2, offsets = [0, 0, 0], sizes = [12, 512, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:12x512x32xf32> -> tensor<12x512x32xf32>
%273 = flow.dispatch.tensor.load %270, offsets = [0], sizes = [%arg4], strides = [1] : !flow.dispatch.tensor<readonly:?xf32>{%arg4} -> tensor<?xf32>
%274 = linalg.init_tensor [12, 512, 512] : tensor<12x512x512xf32>
%275 = linalg.fill(%cst_106, %274) : f32, tensor<12x512x512xf32> -> tensor<12x512x512xf32>
%276 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%271, %272 : tensor<12x512x32xf32>, tensor<12x512x32xf32>) outs(%275 : tensor<12x512x512xf32>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
%278 = arith.mulf %arg6, %arg7 : f32
%279 = arith.addf %278, %arg8 : f32
linalg.yield %279 : f32
} -> tensor<12x512x512xf32>
%277 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%276, %273 : tensor<12x512x512xf32>, tensor<?xf32>) outs(%274 : tensor<12x512x512xf32>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
%278 = arith.truncf %cst_105 : f64 to f32
%279 = arith.divf %arg6, %278 : f32
%280 = arith.addf %279, %arg7 : f32
linalg.yield %280 : f32
} -> tensor<12x512x512xf32>
flow.dispatch.tensor.store %277, %arg5, offsets = [0, 0, 0], sizes = [12, 512, 512], strides = [1, 1, 1] : tensor<12x512x512xf32> -> !flow.dispatch.tensor<writeonly:12x512x512xf32>
flow.return
}
%54:2 = flow.dispatch.workgroups[%c512, %c12, %c1](%53) : (tensor<12x512x512xf32>) -
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment