-
-
Save vivekkhandelwal1/5b07ce3c403b99dfda8ed64f5174595b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#map = affine_map<(d0, d1) -> (0, d1)> | |
#map1 = affine_map<(d0, d1) -> (d0, d1)> | |
#map2 = affine_map<(d0, d1) -> (d0)> | |
#map3 = affine_map<(d0, d1, d2, d3) -> ()> | |
#map4 = affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)> | |
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | |
#map6 = affine_map<(d0, d1) -> ()> | |
#map7 = affine_map<(d0, d1, d2) -> (0, d1, d2)> | |
#map8 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | |
#map9 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> | |
#map10 = affine_map<(d0, d1, d2) -> (0, d1, 0)> | |
#map11 = affine_map<(d0, d1, d2) -> (d2)> | |
#map12 = affine_map<(d0, d1, d2) -> (d1, d2)> | |
#map13 = affine_map<(d0, d1, d2) -> ()> | |
#map14 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> | |
#map15 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> | |
#map16 = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)> | |
#map17 = affine_map<(d0, d1, d2) -> (d0, d1)> | |
#map18 = affine_map<(d0, d1) -> (d0, 0)> | |
#map19 = affine_map<(d0, d1) -> (d1)> | |
#map20 = affine_map<(d0, d1) -> (d1, d0)> | |
module { | |
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64> | |
func.func @main_graph(%arg0: tensor<1x8xi64>) -> (tensor<1x8x50272xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>) { | |
%cst = arith.constant dense<1> : tensor<1x8xi64> | |
%cst_0 = arith.constant 0.000000e+00 : f32 | |
%cst_1 = arith.constant 0xFF800000 : f32 | |
%c2050_i64 = arith.constant 2050 : i64 | |
%c50272_i64 = arith.constant 50272 : i64 | |
%cst_2 = arith.constant 7.680000e+02 : f32 | |
%cst_3 = arith.constant 9.99999974E-6 : f32 | |
%cst_4 = arith.constant dense<1> : tensor<i64> | |
%cst_5 = arith.constant dense<2> : tensor<i64> | |
%cst_6 = arith.constant dense<1.250000e-01> : tensor<f32> | |
%cst_7 = arith.constant dense<1.000000e+00> : tensor<1x1x8x8xf32> | |
%cst_8 = arith.constant dense<-3.40282347E+38> : tensor<f32> | |
%cst_9 = arith.constant dense<1.000000e+00> : tensor<f32> | |
%cst_10 = arith.constant dense_resource<__elided__> : tensor<1x1x8x8xf32> | |
%cst_11 = arith.constant dense_resource<__elided__> : tensor<50272x768xf32> | |
%cst_12 = arith.constant dense_resource<__elided__> : tensor<2050x768xf32> | |
%cst_13 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_14 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_15 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_16 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_17 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_18 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_19 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_20 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_21 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_22 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_23 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_24 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_25 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_26 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_27 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_28 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_29 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_30 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_31 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_32 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_33 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_34 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_35 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_36 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_37 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_38 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_39 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_40 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_41 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_42 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_43 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_44 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_45 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_46 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_47 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_48 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_49 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_50 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_51 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_52 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_53 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_54 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_55 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_56 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_57 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_58 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_59 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_60 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_61 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_62 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_63 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_64 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_65 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_66 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_67 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_68 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_69 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_70 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_71 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_72 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_73 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_74 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_75 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_76 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_77 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_78 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_79 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_80 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_81 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_82 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_83 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_84 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_85 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_86 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_87 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_88 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_89 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_90 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_91 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_92 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_93 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_94 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_95 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_96 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_97 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_98 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_99 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_100 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_101 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_102 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_103 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_104 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_105 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_106 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_107 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_108 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_109 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_110 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_111 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_112 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_113 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_114 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_115 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_116 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_117 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_118 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_119 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_120 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_121 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_122 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_123 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_124 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_125 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_126 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_127 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_128 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_129 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_130 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_131 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_132 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_133 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_134 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_135 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_136 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_137 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_138 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_139 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_140 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_141 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_142 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_143 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_144 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_145 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_146 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_147 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_148 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_149 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_150 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_151 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_152 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_153 = arith.constant dense_resource<__elided__> : tensor<3072x768xf32> | |
%cst_154 = arith.constant dense_resource<__elided__> : tensor<3072xf32> | |
%cst_155 = arith.constant dense_resource<__elided__> : tensor<768x3072xf32> | |
%cst_156 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_157 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_158 = arith.constant dense_resource<__elided__> : tensor<768xf32> | |
%cst_159 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_160 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_161 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_162 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_163 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_164 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_165 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_166 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_167 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_168 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_169 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_170 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_171 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_172 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_173 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_174 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_175 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_176 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_177 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_178 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_179 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_180 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_181 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_182 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_183 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_184 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_185 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_186 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_187 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_188 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_189 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_190 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_191 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_192 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_193 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_194 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_195 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_196 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_197 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_198 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_199 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_200 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_201 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_202 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_203 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_204 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_205 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_206 = arith.constant dense_resource<__elided__> : tensor<768x768xf32> | |
%cst_207 = arith.constant dense_resource<__elided__> : tensor<768x50272xf32> | |
%c0_i64 = arith.constant 0 : i64 | |
%0 = tensor.empty() : tensor<1x8xi1> | |
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x8xi64>) outs(%0 : tensor<1x8xi1>) { | |
^bb0(%in: i64, %out: i1): | |
%839 = arith.cmpi slt, %in, %c0_i64 : i64 | |
linalg.yield %839 : i1 | |
} -> tensor<1x8xi1> | |
%2 = tensor.empty() : tensor<1x8xi64> | |
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x8xi64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%839 = arith.addi %in, %c50272_i64 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%4 = linalg.generic {indexing_maps = [#map, #map, #map, #map1], iterator_types = ["parallel", "parallel"]} ins(%1, %3, %arg0 : tensor<1x8xi1>, tensor<1x8xi64>, tensor<1x8xi64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i1, %in_366: i64, %in_367: i64, %out: i64): | |
%839 = arith.select %in, %in_366, %in_367 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<1x8xi64> into tensor<8xi64> | |
%5 = tensor.empty() : tensor<8x768xf32> | |
%6 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<8xi64>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: i64, %out: f32): | |
%839 = arith.index_cast %in : i64 to index | |
%840 = linalg.index 1 : index | |
%extracted = tensor.extract %cst_11[%839, %840] : tensor<50272x768xf32> | |
linalg.yield %extracted : f32 | |
} -> tensor<8x768xf32> | |
%expanded = tensor.expand_shape %6 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%7 = tensor.empty() : tensor<1x1x8x8xf32> | |
%8 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_9, %cst_7 : tensor<f32>, tensor<1x1x8x8xf32>) outs(%7 : tensor<1x1x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x1x8x8xf32> | |
%9 = tensor.empty() : tensor<1x1x8x8xi1> | |
%10 = linalg.generic {indexing_maps = [#map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<1x1x8x8xf32>) outs(%9 : tensor<1x1x8x8xi1>) { | |
^bb0(%in: f32, %out: i1): | |
%839 = arith.cmpf une, %in, %cst_0 : f32 | |
linalg.yield %839 : i1 | |
} -> tensor<1x1x8x8xi1> | |
%11 = linalg.generic {indexing_maps = [#map4, #map3, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10, %cst_8, %8 : tensor<1x1x8x8xi1>, tensor<f32>, tensor<1x1x8x8xf32>) outs(%7 : tensor<1x1x8x8xf32>) { | |
^bb0(%in: i1, %in_366: f32, %in_367: f32, %out: f32): | |
%839 = arith.select %in, %in_366, %in_367 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x1x8x8xf32> | |
%12 = linalg.generic {indexing_maps = [#map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11 : tensor<1x1x8x8xf32>) outs(%9 : tensor<1x1x8x8xi1>) { | |
^bb0(%in: f32, %out: i1): | |
%839 = arith.cmpf une, %in, %cst_0 : f32 | |
linalg.yield %839 : i1 | |
} -> tensor<1x1x8x8xi1> | |
%13 = linalg.generic {indexing_maps = [#map4, #map3, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12, %cst_8, %cst_10 : tensor<1x1x8x8xi1>, tensor<f32>, tensor<1x1x8x8xf32>) outs(%7 : tensor<1x1x8x8xf32>) { | |
^bb0(%in: i1, %in_366: f32, %in_367: f32, %out: f32): | |
%839 = arith.select %in, %in_366, %in_367 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x1x8x8xf32> | |
%14 = linalg.fill ins(%c0_i64 : i64) outs(%2 : tensor<1x8xi64>) -> tensor<1x8xi64> | |
%15 = tensor.empty() : tensor<1xi64> | |
%16 = linalg.fill ins(%c0_i64 : i64) outs(%15 : tensor<1xi64>) -> tensor<1xi64> | |
%17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) { | |
^bb0(%arg1: i64, %arg2: i64): | |
%839 = arith.addi %arg1, %arg2 : i64 | |
tm_tensor.yield %839 : i64 | |
} -> tensor<1x8xi64>, tensor<1xi64> | |
%18 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel"]} ins(%17#0, %cst : tensor<1x8xi64>, tensor<1x8xi64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i64, %in_366: i64, %out: i64): | |
%839 = arith.muli %in, %in_366 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%19 = linalg.generic {indexing_maps = [#map, #map6, #map1], iterator_types = ["parallel", "parallel"]} ins(%18, %cst_4 : tensor<1x8xi64>, tensor<i64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i64, %in_366: i64, %out: i64): | |
%839 = arith.subi %in, %in_366 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%20 = linalg.generic {indexing_maps = [#map, #map6, #map1], iterator_types = ["parallel", "parallel"]} ins(%19, %cst_5 : tensor<1x8xi64>, tensor<i64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i64, %in_366: i64, %out: i64): | |
%839 = arith.addi %in, %in_366 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%21 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%20 : tensor<1x8xi64>) outs(%0 : tensor<1x8xi1>) { | |
^bb0(%in: i64, %out: i1): | |
%839 = arith.cmpi slt, %in, %c0_i64 : i64 | |
linalg.yield %839 : i1 | |
} -> tensor<1x8xi1> | |
%22 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%20 : tensor<1x8xi64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%839 = arith.addi %in, %c2050_i64 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%23 = linalg.generic {indexing_maps = [#map, #map, #map, #map1], iterator_types = ["parallel", "parallel"]} ins(%21, %22, %20 : tensor<1x8xi1>, tensor<1x8xi64>, tensor<1x8xi64>) outs(%2 : tensor<1x8xi64>) { | |
^bb0(%in: i1, %in_366: i64, %in_367: i64, %out: i64): | |
%839 = arith.select %in, %in_366, %in_367 : i64 | |
linalg.yield %839 : i64 | |
} -> tensor<1x8xi64> | |
%collapsed_208 = tensor.collapse_shape %23 [[0, 1]] : tensor<1x8xi64> into tensor<8xi64> | |
%24 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_208 : tensor<8xi64>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: i64, %out: f32): | |
%839 = arith.index_cast %in : i64 to index | |
%840 = linalg.index 1 : index | |
%extracted = tensor.extract %cst_12[%839, %840] : tensor<2050x768xf32> | |
linalg.yield %extracted : f32 | |
} -> tensor<8x768xf32> | |
%expanded_209 = tensor.expand_shape %24 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%25 = tensor.empty() : tensor<1x8x768xf32> | |
%26 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded, %expanded_209 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%27 = tensor.empty() : tensor<1x8x1xf32> | |
%28 = linalg.fill ins(%cst_0 : f32) outs(%27 : tensor<1x8x1xf32>) -> tensor<1x8x1xf32> | |
%29 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%26 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%30 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%29 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%31 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%30 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%32 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%26, %31 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%33 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%32, %32 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%34 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%33 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%35 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%34 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%36 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%35 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%37 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%36 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%38 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%37 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%39 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%32, %38 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%40 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%39, %cst_19 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%41 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%40, %cst_20 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%42 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%43 = tensor.empty() : tensor<1x768x768xf32> | |
%44 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_159 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%45 = linalg.fill ins(%cst_0 : f32) outs(%25 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%46 = linalg.batch_matmul ins(%42, %44 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%47 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_17, %46 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%48 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%49 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_160 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%50 = linalg.batch_matmul ins(%42, %49 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%51 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_15, %50 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_210 = tensor.expand_shape %51 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%52 = tensor.empty() : tensor<1x12x8x64xf32> | |
%53 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_210 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%54 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_161 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%55 = linalg.batch_matmul ins(%42, %54 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%56 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16, %55 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_211 = tensor.expand_shape %56 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%57 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_211 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_212 = tensor.expand_shape %48 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%58 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_212 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_213 = tensor.collapse_shape %58 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_214 = tensor.collapse_shape %53 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_215 = tensor.collapse_shape %57 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%59 = tensor.empty() : tensor<12x64x8xf32> | |
%60 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_214 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%61 = tensor.empty() : tensor<12x8x8xf32> | |
%62 = linalg.fill ins(%cst_0 : f32) outs(%61 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%63 = linalg.batch_matmul ins(%collapsed_213, %60 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_216 = tensor.expand_shape %63 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%64 = tensor.empty() : tensor<1x12x8x8xf32> | |
%65 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_216, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%66 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%65, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_217 = tensor.collapse_shape %66 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%67 = tensor.empty() : tensor<12x8xi64> | |
%68 = linalg.fill ins(%c0_i64 : i64) outs(%67 : tensor<12x8xi64>) -> tensor<12x8xi64> | |
%69 = tensor.empty() : tensor<12x8xf32> | |
%70 = linalg.fill ins(%cst_1 : f32) outs(%69 : tensor<12x8xf32>) -> tensor<12x8xf32> | |
%71:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_217 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_218 = tensor.expand_shape %71#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%72 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_217, %expanded_218 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%73 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%72 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%74 = tensor.empty() : tensor<12x8x1xf32> | |
%75 = linalg.fill ins(%cst_0 : f32) outs(%74 : tensor<12x8x1xf32>) -> tensor<12x8x1xf32> | |
%76 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%73 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%77 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%73, %76 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%78 = tensor.empty() : tensor<12x8x64xf32> | |
%79 = linalg.fill ins(%cst_0 : f32) outs(%78 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%80 = linalg.batch_matmul ins(%77, %collapsed_215 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_219 = tensor.expand_shape %80 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%81 = tensor.empty() : tensor<1x8x12x64xf32> | |
%82 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_219 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_220 = tensor.collapse_shape %82 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%83 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_220 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%84 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_162 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%85 = linalg.batch_matmul ins(%83, %84 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%86 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18, %85 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%87 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%26, %86 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_221 = tensor.collapse_shape %87 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%88 = tensor.empty() : tensor<8x1xf32> | |
%89 = linalg.fill ins(%cst_0 : f32) outs(%88 : tensor<8x1xf32>) -> tensor<8x1xf32> | |
%90 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_221 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%91 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%90 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%92 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%91 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%93 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_221, %92 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%94 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%93, %93 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%95 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%94 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%96 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%95 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%97 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%96 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%98 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%97 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%99 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%98 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%100 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%93, %99 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%101 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%100, %cst_25 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%102 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%101, %cst_26 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%103 = tensor.empty() : tensor<768x3072xf32> | |
%104 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_21 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%105 = tensor.empty() : tensor<8x3072xf32> | |
%106 = linalg.fill ins(%cst_0 : f32) outs(%105 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%107 = linalg.matmul ins(%102, %104 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%108 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%107, %cst_22 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%109 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%108 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%110 = tensor.empty() : tensor<3072x768xf32> | |
%111 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_23 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%112 = linalg.fill ins(%cst_0 : f32) outs(%5 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%113 = linalg.matmul ins(%109, %111 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%114 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%113, %cst_24 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%115 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_221, %114 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_222 = tensor.expand_shape %115 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%116 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_222 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%117 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%116 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%118 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%117 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%119 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_222, %118 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%120 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%119, %119 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%121 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%120 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%122 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%121 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%123 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%122 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%124 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%123 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%125 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%124 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%126 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%119, %125 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%127 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%126, %cst_31 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%128 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%127, %cst_32 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%129 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%128 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%130 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_163 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%131 = linalg.batch_matmul ins(%129, %130 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%132 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_29, %131 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%133 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%132, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%134 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_164 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%135 = linalg.batch_matmul ins(%129, %134 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%136 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_27, %135 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_223 = tensor.expand_shape %136 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%137 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_223 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%138 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_165 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%139 = linalg.batch_matmul ins(%129, %138 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%140 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_28, %139 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_224 = tensor.expand_shape %140 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%141 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_224 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_225 = tensor.expand_shape %133 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%142 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_225 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_226 = tensor.collapse_shape %142 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_227 = tensor.collapse_shape %137 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_228 = tensor.collapse_shape %141 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%143 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_227 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%144 = linalg.batch_matmul ins(%collapsed_226, %143 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_229 = tensor.expand_shape %144 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%145 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_229, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%146 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%145, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_230 = tensor.collapse_shape %146 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%147:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_230 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_231 = tensor.expand_shape %147#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%148 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_230, %expanded_231 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%149 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%148 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%150 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%149 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%151 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%149, %150 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%152 = linalg.batch_matmul ins(%151, %collapsed_228 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_232 = tensor.expand_shape %152 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%153 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_232 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_233 = tensor.collapse_shape %153 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%154 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_233 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%155 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_166 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%156 = linalg.batch_matmul ins(%154, %155 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%157 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_30, %156 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%158 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_222, %157 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_234 = tensor.collapse_shape %158 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%159 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_234 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%160 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%159 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%161 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%160 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%162 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_234, %161 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%163 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%162, %162 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%164 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%163 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%165 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%164 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%166 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%165 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%167 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%166 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%168 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%167 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%169 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%162, %168 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%170 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%169, %cst_37 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%171 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%170, %cst_38 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%172 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_33 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%173 = linalg.matmul ins(%171, %172 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%174 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%173, %cst_34 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%175 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%174 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%176 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_35 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%177 = linalg.matmul ins(%175, %176 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%178 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%177, %cst_36 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%179 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_234, %178 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_235 = tensor.expand_shape %179 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%180 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_235 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%181 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%180 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%182 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%181 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%183 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_235, %182 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%184 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%183, %183 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%185 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%184 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%186 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%185 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%187 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%186 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%188 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%187 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%189 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%188 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%190 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%183, %189 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%191 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%190, %cst_43 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%192 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%191, %cst_44 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%193 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%192 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%194 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_167 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%195 = linalg.batch_matmul ins(%193, %194 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%196 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_41, %195 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%197 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%196, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%198 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_168 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%199 = linalg.batch_matmul ins(%193, %198 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%200 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_39, %199 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_236 = tensor.expand_shape %200 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%201 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_236 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%202 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_169 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%203 = linalg.batch_matmul ins(%193, %202 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%204 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_40, %203 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_237 = tensor.expand_shape %204 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%205 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_237 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_238 = tensor.expand_shape %197 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%206 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_238 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_239 = tensor.collapse_shape %206 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_240 = tensor.collapse_shape %201 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_241 = tensor.collapse_shape %205 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%207 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_240 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%208 = linalg.batch_matmul ins(%collapsed_239, %207 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_242 = tensor.expand_shape %208 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%209 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_242, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%210 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%209, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_243 = tensor.collapse_shape %210 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%211:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_243 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_244 = tensor.expand_shape %211#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%212 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_243, %expanded_244 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%213 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%212 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%214 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%213 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%215 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%213, %214 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%216 = linalg.batch_matmul ins(%215, %collapsed_241 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_245 = tensor.expand_shape %216 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%217 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_245 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_246 = tensor.collapse_shape %217 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%218 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_246 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%219 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_170 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%220 = linalg.batch_matmul ins(%218, %219 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%221 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_42, %220 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%222 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_235, %221 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_247 = tensor.collapse_shape %222 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%223 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_247 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%224 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%223 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%225 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%224 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%226 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_247, %225 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%227 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%226, %226 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%228 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%227 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%229 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%228 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%230 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%229 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%231 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%230 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%232 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%231 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%233 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%226, %232 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%234 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%233, %cst_49 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%235 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%234, %cst_50 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%236 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_45 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%237 = linalg.matmul ins(%235, %236 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%238 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%237, %cst_46 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%239 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%238 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%240 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_47 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%241 = linalg.matmul ins(%239, %240 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%242 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%241, %cst_48 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%243 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_247, %242 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_248 = tensor.expand_shape %243 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%244 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_248 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%245 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%244 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%246 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%245 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%247 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_248, %246 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%248 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%247, %247 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%249 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%248 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%250 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%249 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%251 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%250 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%252 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%251 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%253 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%252 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%254 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%247, %253 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%255 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%254, %cst_55 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%256 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%255, %cst_56 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%257 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%256 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%258 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_171 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%259 = linalg.batch_matmul ins(%257, %258 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%260 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_53, %259 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%261 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%260, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%262 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_172 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%263 = linalg.batch_matmul ins(%257, %262 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%264 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_51, %263 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_249 = tensor.expand_shape %264 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%265 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_249 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%266 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_173 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%267 = linalg.batch_matmul ins(%257, %266 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%268 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_52, %267 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_250 = tensor.expand_shape %268 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%269 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_250 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_251 = tensor.expand_shape %261 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%270 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_251 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_252 = tensor.collapse_shape %270 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_253 = tensor.collapse_shape %265 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_254 = tensor.collapse_shape %269 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%271 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_253 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%272 = linalg.batch_matmul ins(%collapsed_252, %271 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_255 = tensor.expand_shape %272 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%273 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_255, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%274 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%273, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_256 = tensor.collapse_shape %274 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%275:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_256 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_257 = tensor.expand_shape %275#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%276 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_256, %expanded_257 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%277 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%276 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%278 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%277 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%279 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%277, %278 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%280 = linalg.batch_matmul ins(%279, %collapsed_254 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_258 = tensor.expand_shape %280 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%281 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_258 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_259 = tensor.collapse_shape %281 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%282 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_259 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%283 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_174 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%284 = linalg.batch_matmul ins(%282, %283 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%285 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_54, %284 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%286 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_248, %285 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_260 = tensor.collapse_shape %286 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%287 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_260 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%288 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%287 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%289 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%288 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%290 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_260, %289 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%291 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%290, %290 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%292 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%291 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%293 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%292 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%294 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%293 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%295 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%294 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%296 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%295 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%297 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%290, %296 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%298 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%297, %cst_61 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%299 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%298, %cst_62 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%300 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_57 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%301 = linalg.matmul ins(%299, %300 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%302 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%301, %cst_58 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%303 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%302 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%304 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_59 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%305 = linalg.matmul ins(%303, %304 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%306 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%305, %cst_60 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%307 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_260, %306 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_261 = tensor.expand_shape %307 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%308 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_261 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%309 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%308 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%310 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%309 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%311 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_261, %310 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%312 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%311, %311 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%313 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%312 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%314 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%313 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%315 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%314 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%316 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%315 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%317 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%316 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%318 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%311, %317 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%319 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%318, %cst_67 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%320 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%319, %cst_68 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%321 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%320 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%322 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_175 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%323 = linalg.batch_matmul ins(%321, %322 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%324 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_65, %323 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%325 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%324, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%326 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_176 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%327 = linalg.batch_matmul ins(%321, %326 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%328 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_63, %327 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_262 = tensor.expand_shape %328 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%329 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_262 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%330 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_177 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%331 = linalg.batch_matmul ins(%321, %330 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%332 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_64, %331 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_263 = tensor.expand_shape %332 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%333 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_263 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_264 = tensor.expand_shape %325 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%334 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_264 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_265 = tensor.collapse_shape %334 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_266 = tensor.collapse_shape %329 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_267 = tensor.collapse_shape %333 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%335 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_266 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%336 = linalg.batch_matmul ins(%collapsed_265, %335 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_268 = tensor.expand_shape %336 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%337 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_268, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%338 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%337, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_269 = tensor.collapse_shape %338 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%339:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_269 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_270 = tensor.expand_shape %339#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%340 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_269, %expanded_270 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%341 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%340 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%342 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%341 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%343 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%341, %342 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%344 = linalg.batch_matmul ins(%343, %collapsed_267 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_271 = tensor.expand_shape %344 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%345 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_271 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_272 = tensor.collapse_shape %345 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%346 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_272 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%347 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_178 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%348 = linalg.batch_matmul ins(%346, %347 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%349 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_66, %348 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%350 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_261, %349 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_273 = tensor.collapse_shape %350 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%351 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_273 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%352 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%351 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%353 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%352 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%354 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_273, %353 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%355 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%354, %354 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%356 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%355 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%357 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%356 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%358 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%357 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%359 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%358 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%360 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%359 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%361 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%354, %360 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%362 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%361, %cst_73 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%363 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%362, %cst_74 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%364 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_69 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%365 = linalg.matmul ins(%363, %364 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%366 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%365, %cst_70 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%367 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%366 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%368 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_71 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%369 = linalg.matmul ins(%367, %368 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%370 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%369, %cst_72 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%371 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_273, %370 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_274 = tensor.expand_shape %371 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%372 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_274 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%373 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%372 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%374 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%373 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%375 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_274, %374 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%376 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%375, %375 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%377 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%376 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%378 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%377 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%379 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%378 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%380 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%379 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%381 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%380 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%382 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%375, %381 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%383 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%382, %cst_79 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%384 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%383, %cst_80 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%385 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%384 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%386 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_179 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%387 = linalg.batch_matmul ins(%385, %386 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%388 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_77, %387 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%389 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%388, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%390 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_180 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%391 = linalg.batch_matmul ins(%385, %390 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%392 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_75, %391 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_275 = tensor.expand_shape %392 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%393 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_275 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%394 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_181 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%395 = linalg.batch_matmul ins(%385, %394 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%396 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_76, %395 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_276 = tensor.expand_shape %396 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%397 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_276 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_277 = tensor.expand_shape %389 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%398 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_277 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_278 = tensor.collapse_shape %398 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_279 = tensor.collapse_shape %393 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_280 = tensor.collapse_shape %397 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%399 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_279 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%400 = linalg.batch_matmul ins(%collapsed_278, %399 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_281 = tensor.expand_shape %400 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%401 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_281, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%402 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%401, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_282 = tensor.collapse_shape %402 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%403:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_282 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_283 = tensor.expand_shape %403#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%404 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_282, %expanded_283 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%405 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%404 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%406 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%405 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%407 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%405, %406 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%408 = linalg.batch_matmul ins(%407, %collapsed_280 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_284 = tensor.expand_shape %408 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%409 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_284 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_285 = tensor.collapse_shape %409 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%410 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_285 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%411 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_182 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%412 = linalg.batch_matmul ins(%410, %411 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%413 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_78, %412 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%414 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_274, %413 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_286 = tensor.collapse_shape %414 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%415 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_286 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%416 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%415 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%417 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%416 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%418 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_286, %417 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%419 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%418, %418 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%420 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%419 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%421 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%420 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%422 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%421 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%423 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%422 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%424 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%423 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%425 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%418, %424 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%426 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%425, %cst_85 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%427 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%426, %cst_86 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%428 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_81 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%429 = linalg.matmul ins(%427, %428 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%430 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%429, %cst_82 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%431 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%430 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%432 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_83 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%433 = linalg.matmul ins(%431, %432 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%434 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%433, %cst_84 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%435 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_286, %434 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_287 = tensor.expand_shape %435 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%436 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_287 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%437 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%436 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%438 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%437 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%439 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_287, %438 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%440 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%439, %439 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%441 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%440 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%442 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%441 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%443 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%442 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%444 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%443 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%445 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%444 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%446 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%439, %445 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%447 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%446, %cst_91 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%448 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%447, %cst_92 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%449 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%448 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%450 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_183 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%451 = linalg.batch_matmul ins(%449, %450 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%452 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_89, %451 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%453 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%452, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%454 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_184 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%455 = linalg.batch_matmul ins(%449, %454 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%456 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_87, %455 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_288 = tensor.expand_shape %456 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%457 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_288 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%458 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_185 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%459 = linalg.batch_matmul ins(%449, %458 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%460 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_88, %459 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_289 = tensor.expand_shape %460 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%461 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_289 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_290 = tensor.expand_shape %453 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%462 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_290 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_291 = tensor.collapse_shape %462 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_292 = tensor.collapse_shape %457 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_293 = tensor.collapse_shape %461 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%463 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_292 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%464 = linalg.batch_matmul ins(%collapsed_291, %463 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_294 = tensor.expand_shape %464 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%465 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_294, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%466 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%465, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_295 = tensor.collapse_shape %466 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%467:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_295 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_296 = tensor.expand_shape %467#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%468 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_295, %expanded_296 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%469 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%468 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%470 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%469 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%471 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%469, %470 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%472 = linalg.batch_matmul ins(%471, %collapsed_293 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_297 = tensor.expand_shape %472 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%473 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_297 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_298 = tensor.collapse_shape %473 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%474 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_298 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%475 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_186 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%476 = linalg.batch_matmul ins(%474, %475 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%477 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_90, %476 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%478 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_287, %477 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_299 = tensor.collapse_shape %478 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%479 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_299 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%480 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%479 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%481 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%480 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%482 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_299, %481 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%483 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%482, %482 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%484 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%483 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%485 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%484 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%486 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%485 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%487 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%486 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%488 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%487 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%489 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%482, %488 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%490 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%489, %cst_97 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%491 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%490, %cst_98 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%492 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_93 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%493 = linalg.matmul ins(%491, %492 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%494 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%493, %cst_94 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%495 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%494 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%496 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_95 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%497 = linalg.matmul ins(%495, %496 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%498 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%497, %cst_96 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%499 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_299, %498 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_300 = tensor.expand_shape %499 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%500 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_300 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%501 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%500 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%502 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%501 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%503 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_300, %502 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%504 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%503, %503 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%505 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%504 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%506 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%505 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%507 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%506 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%508 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%507 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%509 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%508 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%510 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%503, %509 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%511 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%510, %cst_103 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%512 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%511, %cst_104 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%513 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%512 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%514 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_187 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%515 = linalg.batch_matmul ins(%513, %514 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%516 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_101, %515 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%517 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%516, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%518 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_188 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%519 = linalg.batch_matmul ins(%513, %518 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%520 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_99, %519 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_301 = tensor.expand_shape %520 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%521 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_301 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%522 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_189 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%523 = linalg.batch_matmul ins(%513, %522 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%524 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_100, %523 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_302 = tensor.expand_shape %524 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%525 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_302 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_303 = tensor.expand_shape %517 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%526 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_303 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_304 = tensor.collapse_shape %526 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_305 = tensor.collapse_shape %521 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_306 = tensor.collapse_shape %525 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%527 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_305 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%528 = linalg.batch_matmul ins(%collapsed_304, %527 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_307 = tensor.expand_shape %528 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%529 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_307, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%530 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%529, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_308 = tensor.collapse_shape %530 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%531:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_308 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_309 = tensor.expand_shape %531#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%532 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_308, %expanded_309 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%533 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%532 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%534 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%533 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%535 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%533, %534 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%536 = linalg.batch_matmul ins(%535, %collapsed_306 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_310 = tensor.expand_shape %536 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%537 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_310 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_311 = tensor.collapse_shape %537 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%538 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_311 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%539 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_190 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%540 = linalg.batch_matmul ins(%538, %539 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%541 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_102, %540 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%542 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_300, %541 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_312 = tensor.collapse_shape %542 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%543 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_312 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%544 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%543 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%545 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%544 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%546 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_312, %545 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%547 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%546, %546 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%548 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%547 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%549 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%548 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%550 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%549 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%551 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%550 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%552 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%551 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%553 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%546, %552 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%554 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%553, %cst_109 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%555 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%554, %cst_110 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%556 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_105 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%557 = linalg.matmul ins(%555, %556 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%558 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%557, %cst_106 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%559 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%558 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%560 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_107 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%561 = linalg.matmul ins(%559, %560 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%562 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%561, %cst_108 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%563 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_312, %562 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_313 = tensor.expand_shape %563 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%564 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_313 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%565 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%564 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%566 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%565 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%567 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_313, %566 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%568 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%567, %567 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%569 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%568 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%570 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%569 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%571 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%570 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%572 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%571 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%573 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%572 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%574 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%567, %573 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%575 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%574, %cst_115 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%576 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%575, %cst_116 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%577 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%576 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%578 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_191 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%579 = linalg.batch_matmul ins(%577, %578 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%580 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_113, %579 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%581 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%580, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%582 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_192 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%583 = linalg.batch_matmul ins(%577, %582 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%584 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_111, %583 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_314 = tensor.expand_shape %584 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%585 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_314 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%586 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_193 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%587 = linalg.batch_matmul ins(%577, %586 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%588 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_112, %587 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_315 = tensor.expand_shape %588 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%589 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_315 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_316 = tensor.expand_shape %581 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%590 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_316 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_317 = tensor.collapse_shape %590 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_318 = tensor.collapse_shape %585 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_319 = tensor.collapse_shape %589 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%591 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_318 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%592 = linalg.batch_matmul ins(%collapsed_317, %591 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_320 = tensor.expand_shape %592 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%593 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_320, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%594 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%593, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_321 = tensor.collapse_shape %594 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%595:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_321 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_322 = tensor.expand_shape %595#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%596 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_321, %expanded_322 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%597 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%596 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%598 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%597 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%599 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%597, %598 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%600 = linalg.batch_matmul ins(%599, %collapsed_319 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_323 = tensor.expand_shape %600 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%601 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_323 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_324 = tensor.collapse_shape %601 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%602 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_324 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%603 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_194 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%604 = linalg.batch_matmul ins(%602, %603 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%605 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_114, %604 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%606 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_313, %605 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_325 = tensor.collapse_shape %606 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%607 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_325 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%608 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%607 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%609 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%608 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%610 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_325, %609 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%611 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%610, %610 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%612 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%611 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%613 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%612 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%614 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%613 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%615 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%614 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%616 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%615 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%617 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%610, %616 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%618 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%617, %cst_121 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%619 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%618, %cst_122 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%620 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_117 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%621 = linalg.matmul ins(%619, %620 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%622 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%621, %cst_118 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%623 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%622 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%624 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_119 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%625 = linalg.matmul ins(%623, %624 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%626 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%625, %cst_120 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%627 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_325, %626 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_326 = tensor.expand_shape %627 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%628 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_326 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%629 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%628 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%630 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%629 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%631 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_326, %630 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%632 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%631, %631 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%633 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%632 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%634 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%633 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%635 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%634 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%636 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%635 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%637 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%636 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%638 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%631, %637 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%639 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%638, %cst_127 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%640 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%639, %cst_128 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%641 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%640 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%642 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_195 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%643 = linalg.batch_matmul ins(%641, %642 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%644 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_125, %643 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%645 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%644, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%646 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_196 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%647 = linalg.batch_matmul ins(%641, %646 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%648 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_123, %647 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_327 = tensor.expand_shape %648 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%649 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_327 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%650 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_197 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%651 = linalg.batch_matmul ins(%641, %650 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%652 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_124, %651 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_328 = tensor.expand_shape %652 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%653 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_328 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_329 = tensor.expand_shape %645 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%654 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_329 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_330 = tensor.collapse_shape %654 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_331 = tensor.collapse_shape %649 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_332 = tensor.collapse_shape %653 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%655 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_331 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%656 = linalg.batch_matmul ins(%collapsed_330, %655 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_333 = tensor.expand_shape %656 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%657 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_333, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%658 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%657, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_334 = tensor.collapse_shape %658 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%659:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_334 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_335 = tensor.expand_shape %659#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%660 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_334, %expanded_335 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%661 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%660 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%662 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%661 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%663 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%661, %662 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%664 = linalg.batch_matmul ins(%663, %collapsed_332 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_336 = tensor.expand_shape %664 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%665 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_336 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_337 = tensor.collapse_shape %665 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%666 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_337 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%667 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_198 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%668 = linalg.batch_matmul ins(%666, %667 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%669 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_126, %668 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%670 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_326, %669 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_338 = tensor.collapse_shape %670 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%671 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_338 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%672 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%671 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%673 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%672 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%674 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_338, %673 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%675 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%674, %674 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%676 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%675 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%677 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%676 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%678 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%677 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%679 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%678 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%680 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%679 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%681 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%674, %680 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%682 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%681, %cst_133 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%683 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%682, %cst_134 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%684 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_129 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%685 = linalg.matmul ins(%683, %684 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%686 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%685, %cst_130 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%687 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%686 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%688 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_131 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%689 = linalg.matmul ins(%687, %688 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%690 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%689, %cst_132 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%691 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_338, %690 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_339 = tensor.expand_shape %691 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%692 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_339 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%693 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%692 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%694 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%693 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%695 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_339, %694 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%696 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%695, %695 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%697 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%696 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%698 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%697 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%699 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%698 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%700 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%699 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%701 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%700 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%702 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%695, %701 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%703 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%702, %cst_139 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%704 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%703, %cst_140 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%705 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%704 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%706 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_199 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%707 = linalg.batch_matmul ins(%705, %706 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%708 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_137, %707 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%709 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%708, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%710 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_200 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%711 = linalg.batch_matmul ins(%705, %710 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%712 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_135, %711 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_340 = tensor.expand_shape %712 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%713 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_340 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%714 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_201 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%715 = linalg.batch_matmul ins(%705, %714 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%716 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_136, %715 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_341 = tensor.expand_shape %716 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%717 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_341 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_342 = tensor.expand_shape %709 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%718 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_342 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_343 = tensor.collapse_shape %718 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_344 = tensor.collapse_shape %713 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_345 = tensor.collapse_shape %717 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%719 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_344 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%720 = linalg.batch_matmul ins(%collapsed_343, %719 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_346 = tensor.expand_shape %720 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%721 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_346, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%722 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%721, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_347 = tensor.collapse_shape %722 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%723:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_347 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_348 = tensor.expand_shape %723#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%724 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_347, %expanded_348 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%725 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%724 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%726 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%725 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%727 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%725, %726 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%728 = linalg.batch_matmul ins(%727, %collapsed_345 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_349 = tensor.expand_shape %728 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%729 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_349 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_350 = tensor.collapse_shape %729 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%730 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_350 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%731 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_202 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%732 = linalg.batch_matmul ins(%730, %731 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%733 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_138, %732 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%734 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_339, %733 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_351 = tensor.collapse_shape %734 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%735 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_351 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%736 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%735 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%737 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%736 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%738 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_351, %737 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%739 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%738, %738 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%740 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%739 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%741 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%740 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%742 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%741 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%743 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%742 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%744 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%743 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%745 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%738, %744 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%746 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%745, %cst_145 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%747 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%746, %cst_146 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%748 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_141 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%749 = linalg.matmul ins(%747, %748 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%750 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%749, %cst_142 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%751 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%750 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%752 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_143 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%753 = linalg.matmul ins(%751, %752 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%754 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%753, %cst_144 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%755 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_351, %754 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_352 = tensor.expand_shape %755 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%756 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_352 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%757 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%756 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%758 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%757 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%759 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_352, %758 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%760 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%759, %759 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%761 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%760 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%762 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%761 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%763 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%762 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%764 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%763 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%765 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%764 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%766 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%759, %765 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%767 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%766, %cst_151 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%768 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%767, %cst_152 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%769 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%768 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%770 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_203 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%771 = linalg.batch_matmul ins(%769, %770 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%772 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_149, %771 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%773 = linalg.generic {indexing_maps = [#map7, #map13, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%772, %cst_6 : tensor<1x8x768xf32>, tensor<f32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%774 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_204 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%775 = linalg.batch_matmul ins(%769, %774 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%776 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_147, %775 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_353 = tensor.expand_shape %776 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%777 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_353 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%778 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_205 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%779 = linalg.batch_matmul ins(%769, %778 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%780 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_148, %779 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%expanded_354 = tensor.expand_shape %780 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%781 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_354 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%expanded_355 = tensor.expand_shape %773 [[0], [1], [2, 3]] output_shape [1, 8, 12, 64] : tensor<1x8x768xf32> into tensor<1x8x12x64xf32> | |
%782 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_355 : tensor<1x8x12x64xf32>) outs(%52 : tensor<1x12x8x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x12x8x64xf32> | |
%collapsed_356 = tensor.collapse_shape %782 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_357 = tensor.collapse_shape %777 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%collapsed_358 = tensor.collapse_shape %781 [[0, 1], [2], [3]] : tensor<1x12x8x64xf32> into tensor<12x8x64xf32> | |
%783 = linalg.generic {indexing_maps = [#map8, #map15], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_357 : tensor<12x8x64xf32>) outs(%59 : tensor<12x64x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<12x64x8xf32> | |
%784 = linalg.batch_matmul ins(%collapsed_356, %783 : tensor<12x8x64xf32>, tensor<12x64x8xf32>) outs(%62 : tensor<12x8x8xf32>) -> tensor<12x8x8xf32> | |
%expanded_359 = tensor.expand_shape %784 [[0, 1], [2], [3]] output_shape [1, 12, 8, 8] : tensor<12x8x8xf32> into tensor<1x12x8x8xf32> | |
%785 = linalg.generic {indexing_maps = [#map16, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_359, %13 : tensor<1x12x8x8xf32>, tensor<1x1x8x8xf32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%786 = linalg.generic {indexing_maps = [#map16, #map3, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%785, %cst_8 : tensor<1x12x8x8xf32>, tensor<f32>) outs(%64 : tensor<1x12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.cmpf ogt, %in, %in_366 : f32 | |
%840 = arith.select %839, %in, %in_366 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<1x12x8x8xf32> | |
%collapsed_360 = tensor.collapse_shape %786 [[0, 1], [2], [3]] : tensor<1x12x8x8xf32> into tensor<12x8x8xf32> | |
%787:2 = linalg.generic {indexing_maps = [#map8, #map17, #map17], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_360 : tensor<12x8x8xf32>) outs(%70, %68 : tensor<12x8xf32>, tensor<12x8xi64>) { | |
^bb0(%in: f32, %out: f32, %out_366: i64): | |
%839 = linalg.index 2 : index | |
%840 = arith.index_cast %839 : index to i64 | |
%841 = arith.maximumf %in, %out : f32 | |
%842 = arith.cmpf ogt, %in, %out : f32 | |
%843 = arith.select %842, %840, %out_366 : i64 | |
linalg.yield %841, %843 : f32, i64 | |
} -> (tensor<12x8xf32>, tensor<12x8xi64>) | |
%expanded_361 = tensor.expand_shape %787#0 [[0], [1, 2]] output_shape [12, 8, 1] : tensor<12x8xf32> into tensor<12x8x1xf32> | |
%788 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_360, %expanded_361 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%789 = linalg.generic {indexing_maps = [#map8, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%788 : tensor<12x8x8xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.exp %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%790 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%789 : tensor<12x8x8xf32>) outs(%75 : tensor<12x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x1xf32> | |
%791 = linalg.generic {indexing_maps = [#map8, #map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%789, %790 : tensor<12x8x8xf32>, tensor<12x8x1xf32>) outs(%61 : tensor<12x8x8xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.divf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<12x8x8xf32> | |
%792 = linalg.batch_matmul ins(%791, %collapsed_358 : tensor<12x8x8xf32>, tensor<12x8x64xf32>) outs(%79 : tensor<12x8x64xf32>) -> tensor<12x8x64xf32> | |
%expanded_362 = tensor.expand_shape %792 [[0, 1], [2], [3]] output_shape [1, 12, 8, 64] : tensor<12x8x64xf32> into tensor<1x12x8x64xf32> | |
%793 = linalg.generic {indexing_maps = [#map5, #map14], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_362 : tensor<1x12x8x64xf32>) outs(%81 : tensor<1x8x12x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x12x64xf32> | |
%collapsed_363 = tensor.collapse_shape %793 [[0], [1], [2, 3]] : tensor<1x8x12x64xf32> into tensor<1x8x768xf32> | |
%794 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_363 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%795 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_206 : tensor<768x768xf32>) outs(%43 : tensor<1x768x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x768xf32> | |
%796 = linalg.batch_matmul ins(%794, %795 : tensor<1x8x768xf32>, tensor<1x768x768xf32>) outs(%45 : tensor<1x8x768xf32>) -> tensor<1x8x768xf32> | |
%797 = linalg.generic {indexing_maps = [#map11, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_150, %796 : tensor<768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%798 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_352, %797 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%collapsed_364 = tensor.collapse_shape %798 [[0, 1], [2]] : tensor<1x8x768xf32> into tensor<8x768xf32> | |
%799 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%collapsed_364 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%800 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%799 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%801 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%800 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%802 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_364, %801 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%803 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%802, %802 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%804 = linalg.generic {indexing_maps = [#map1, #map18], iterator_types = ["parallel", "reduction"]} ins(%803 : tensor<8x768xf32>) outs(%89 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%805 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%804 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%806 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%805 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%807 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%806 : tensor<8x1xf32>) outs(%88 : tensor<8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x1xf32> | |
%808 = linalg.generic {indexing_maps = [#map18, #map1], iterator_types = ["parallel", "parallel"]} ins(%807 : tensor<8x1xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<8x768xf32> | |
%809 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%802, %808 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%810 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%809, %cst_157 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%811 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%810, %cst_158 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%812 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_153 : tensor<3072x768xf32>) outs(%103 : tensor<768x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<768x3072xf32> | |
%813 = linalg.matmul ins(%811, %812 : tensor<8x768xf32>, tensor<768x3072xf32>) outs(%106 : tensor<8x3072xf32>) -> tensor<8x3072xf32> | |
%814 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%813, %cst_154 : tensor<8x3072xf32>, tensor<3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x3072xf32> | |
%815 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%814 : tensor<8x3072xf32>) outs(%105 : tensor<8x3072xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.cmpf ugt, %in, %cst_0 : f32 | |
%840 = arith.select %839, %in, %cst_0 : f32 | |
linalg.yield %840 : f32 | |
} -> tensor<8x3072xf32> | |
%816 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel"]} ins(%cst_155 : tensor<768x3072xf32>) outs(%110 : tensor<3072x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<3072x768xf32> | |
%817 = linalg.matmul ins(%815, %816 : tensor<8x3072xf32>, tensor<3072x768xf32>) outs(%112 : tensor<8x768xf32>) -> tensor<8x768xf32> | |
%818 = linalg.generic {indexing_maps = [#map1, #map19, #map1], iterator_types = ["parallel", "parallel"]} ins(%817, %cst_156 : tensor<8x768xf32>, tensor<768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%819 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%collapsed_364, %818 : tensor<8x768xf32>, tensor<8x768xf32>) outs(%5 : tensor<8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<8x768xf32> | |
%expanded_365 = tensor.expand_shape %819 [[0, 1], [2]] output_shape [1, 8, 768] : tensor<8x768xf32> into tensor<1x8x768xf32> | |
%820 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_365 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%821 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%820 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%822 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%821 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%823 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_365, %822 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.subf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%824 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%823, %823 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%825 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "reduction"]} ins(%824 : tensor<1x8x768xf32>) outs(%28 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %out : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%826 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%825 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.divf %in, %cst_2 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%827 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%826 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = arith.addf %in, %cst_3 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%828 = linalg.generic {indexing_maps = [#map10, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%827 : tensor<1x8x1xf32>) outs(%27 : tensor<1x8x1xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%839 = math.rsqrt %in : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x1xf32> | |
%829 = linalg.generic {indexing_maps = [#map9, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%828 : tensor<1x8x1xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%830 = linalg.generic {indexing_maps = [#map7, #map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%823, %829 : tensor<1x8x768xf32>, tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%831 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%830, %cst_13 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.mulf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%832 = linalg.generic {indexing_maps = [#map7, #map11, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%831, %cst_14 : tensor<1x8x768xf32>, tensor<768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %in_366: f32, %out: f32): | |
%839 = arith.addf %in, %in_366 : f32 | |
linalg.yield %839 : f32 | |
} -> tensor<1x8x768xf32> | |
%833 = linalg.generic {indexing_maps = [#map7, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%832 : tensor<1x8x768xf32>) outs(%25 : tensor<1x8x768xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x8x768xf32> | |
%834 = tensor.empty() : tensor<1x768x50272xf32> | |
%835 = linalg.generic {indexing_maps = [#map12, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_207 : tensor<768x50272xf32>) outs(%834 : tensor<1x768x50272xf32>) { | |
^bb0(%in: f32, %out: f32): | |
linalg.yield %in : f32 | |
} -> tensor<1x768x50272xf32> | |
%836 = tensor.empty() : tensor<1x8x50272xf32> | |
%837 = linalg.fill ins(%cst_0 : f32) outs(%836 : tensor<1x8x50272xf32>) -> tensor<1x8x50272xf32> | |
%838 = linalg.batch_matmul ins(%833, %835 : tensor<1x8x768xf32>, tensor<1x768x50272xf32>) outs(%837 : tensor<1x8x50272xf32>) -> tensor<1x8x50272xf32> | |
return %838, %53, %57, %137, %141, %201, %205, %265, %269, %329, %333, %393, %397, %457, %461, %521, %525, %585, %589, %649, %653, %713, %717, %777, %781 : tensor<1x8x50272xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32> | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment