Skip to content

Instantly share code, notes, and snippets.

@nirvedhmeshram
Created February 9, 2024 21:35
Show Gist options
  • Save nirvedhmeshram/4797d0505788b2e19bcc3fd7b9d0f1a5 to your computer and use it in GitHub Desktop.
Save nirvedhmeshram/4797d0505788b2e19bcc3fd7b9d0f1a5 to your computer and use it in GitHub Desktop.
// -----// IR Dump After AssignTargetDevicesPass (iree-hal-assign-target-devices) //----- //
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "default"}>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<() -> ()>
#map3 = affine_map<(d0, d1) -> ()>
#map4 = affine_map<(d0, d1) -> (d0, d1)>
#map5 = affine_map<(d0) -> (d0)>
#map6 = affine_map<(d0, d1) -> (d1)>
#map7 = affine_map<(d0, d1) -> (d0, 0)>
#map8 = affine_map<(d0, d1, d2, d3) -> (0, 0, d2, d3)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map10 = affine_map<(d0, d1, d2, d3) -> (0, 0, 0, d3)>
#map11 = affine_map<(d0, d1, d2, d3) -> ()>
#map12 = affine_map<(d0, d1) -> (0, d1)>
#map13 = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map14 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map15 = affine_map<(d0, d1, d2) -> (0, d1, 0)>
#map16 = affine_map<(d0, d1, d2) -> (d2)>
#map17 = affine_map<(d0, d1) -> (d1, d0)>
#map18 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map19 = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>
#map20 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
module attributes {hal.device.targets = [#device_target_llvm_cpu], torch.debug_module_name = "_lambda"} {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @forward(%arg0: tensor<1x8xi64>, %arg1: tensor<1x8xi64>) -> tensor<1x8x50272xf32> {
%cst = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_0 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_1 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_2 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_3 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_4 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_5 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_6 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_7 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_8 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_9 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_10 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_11 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_12 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_13 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_14 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_15 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_16 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_17 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_18 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_19 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_20 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_21 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_22 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_23 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_24 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_25 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_26 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_27 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_28 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_29 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_30 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_31 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_32 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_33 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_34 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_35 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_36 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_37 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_38 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_39 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_40 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_41 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_42 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_43 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_44 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_45 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_46 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_47 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_48 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_49 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_50 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_51 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_52 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_53 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_54 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_55 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_56 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_57 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_58 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_59 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_60 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_61 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_62 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_63 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_64 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_65 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_66 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_67 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_68 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_69 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_70 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_71 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_72 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_73 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_74 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_75 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_76 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_77 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_78 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_79 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_80 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_81 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_82 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_83 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_84 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_85 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_86 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_87 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_88 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_89 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_90 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_91 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_92 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_93 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_94 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_95 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_96 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_97 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_98 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_99 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_100 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_101 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_102 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_103 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_104 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_105 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_106 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_107 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_108 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_109 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_110 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_111 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_112 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_113 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_114 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_115 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_116 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_117 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_118 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_119 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_120 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_121 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_122 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_123 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_124 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_125 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_126 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_127 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_128 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_129 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_130 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_131 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_132 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_133 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_134 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_135 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_136 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_137 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_138 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_139 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_140 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_141 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_142 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_143 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_144 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_145 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_146 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_147 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_148 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_149 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_150 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_151 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_152 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_153 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_154 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_155 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_156 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_157 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_158 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_159 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_160 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_161 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_162 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_163 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_164 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_165 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_166 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_167 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_168 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_169 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_170 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_171 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_172 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_173 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_174 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_175 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_176 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_177 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_178 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_179 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_180 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_181 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_182 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_183 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_184 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_185 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_186 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_187 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_188 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_189 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_190 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_191 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_192 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_193 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_194 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_195 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_196 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_197 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_198 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_199 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_200 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_201 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_202 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_203 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_204 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_205 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_206 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_207 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_208 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_209 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_210 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_211 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_212 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_213 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_214 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_215 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_216 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_217 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_218 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_219 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_220 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_221 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_222 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_223 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_224 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_225 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_226 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_227 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_228 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_229 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_230 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_231 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_232 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_233 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_234 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_235 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_236 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_237 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_238 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_239 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_240 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_241 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_242 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_243 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_244 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_245 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_246 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_247 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_248 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_249 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_250 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_251 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_252 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_253 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_254 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_255 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_256 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_257 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_258 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_259 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_260 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_261 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_262 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_263 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_264 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_265 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_266 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_267 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_268 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_269 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_270 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_271 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_272 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_273 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_274 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_275 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_276 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_277 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_278 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_279 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_280 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_281 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_282 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_283 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_284 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_285 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_286 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_287 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_288 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_289 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_290 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_291 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_292 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_293 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_294 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_295 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_296 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_297 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_298 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_299 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_300 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_301 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_302 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_303 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_304 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_305 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_306 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_307 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_308 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_309 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_310 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_311 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_312 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_313 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_314 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_315 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_316 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_317 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_318 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_319 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_320 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_321 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_322 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_323 = arith.constant dense_resource<__elided__> : tensor<2048x8192xf32>
%cst_324 = arith.constant dense_resource<__elided__> : tensor<8192xf32>
%cst_325 = arith.constant dense_resource<__elided__> : tensor<8192x2048xf32>
%cst_326 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_327 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_328 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_329 = arith.constant dense<-3.40282347E+38> : tensor<f32>
%cst_330 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_331 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_332 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_333 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_334 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_335 = arith.constant dense_resource<__elided__> : tensor<2048x2048xf32>
%cst_336 = arith.constant dense_resource<__elided__> : tensor<2048xf32>
%cst_337 = arith.constant dense<1.000000e+00> : tensor<2048xf32>
%cst_338 = arith.constant dense_resource<__elided__> : tensor<2050x2048xf32>
%cst_339 = arith.constant dense_resource<__elided__> : tensor<50272x2048xf32>
%c50272 = arith.constant 50272 : index
%cst_340 = arith.constant 0.000000e+00 : f32
%c2050 = arith.constant 2050 : index
%cst_341 = arith.constant 0xFF800000 : f32
%c0_i64 = arith.constant 0 : i64
%cst_342 = arith.constant 0.000000e+00 : f64
%cst_343 = arith.constant -3.4028234663852886E+38 : f64
%c2_i64 = arith.constant 2 : i64
%cst_344 = arith.constant 1.000000e-05 : f64
%cst_345 = arith.constant 1.000000e+00 : f32
%c1_i64 = arith.constant 1 : i64
%cst_346 = arith.constant 2.048000e+03 : f64
%cst_347 = arith.constant 2.048000e+03 : f32
%cst_348 = arith.constant 1.250000e-01 : f32
%0 = tensor.empty() : tensor<1x8x2048xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x8xi64>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: i64, %out: f32):
%1818 = arith.index_cast %in : i64 to index
%1819 = linalg.index 2 : index
%1820 = arith.cmpi slt, %1818, %c50272 : index
cf.assert %1820, "index must be smaller than dim size"
%1821 = arith.cmpi sge, %in, %c0_i64 : i64
cf.assert %1821, "index must be larger or equal to 0"
%extracted = tensor.extract %cst_339[%1818, %1819] : tensor<50272x2048xf32>
linalg.yield %extracted : f32
} -> tensor<1x8x2048xf32>
%2 = tensor.empty() : tensor<f64>
%3 = linalg.fill ins(%cst_343 : f64) outs(%2 : tensor<f64>) -> tensor<f64>
%4 = tensor.empty() : tensor<f32>
%5 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = []} ins(%3 : tensor<f64>) outs(%4 : tensor<f32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<f32>
%6 = tensor.empty() : tensor<8x8xf32>
%7 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<f32>) outs(%6 : tensor<8x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8x8xf32>
%8 = tensor.empty() : tensor<8xi64>
%9 = linalg.generic {indexing_maps = [#map5], iterator_types = ["parallel"]} outs(%8 : tensor<8xi64>) {
^bb0(%out: i64):
%1818 = linalg.index 0 : index
%1819 = arith.index_cast %1818 : index to i64
linalg.yield %1819 : i64
} -> tensor<8xi64>
%10 = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel"]} ins(%9 : tensor<8xi64>) outs(%8 : tensor<8xi64>) {
^bb0(%in: i64, %out: i64):
%1818 = arith.addi %in, %c1_i64 : i64
linalg.yield %1818 : i64
} -> tensor<8xi64>
%expanded = tensor.expand_shape %10 [[0, 1]] : tensor<8xi64> into tensor<8x1xi64>
%11 = tensor.empty() : tensor<8x8xi1>
%12 = linalg.generic {indexing_maps = [#map6, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%9, %expanded : tensor<8xi64>, tensor<8x1xi64>) outs(%11 : tensor<8x8xi1>) {
^bb0(%in: i64, %in_712: i64, %out: i1):
%1818 = arith.cmpi slt, %in, %in_712 : i64
linalg.yield %1818 : i1
} -> tensor<8x8xi1>
%13 = tensor.empty() : tensor<i64>
%14 = linalg.fill ins(%c0_i64 : i64) outs(%13 : tensor<i64>) -> tensor<i64>
%15 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = []} ins(%14 : tensor<i64>) outs(%4 : tensor<f32>) {
^bb0(%in: i64, %out: f32):
%1818 = arith.sitofp %in : i64 to f32
linalg.yield %1818 : f32
} -> tensor<f32>
%16 = linalg.generic {indexing_maps = [#map4, #map3, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%12, %15, %7 : tensor<8x8xi1>, tensor<f32>, tensor<8x8xf32>) outs(%6 : tensor<8x8xf32>) {
^bb0(%in: i1, %in_712: f32, %in_713: f32, %out: f32):
%1818 = arith.select %in, %in_712, %in_713 : f32
linalg.yield %1818 : f32
} -> tensor<8x8xf32>
%expanded_349 = tensor.expand_shape %16 [[0, 1, 2], [3]] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
%17 = tensor.empty() : tensor<1x1x8x8xf32>
%18 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_349 : tensor<1x1x8x8xf32>) outs(%17 : tensor<1x1x8x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x8xf32>
%expanded_350 = tensor.expand_shape %arg1 [[0], [1, 2, 3]] : tensor<1x8xi64> into tensor<1x1x1x8xi64>
%19 = tensor.empty() : tensor<1x1x8x8xi64>
%20 = linalg.generic {indexing_maps = [#map10, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_350 : tensor<1x1x1x8xi64>) outs(%19 : tensor<1x1x8x8xi64>) {
^bb0(%in: i64, %out: i64):
linalg.yield %in : i64
} -> tensor<1x1x8x8xi64>
%21 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%20 : tensor<1x1x8x8xi64>) outs(%17 : tensor<1x1x8x8xf32>) {
^bb0(%in: i64, %out: f32):
%1818 = arith.sitofp %in : i64 to f32
linalg.yield %1818 : f32
} -> tensor<1x1x8x8xf32>
%22 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%21 : tensor<1x1x8x8xf32>) outs(%17 : tensor<1x1x8x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x1x8x8xf32>
%23 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%22 : tensor<1x1x8x8xf32>) outs(%17 : tensor<1x1x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.subf %cst_345, %in : f32
linalg.yield %1818 : f32
} -> tensor<1x1x8x8xf32>
%24 = tensor.empty() : tensor<1x1x8x8xi1>
%25 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%23 : tensor<1x1x8x8xf32>) outs(%24 : tensor<1x1x8x8xi1>) {
^bb0(%in: f32, %out: i1):
%1818 = arith.cmpf une, %in, %cst_340 : f32
linalg.yield %1818 : i1
} -> tensor<1x1x8x8xi1>
%26 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%25 : tensor<1x1x8x8xi1>) outs(%24 : tensor<1x1x8x8xi1>) {
^bb0(%in: i1, %out: i1):
linalg.yield %in : i1
} -> tensor<1x1x8x8xi1>
%27 = linalg.generic {indexing_maps = [#map8, #map11, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%26, %5, %23 : tensor<1x1x8x8xi1>, tensor<f32>, tensor<1x1x8x8xf32>) outs(%17 : tensor<1x1x8x8xf32>) {
^bb0(%in: i1, %in_712: f32, %in_713: f32, %out: f32):
%1818 = arith.select %in, %in_712, %in_713 : f32
linalg.yield %1818 : f32
} -> tensor<1x1x8x8xf32>
%28 = linalg.generic {indexing_maps = [#map8, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%27, %18 : tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) outs(%17 : tensor<1x1x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x1x8x8xf32>
%29 = tensor.empty() : tensor<1x8xi64>
%30 = linalg.fill ins(%c0_i64 : i64) outs(%29 : tensor<1x8xi64>) -> tensor<1x8xi64>
%31 = tensor.empty() : tensor<1xi64>
%32 = linalg.fill ins(%c0_i64 : i64) outs(%31 : tensor<1xi64>) -> tensor<1xi64>
%33:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%arg1 : tensor<1x8xi64>) outs(%30, %32 : tensor<1x8xi64>, tensor<1xi64>) {
^bb0(%arg2: i64, %arg3: i64):
%1818 = arith.addi %arg2, %arg3 : i64
tm_tensor.yield %1818 : i64
} -> tensor<1x8xi64>, tensor<1xi64>
%34 = linalg.generic {indexing_maps = [#map12, #map12, #map4], iterator_types = ["parallel", "parallel"]} ins(%33#0, %arg1 : tensor<1x8xi64>, tensor<1x8xi64>) outs(%29 : tensor<1x8xi64>) {
^bb0(%in: i64, %in_712: i64, %out: i64):
%1818 = arith.muli %in, %in_712 : i64
linalg.yield %1818 : i64
} -> tensor<1x8xi64>
%35 = linalg.generic {indexing_maps = [#map12, #map4], iterator_types = ["parallel", "parallel"]} ins(%34 : tensor<1x8xi64>) outs(%29 : tensor<1x8xi64>) {
^bb0(%in: i64, %out: i64):
%1818 = arith.subi %in, %c1_i64 : i64
linalg.yield %1818 : i64
} -> tensor<1x8xi64>
%36 = linalg.generic {indexing_maps = [#map12, #map4], iterator_types = ["parallel", "parallel"]} ins(%35 : tensor<1x8xi64>) outs(%29 : tensor<1x8xi64>) {
^bb0(%in: i64, %out: i64):
%1818 = arith.addi %in, %c2_i64 : i64
linalg.yield %1818 : i64
} -> tensor<1x8xi64>
%37 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%36 : tensor<1x8xi64>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: i64, %out: f32):
%1818 = arith.index_cast %in : i64 to index
%1819 = linalg.index 2 : index
%1820 = arith.cmpi slt, %1818, %c2050 : index
cf.assert %1820, "index must be smaller than dim size"
%1821 = arith.cmpi sge, %in, %c0_i64 : i64
cf.assert %1821, "index must be larger or equal to 0"
%extracted = tensor.extract %cst_338[%1818, %1819] : tensor<2050x2048xf32>
linalg.yield %extracted : f32
} -> tensor<1x8x2048xf32>
%38 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %37 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%39 = tensor.empty() : tensor<1x8x2048xf64>
%40 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%38 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%41 = tensor.empty() : tensor<1x8x1xf64>
%42 = linalg.fill ins(%cst_342 : f64) outs(%41 : tensor<1x8x1xf64>) -> tensor<1x8x1xf64>
%43 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%40 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%44 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%45 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%40, %44 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%46 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%45, %45 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%47 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%46 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%48 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%49 = tensor.empty() : tensor<1x8x1xf32>
%50 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%48 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%51 = linalg.fill ins(%cst_340 : f32) outs(%49 : tensor<1x8x1xf32>) -> tensor<1x8x1xf32>
%52 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%38 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%53 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%52 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%54 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%50 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%55 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%54 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%56 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%38, %53 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%57 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%56, %55 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%58 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%57, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%59 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%58, %cst_336 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%60 = tensor.empty() : tensor<2048x2048xf32>
%61 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_335 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed = tensor.collapse_shape %59 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%62 = tensor.empty() : tensor<8x2048xf32>
%63 = linalg.fill ins(%cst_340 : f32) outs(%62 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%64 = linalg.matmul ins(%collapsed, %61 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%65 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_334, %64 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_351 = tensor.expand_shape %65 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%66 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_351 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%67 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_333 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%68 = linalg.matmul ins(%collapsed, %67 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%69 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_332, %68 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_352 = tensor.expand_shape %69 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%70 = tensor.empty() : tensor<1x32x8x64xf32>
%71 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_352 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%72 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%71 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%73 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_331 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%74 = linalg.matmul ins(%collapsed, %73 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%75 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_330, %74 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_353 = tensor.expand_shape %75 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%76 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_353 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%77 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%76 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_354 = tensor.expand_shape %66 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%78 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_354 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%79 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%78 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_355 = tensor.collapse_shape %79 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_356 = tensor.collapse_shape %72 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_357 = tensor.collapse_shape %77 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%80 = tensor.empty() : tensor<32x64x8xf32>
%81 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_356 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%82 = tensor.empty() : tensor<32x8x8xf32>
%83 = linalg.fill ins(%cst_340 : f32) outs(%82 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%84 = linalg.batch_matmul ins(%collapsed_355, %81 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_358 = tensor.expand_shape %84 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%85 = tensor.empty() : tensor<1x32x8x8xf32>
%86 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_358, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%87 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%86, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_359 = tensor.collapse_shape %87 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%88 = tensor.empty() : tensor<32x8x1xi64>
%89 = linalg.fill ins(%c0_i64 : i64) outs(%88 : tensor<32x8x1xi64>) -> tensor<32x8x1xi64>
%90 = tensor.empty() : tensor<32x8x1xf32>
%91 = linalg.fill ins(%cst_341 : f32) outs(%90 : tensor<32x8x1xf32>) -> tensor<32x8x1xf32>
%92:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_359 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%93 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_359, %92#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%94 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%93 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%95 = linalg.fill ins(%cst_340 : f32) outs(%90 : tensor<32x8x1xf32>) -> tensor<32x8x1xf32>
%96 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%94 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%97 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%94, %96 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%98 = tensor.empty() : tensor<32x8x64xf32>
%99 = linalg.fill ins(%cst_340 : f32) outs(%98 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%100 = linalg.batch_matmul ins(%97, %collapsed_357 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_360 = tensor.expand_shape %100 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%101 = tensor.empty() : tensor<1x8x32x64xf32>
%102 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_360 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%103 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%102 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%104 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_328 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_361 = tensor.collapse_shape %103 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%105 = linalg.matmul ins(%collapsed_361, %104 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%106 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_327, %105 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_362 = tensor.expand_shape %106 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%107 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%38, %expanded_362 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_363 = tensor.collapse_shape %107 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%108 = tensor.empty() : tensor<8x2048xf64>
%109 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_363 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%110 = tensor.empty() : tensor<8x1xf64>
%111 = linalg.fill ins(%cst_342 : f64) outs(%110 : tensor<8x1xf64>) -> tensor<8x1xf64>
%112 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%109 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%113 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%112 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%114 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%109, %113 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%115 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%114, %114 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%116 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%115 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%117 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%116 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%118 = tensor.empty() : tensor<8x1xf32>
%119 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%117 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%120 = linalg.fill ins(%cst_340 : f32) outs(%118 : tensor<8x1xf32>) -> tensor<8x1xf32>
%121 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_363 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%122 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%121 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%123 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%119 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%124 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%123 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%125 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_363, %122 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%126 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%125, %124 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%127 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%126, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%128 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%127, %cst_326 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%129 = tensor.empty() : tensor<2048x8192xf32>
%130 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_325 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%131 = tensor.empty() : tensor<8x8192xf32>
%132 = linalg.fill ins(%cst_340 : f32) outs(%131 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%133 = linalg.matmul ins(%128, %130 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%134 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_324, %133 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%135 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%134 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%136 = tensor.empty() : tensor<8192x2048xf32>
%137 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_323 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%138 = linalg.matmul ins(%135, %137 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%139 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_322, %138 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%140 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_363, %139 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_364 = tensor.expand_shape %140 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%141 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_364 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%142 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%141 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%143 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%142 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%144 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%141, %143 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%145 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%144, %144 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%146 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%145 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%147 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%146 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%148 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%147 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%149 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_364 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%150 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%149 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%151 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%148 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%152 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%151 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%153 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_364, %150 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%154 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%153, %152 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%155 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%154, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%156 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%155, %cst_321 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%157 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_320 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_365 = tensor.collapse_shape %156 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%158 = linalg.matmul ins(%collapsed_365, %157 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%159 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_319, %158 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_366 = tensor.expand_shape %159 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%160 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_366 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%161 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_318 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%162 = linalg.matmul ins(%collapsed_365, %161 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%163 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_317, %162 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_367 = tensor.expand_shape %163 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%164 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_367 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%165 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%164 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%166 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_316 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%167 = linalg.matmul ins(%collapsed_365, %166 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%168 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_315, %167 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_368 = tensor.expand_shape %168 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%169 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_368 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%170 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%169 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_369 = tensor.expand_shape %160 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%171 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_369 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%172 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%171 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_370 = tensor.collapse_shape %172 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_371 = tensor.collapse_shape %165 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_372 = tensor.collapse_shape %170 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%173 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_371 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%174 = linalg.batch_matmul ins(%collapsed_370, %173 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_373 = tensor.expand_shape %174 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%175 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_373, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%176 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%175, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_374 = tensor.collapse_shape %176 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%177:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_374 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%178 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_374, %177#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%179 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%178 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%180 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%179 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%181 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%179, %180 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%182 = linalg.batch_matmul ins(%181, %collapsed_372 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_375 = tensor.expand_shape %182 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%183 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_375 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%184 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%183 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%185 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_314 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_376 = tensor.collapse_shape %184 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%186 = linalg.matmul ins(%collapsed_376, %185 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%187 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_313, %186 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_377 = tensor.expand_shape %187 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%188 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_364, %expanded_377 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_378 = tensor.collapse_shape %188 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%189 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_378 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%190 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%189 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%191 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%190 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%192 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%189, %191 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%193 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%192, %192 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%194 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%193 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%195 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%194 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%196 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%195 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%197 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_378 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%198 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%197 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%199 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%196 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%200 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%199 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%201 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_378, %198 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%202 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%201, %200 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%203 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%202, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%204 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%203, %cst_312 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%205 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_311 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%206 = linalg.matmul ins(%204, %205 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%207 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_310, %206 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%208 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%207 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%209 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_309 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%210 = linalg.matmul ins(%208, %209 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%211 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_308, %210 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%212 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_378, %211 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_379 = tensor.expand_shape %212 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%213 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_379 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%214 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%213 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%215 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%214 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%216 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%213, %215 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%217 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%216, %216 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%218 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%217 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%219 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%218 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%220 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%219 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%221 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_379 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%222 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%221 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%223 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%220 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%224 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%223 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%225 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_379, %222 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%226 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%225, %224 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%227 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%226, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%228 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%227, %cst_307 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%229 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_306 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_380 = tensor.collapse_shape %228 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%230 = linalg.matmul ins(%collapsed_380, %229 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%231 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_305, %230 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_381 = tensor.expand_shape %231 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%232 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_381 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%233 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_304 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%234 = linalg.matmul ins(%collapsed_380, %233 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%235 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_303, %234 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_382 = tensor.expand_shape %235 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%236 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_382 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%237 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%236 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%238 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_302 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%239 = linalg.matmul ins(%collapsed_380, %238 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%240 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_301, %239 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_383 = tensor.expand_shape %240 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%241 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_383 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%242 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%241 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_384 = tensor.expand_shape %232 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%243 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_384 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%244 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%243 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_385 = tensor.collapse_shape %244 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_386 = tensor.collapse_shape %237 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_387 = tensor.collapse_shape %242 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%245 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_386 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%246 = linalg.batch_matmul ins(%collapsed_385, %245 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_388 = tensor.expand_shape %246 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%247 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_388, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%248 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%247, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_389 = tensor.collapse_shape %248 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%249:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_389 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%250 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_389, %249#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%251 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%250 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%252 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%251 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%253 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%251, %252 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%254 = linalg.batch_matmul ins(%253, %collapsed_387 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_390 = tensor.expand_shape %254 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%255 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_390 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%256 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%255 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%257 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_300 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_391 = tensor.collapse_shape %256 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%258 = linalg.matmul ins(%collapsed_391, %257 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%259 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_299, %258 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_392 = tensor.expand_shape %259 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%260 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_379, %expanded_392 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_393 = tensor.collapse_shape %260 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%261 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_393 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%262 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%261 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%263 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%262 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%264 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%261, %263 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%265 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%264, %264 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%266 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%265 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%267 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%266 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%268 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%267 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%269 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_393 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%270 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%269 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%271 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%268 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%272 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%271 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%273 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_393, %270 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%274 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%273, %272 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%275 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%274, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%276 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%275, %cst_298 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%277 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_297 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%278 = linalg.matmul ins(%276, %277 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%279 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_296, %278 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%280 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%279 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%281 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_295 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%282 = linalg.matmul ins(%280, %281 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%283 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_294, %282 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%284 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_393, %283 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_394 = tensor.expand_shape %284 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%285 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_394 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%286 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%285 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%287 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%286 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%288 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%285, %287 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%289 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%288, %288 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%290 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%289 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%291 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%290 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%292 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%291 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%293 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_394 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%294 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%293 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%295 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%292 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%296 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%295 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%297 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_394, %294 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%298 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%297, %296 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%299 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%298, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%300 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%299, %cst_293 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%301 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_292 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_395 = tensor.collapse_shape %300 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%302 = linalg.matmul ins(%collapsed_395, %301 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%303 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_291, %302 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_396 = tensor.expand_shape %303 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%304 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_396 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%305 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_290 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%306 = linalg.matmul ins(%collapsed_395, %305 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%307 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_289, %306 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_397 = tensor.expand_shape %307 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%308 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_397 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%309 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%308 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%310 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_288 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%311 = linalg.matmul ins(%collapsed_395, %310 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%312 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_287, %311 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_398 = tensor.expand_shape %312 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%313 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_398 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%314 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%313 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_399 = tensor.expand_shape %304 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%315 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_399 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%316 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%315 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_400 = tensor.collapse_shape %316 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_401 = tensor.collapse_shape %309 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_402 = tensor.collapse_shape %314 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%317 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_401 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%318 = linalg.batch_matmul ins(%collapsed_400, %317 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_403 = tensor.expand_shape %318 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%319 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_403, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%320 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%319, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_404 = tensor.collapse_shape %320 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%321:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_404 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%322 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_404, %321#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%323 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%322 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%324 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%323 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%325 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%323, %324 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%326 = linalg.batch_matmul ins(%325, %collapsed_402 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_405 = tensor.expand_shape %326 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%327 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_405 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%328 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%327 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%329 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_286 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_406 = tensor.collapse_shape %328 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%330 = linalg.matmul ins(%collapsed_406, %329 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%331 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_285, %330 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_407 = tensor.expand_shape %331 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%332 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_394, %expanded_407 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_408 = tensor.collapse_shape %332 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%333 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_408 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%334 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%333 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%335 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%334 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%336 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%333, %335 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%337 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%336, %336 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%338 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%337 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%339 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%338 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%340 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%339 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%341 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_408 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%342 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%341 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%343 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%340 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%344 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%343 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%345 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_408, %342 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%346 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%345, %344 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%347 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%346, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%348 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%347, %cst_284 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%349 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_283 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%350 = linalg.matmul ins(%348, %349 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%351 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_282, %350 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%352 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%351 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%353 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_281 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%354 = linalg.matmul ins(%352, %353 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%355 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_280, %354 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%356 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_408, %355 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_409 = tensor.expand_shape %356 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%357 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_409 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%358 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%357 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%359 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%358 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%360 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%357, %359 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%361 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%360, %360 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%362 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%361 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%363 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%362 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%364 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%363 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%365 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_409 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%366 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%365 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%367 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%364 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%368 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%367 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%369 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_409, %366 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%370 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%369, %368 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%371 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%370, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%372 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%371, %cst_279 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%373 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_278 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_410 = tensor.collapse_shape %372 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%374 = linalg.matmul ins(%collapsed_410, %373 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%375 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_277, %374 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_411 = tensor.expand_shape %375 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%376 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_411 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%377 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_276 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%378 = linalg.matmul ins(%collapsed_410, %377 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%379 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_275, %378 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_412 = tensor.expand_shape %379 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%380 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_412 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%381 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%380 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%382 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_274 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%383 = linalg.matmul ins(%collapsed_410, %382 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%384 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_273, %383 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_413 = tensor.expand_shape %384 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%385 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_413 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%386 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%385 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_414 = tensor.expand_shape %376 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%387 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_414 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%388 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%387 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_415 = tensor.collapse_shape %388 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_416 = tensor.collapse_shape %381 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_417 = tensor.collapse_shape %386 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%389 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_416 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%390 = linalg.batch_matmul ins(%collapsed_415, %389 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_418 = tensor.expand_shape %390 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%391 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_418, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%392 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%391, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_419 = tensor.collapse_shape %392 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%393:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_419 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%394 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_419, %393#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%395 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%394 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%396 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%395 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%397 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%395, %396 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%398 = linalg.batch_matmul ins(%397, %collapsed_417 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_420 = tensor.expand_shape %398 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%399 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_420 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%400 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%399 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%401 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_272 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_421 = tensor.collapse_shape %400 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%402 = linalg.matmul ins(%collapsed_421, %401 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%403 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_271, %402 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_422 = tensor.expand_shape %403 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%404 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_409, %expanded_422 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_423 = tensor.collapse_shape %404 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%405 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_423 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%406 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%405 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%407 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%406 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%408 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%405, %407 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%409 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%408, %408 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%410 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%409 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%411 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%410 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%412 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%411 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%413 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_423 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%414 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%413 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%415 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%412 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%416 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%415 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%417 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_423, %414 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%418 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%417, %416 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%419 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%418, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%420 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%419, %cst_270 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%421 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_269 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%422 = linalg.matmul ins(%420, %421 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%423 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_268, %422 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%424 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%423 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%425 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_267 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%426 = linalg.matmul ins(%424, %425 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%427 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_266, %426 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%428 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_423, %427 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_424 = tensor.expand_shape %428 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%429 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_424 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%430 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%429 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%431 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%430 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%432 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%429, %431 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%433 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%432, %432 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%434 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%433 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%435 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%434 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%436 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%435 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%437 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_424 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%438 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%437 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%439 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%436 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%440 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%439 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%441 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_424, %438 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%442 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%441, %440 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%443 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%442, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%444 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%443, %cst_265 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%445 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_264 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_425 = tensor.collapse_shape %444 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%446 = linalg.matmul ins(%collapsed_425, %445 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%447 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_263, %446 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_426 = tensor.expand_shape %447 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%448 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_426 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%449 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_262 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%450 = linalg.matmul ins(%collapsed_425, %449 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%451 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_261, %450 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_427 = tensor.expand_shape %451 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%452 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_427 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%453 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%452 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%454 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_260 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%455 = linalg.matmul ins(%collapsed_425, %454 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%456 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_259, %455 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_428 = tensor.expand_shape %456 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%457 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_428 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%458 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%457 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_429 = tensor.expand_shape %448 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%459 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_429 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%460 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%459 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_430 = tensor.collapse_shape %460 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_431 = tensor.collapse_shape %453 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_432 = tensor.collapse_shape %458 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%461 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_431 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%462 = linalg.batch_matmul ins(%collapsed_430, %461 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_433 = tensor.expand_shape %462 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%463 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_433, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%464 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%463, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_434 = tensor.collapse_shape %464 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%465:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_434 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%466 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_434, %465#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%467 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%466 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%468 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%467 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%469 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%467, %468 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%470 = linalg.batch_matmul ins(%469, %collapsed_432 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_435 = tensor.expand_shape %470 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%471 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_435 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%472 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%471 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%473 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_258 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_436 = tensor.collapse_shape %472 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%474 = linalg.matmul ins(%collapsed_436, %473 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%475 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_257, %474 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_437 = tensor.expand_shape %475 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%476 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_424, %expanded_437 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_438 = tensor.collapse_shape %476 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%477 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_438 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%478 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%477 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%479 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%478 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%480 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%477, %479 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%481 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%480, %480 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%482 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%481 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%483 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%482 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%484 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%483 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%485 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_438 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%486 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%485 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%487 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%484 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%488 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%487 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%489 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_438, %486 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%490 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%489, %488 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%491 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%490, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%492 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%491, %cst_256 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%493 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_255 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%494 = linalg.matmul ins(%492, %493 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%495 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_254, %494 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%496 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%495 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%497 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_253 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%498 = linalg.matmul ins(%496, %497 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%499 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_252, %498 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%500 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_438, %499 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_439 = tensor.expand_shape %500 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%501 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_439 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%502 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%501 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%503 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%502 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%504 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%501, %503 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%505 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%504, %504 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%506 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%505 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%507 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%506 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%508 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%507 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%509 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_439 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%510 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%509 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%511 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%508 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%512 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%511 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%513 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_439, %510 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%514 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%513, %512 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%515 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%514, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%516 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%515, %cst_251 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%517 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_250 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_440 = tensor.collapse_shape %516 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%518 = linalg.matmul ins(%collapsed_440, %517 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%519 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_249, %518 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_441 = tensor.expand_shape %519 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%520 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_441 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%521 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_248 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%522 = linalg.matmul ins(%collapsed_440, %521 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%523 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_247, %522 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_442 = tensor.expand_shape %523 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%524 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_442 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%525 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%524 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%526 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_246 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%527 = linalg.matmul ins(%collapsed_440, %526 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%528 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_245, %527 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_443 = tensor.expand_shape %528 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%529 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_443 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%530 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%529 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_444 = tensor.expand_shape %520 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%531 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_444 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%532 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%531 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_445 = tensor.collapse_shape %532 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_446 = tensor.collapse_shape %525 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_447 = tensor.collapse_shape %530 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%533 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_446 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%534 = linalg.batch_matmul ins(%collapsed_445, %533 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_448 = tensor.expand_shape %534 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%535 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_448, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%536 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%535, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_449 = tensor.collapse_shape %536 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%537:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_449 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%538 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_449, %537#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%539 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%538 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%540 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%539 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%541 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%539, %540 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%542 = linalg.batch_matmul ins(%541, %collapsed_447 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_450 = tensor.expand_shape %542 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%543 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_450 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%544 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%543 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%545 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_244 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_451 = tensor.collapse_shape %544 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%546 = linalg.matmul ins(%collapsed_451, %545 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%547 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_243, %546 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_452 = tensor.expand_shape %547 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%548 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_439, %expanded_452 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_453 = tensor.collapse_shape %548 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%549 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_453 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%550 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%549 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%551 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%550 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%552 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%549, %551 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%553 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%552, %552 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%554 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%553 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%555 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%554 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%556 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%555 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%557 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_453 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%558 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%557 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%559 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%556 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%560 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%559 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%561 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_453, %558 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%562 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%561, %560 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%563 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%562, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%564 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%563, %cst_242 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%565 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_241 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%566 = linalg.matmul ins(%564, %565 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%567 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_240, %566 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%568 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%567 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%569 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_239 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%570 = linalg.matmul ins(%568, %569 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%571 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_238, %570 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%572 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_453, %571 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_454 = tensor.expand_shape %572 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%573 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_454 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%574 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%573 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%575 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%574 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%576 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%573, %575 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%577 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%576, %576 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%578 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%577 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%579 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%578 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%580 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%579 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%581 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_454 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%582 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%581 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%583 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%580 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%584 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%583 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%585 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_454, %582 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%586 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%585, %584 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%587 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%586, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%588 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%587, %cst_237 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%589 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_236 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_455 = tensor.collapse_shape %588 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%590 = linalg.matmul ins(%collapsed_455, %589 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%591 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_235, %590 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_456 = tensor.expand_shape %591 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%592 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_456 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%593 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_234 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%594 = linalg.matmul ins(%collapsed_455, %593 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%595 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_233, %594 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_457 = tensor.expand_shape %595 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%596 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_457 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%597 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%596 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%598 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_232 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%599 = linalg.matmul ins(%collapsed_455, %598 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%600 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_231, %599 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_458 = tensor.expand_shape %600 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%601 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_458 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%602 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%601 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_459 = tensor.expand_shape %592 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%603 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_459 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%604 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%603 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_460 = tensor.collapse_shape %604 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_461 = tensor.collapse_shape %597 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_462 = tensor.collapse_shape %602 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%605 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_461 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%606 = linalg.batch_matmul ins(%collapsed_460, %605 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_463 = tensor.expand_shape %606 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%607 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_463, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%608 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%607, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_464 = tensor.collapse_shape %608 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%609:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_464 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%610 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_464, %609#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%611 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%610 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%612 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%611 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%613 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%611, %612 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%614 = linalg.batch_matmul ins(%613, %collapsed_462 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_465 = tensor.expand_shape %614 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%615 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_465 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%616 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%615 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%617 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_230 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_466 = tensor.collapse_shape %616 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%618 = linalg.matmul ins(%collapsed_466, %617 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%619 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_229, %618 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_467 = tensor.expand_shape %619 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%620 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_454, %expanded_467 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_468 = tensor.collapse_shape %620 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%621 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_468 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%622 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%621 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%623 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%622 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%624 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%621, %623 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%625 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%624, %624 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%626 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%625 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%627 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%626 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%628 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%627 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%629 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_468 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%630 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%629 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%631 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%628 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%632 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%631 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%633 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_468, %630 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%634 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%633, %632 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%635 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%634, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%636 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%635, %cst_228 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%637 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_227 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%638 = linalg.matmul ins(%636, %637 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%639 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_226, %638 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%640 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%639 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%641 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_225 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%642 = linalg.matmul ins(%640, %641 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%643 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_224, %642 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%644 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_468, %643 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_469 = tensor.expand_shape %644 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%645 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_469 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%646 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%645 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%647 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%646 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%648 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%645, %647 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%649 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%648, %648 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%650 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%649 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%651 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%650 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%652 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%651 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%653 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_469 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%654 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%653 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%655 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%652 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%656 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%655 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%657 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_469, %654 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%658 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%657, %656 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%659 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%658, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%660 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%659, %cst_223 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%661 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_222 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_470 = tensor.collapse_shape %660 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%662 = linalg.matmul ins(%collapsed_470, %661 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%663 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_221, %662 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_471 = tensor.expand_shape %663 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%664 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_471 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%665 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_220 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%666 = linalg.matmul ins(%collapsed_470, %665 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%667 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_219, %666 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_472 = tensor.expand_shape %667 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%668 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_472 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%669 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%668 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%670 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_218 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%671 = linalg.matmul ins(%collapsed_470, %670 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%672 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_217, %671 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_473 = tensor.expand_shape %672 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%673 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_473 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%674 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%673 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_474 = tensor.expand_shape %664 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%675 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_474 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%676 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%675 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_475 = tensor.collapse_shape %676 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_476 = tensor.collapse_shape %669 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_477 = tensor.collapse_shape %674 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%677 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_476 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%678 = linalg.batch_matmul ins(%collapsed_475, %677 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_478 = tensor.expand_shape %678 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%679 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_478, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%680 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%679, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_479 = tensor.collapse_shape %680 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%681:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_479 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%682 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_479, %681#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%683 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%682 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%684 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%683 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%685 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%683, %684 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%686 = linalg.batch_matmul ins(%685, %collapsed_477 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_480 = tensor.expand_shape %686 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%687 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_480 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%688 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%687 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%689 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_216 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_481 = tensor.collapse_shape %688 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%690 = linalg.matmul ins(%collapsed_481, %689 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%691 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_215, %690 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_482 = tensor.expand_shape %691 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%692 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_469, %expanded_482 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_483 = tensor.collapse_shape %692 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%693 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_483 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%694 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%693 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%695 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%694 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%696 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%693, %695 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%697 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%696, %696 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%698 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%697 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%699 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%698 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%700 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%699 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%701 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_483 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%702 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%701 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%703 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%700 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%704 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%703 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%705 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_483, %702 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%706 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%705, %704 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%707 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%706, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%708 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%707, %cst_214 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%709 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_213 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%710 = linalg.matmul ins(%708, %709 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%711 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_212, %710 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%712 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%711 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%713 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_211 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%714 = linalg.matmul ins(%712, %713 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%715 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_210, %714 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%716 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_483, %715 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_484 = tensor.expand_shape %716 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%717 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_484 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%718 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%717 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%719 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%718 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%720 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%717, %719 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%721 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%720, %720 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%722 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%721 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%723 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%722 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%724 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%723 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%725 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_484 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%726 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%725 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%727 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%724 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%728 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%727 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%729 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_484, %726 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%730 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%729, %728 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%731 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%730, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%732 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%731, %cst_209 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%733 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_208 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_485 = tensor.collapse_shape %732 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%734 = linalg.matmul ins(%collapsed_485, %733 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%735 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_207, %734 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_486 = tensor.expand_shape %735 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%736 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_486 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%737 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_206 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%738 = linalg.matmul ins(%collapsed_485, %737 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%739 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_205, %738 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_487 = tensor.expand_shape %739 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%740 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_487 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%741 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%740 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%742 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_204 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%743 = linalg.matmul ins(%collapsed_485, %742 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%744 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_203, %743 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_488 = tensor.expand_shape %744 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%745 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_488 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%746 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%745 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_489 = tensor.expand_shape %736 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%747 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_489 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%748 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%747 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_490 = tensor.collapse_shape %748 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_491 = tensor.collapse_shape %741 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_492 = tensor.collapse_shape %746 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%749 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_491 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%750 = linalg.batch_matmul ins(%collapsed_490, %749 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_493 = tensor.expand_shape %750 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%751 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_493, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%752 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%751, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_494 = tensor.collapse_shape %752 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%753:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_494 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%754 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_494, %753#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%755 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%754 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%756 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%755 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%757 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%755, %756 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%758 = linalg.batch_matmul ins(%757, %collapsed_492 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_495 = tensor.expand_shape %758 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%759 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_495 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%760 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%759 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%761 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_202 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_496 = tensor.collapse_shape %760 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%762 = linalg.matmul ins(%collapsed_496, %761 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%763 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_201, %762 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_497 = tensor.expand_shape %763 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%764 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_484, %expanded_497 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_498 = tensor.collapse_shape %764 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%765 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_498 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%766 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%765 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%767 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%766 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%768 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%765, %767 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%769 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%768, %768 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%770 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%769 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%771 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%770 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%772 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%771 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%773 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_498 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%774 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%773 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%775 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%772 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%776 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%775 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%777 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_498, %774 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%778 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%777, %776 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%779 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%778, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%780 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%779, %cst_200 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%781 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_199 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%782 = linalg.matmul ins(%780, %781 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%783 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_198, %782 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%784 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%783 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%785 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_197 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%786 = linalg.matmul ins(%784, %785 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%787 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_196, %786 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%788 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_498, %787 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_499 = tensor.expand_shape %788 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%789 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_499 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%790 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%789 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%791 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%790 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%792 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%789, %791 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%793 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%792, %792 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%794 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%793 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%795 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%794 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%796 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%795 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%797 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_499 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%798 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%797 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%799 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%796 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%800 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%799 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%801 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_499, %798 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%802 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%801, %800 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%803 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%802, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%804 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%803, %cst_195 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%805 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_194 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_500 = tensor.collapse_shape %804 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%806 = linalg.matmul ins(%collapsed_500, %805 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%807 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_193, %806 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_501 = tensor.expand_shape %807 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%808 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_501 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%809 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_192 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%810 = linalg.matmul ins(%collapsed_500, %809 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%811 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_191, %810 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_502 = tensor.expand_shape %811 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%812 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_502 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%813 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%812 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%814 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_190 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%815 = linalg.matmul ins(%collapsed_500, %814 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%816 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_189, %815 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_503 = tensor.expand_shape %816 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%817 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_503 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%818 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%817 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_504 = tensor.expand_shape %808 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%819 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_504 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%820 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%819 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_505 = tensor.collapse_shape %820 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_506 = tensor.collapse_shape %813 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_507 = tensor.collapse_shape %818 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%821 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_506 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%822 = linalg.batch_matmul ins(%collapsed_505, %821 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_508 = tensor.expand_shape %822 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%823 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_508, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%824 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%823, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_509 = tensor.collapse_shape %824 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%825:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_509 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%826 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_509, %825#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%827 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%826 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%828 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%827 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%829 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%827, %828 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%830 = linalg.batch_matmul ins(%829, %collapsed_507 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_510 = tensor.expand_shape %830 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%831 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_510 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%832 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%831 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%833 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_188 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_511 = tensor.collapse_shape %832 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%834 = linalg.matmul ins(%collapsed_511, %833 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%835 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_187, %834 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_512 = tensor.expand_shape %835 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%836 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_499, %expanded_512 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_513 = tensor.collapse_shape %836 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%837 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_513 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%838 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%837 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%839 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%838 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%840 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%837, %839 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%841 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%840, %840 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%842 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%841 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%843 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%842 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%844 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%843 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%845 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_513 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%846 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%845 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%847 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%844 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%848 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%847 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%849 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_513, %846 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%850 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%849, %848 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%851 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%850, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%852 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%851, %cst_186 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%853 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_185 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%854 = linalg.matmul ins(%852, %853 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%855 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_184, %854 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%856 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%855 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%857 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_183 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%858 = linalg.matmul ins(%856, %857 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%859 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_182, %858 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%860 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_513, %859 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_514 = tensor.expand_shape %860 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%861 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_514 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%862 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%861 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%863 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%862 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%864 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%861, %863 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%865 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%864, %864 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%866 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%865 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%867 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%866 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%868 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%867 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%869 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_514 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%870 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%869 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%871 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%868 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%872 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%871 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%873 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_514, %870 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%874 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%873, %872 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%875 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%874, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%876 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%875, %cst_181 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%877 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_180 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_515 = tensor.collapse_shape %876 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%878 = linalg.matmul ins(%collapsed_515, %877 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%879 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_179, %878 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_516 = tensor.expand_shape %879 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%880 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_516 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%881 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_178 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%882 = linalg.matmul ins(%collapsed_515, %881 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%883 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_177, %882 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_517 = tensor.expand_shape %883 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%884 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_517 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%885 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%884 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%886 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_176 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%887 = linalg.matmul ins(%collapsed_515, %886 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%888 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_175, %887 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_518 = tensor.expand_shape %888 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%889 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_518 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%890 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%889 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_519 = tensor.expand_shape %880 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%891 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_519 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%892 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%891 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_520 = tensor.collapse_shape %892 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_521 = tensor.collapse_shape %885 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_522 = tensor.collapse_shape %890 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%893 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_521 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%894 = linalg.batch_matmul ins(%collapsed_520, %893 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_523 = tensor.expand_shape %894 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%895 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_523, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%896 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%895, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_524 = tensor.collapse_shape %896 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%897:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_524 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%898 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_524, %897#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%899 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%898 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%900 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%899 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%901 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%899, %900 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%902 = linalg.batch_matmul ins(%901, %collapsed_522 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_525 = tensor.expand_shape %902 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%903 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_525 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%904 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%903 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%905 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_174 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_526 = tensor.collapse_shape %904 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%906 = linalg.matmul ins(%collapsed_526, %905 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%907 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_173, %906 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_527 = tensor.expand_shape %907 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%908 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_514, %expanded_527 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_528 = tensor.collapse_shape %908 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%909 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_528 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%910 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%909 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%911 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%910 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%912 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%909, %911 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%913 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%912, %912 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%914 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%913 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%915 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%914 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%916 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%915 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%917 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_528 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%918 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%917 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%919 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%916 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%920 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%919 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%921 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_528, %918 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%922 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%921, %920 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%923 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%922, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%924 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%923, %cst_172 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%925 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_171 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%926 = linalg.matmul ins(%924, %925 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%927 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_170, %926 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%928 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%927 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%929 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_169 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%930 = linalg.matmul ins(%928, %929 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%931 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_168, %930 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%932 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_528, %931 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_529 = tensor.expand_shape %932 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%933 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_529 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%934 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%933 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%935 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%934 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%936 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%933, %935 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%937 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%936, %936 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%938 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%937 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%939 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%938 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%940 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%939 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%941 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_529 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%942 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%941 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%943 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%940 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%944 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%943 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%945 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_529, %942 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%946 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%945, %944 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%947 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%946, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%948 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%947, %cst_167 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%949 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_166 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_530 = tensor.collapse_shape %948 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%950 = linalg.matmul ins(%collapsed_530, %949 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%951 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_165, %950 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_531 = tensor.expand_shape %951 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%952 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_531 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%953 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_164 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%954 = linalg.matmul ins(%collapsed_530, %953 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%955 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_163, %954 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_532 = tensor.expand_shape %955 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%956 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_532 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%957 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%956 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%958 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_162 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%959 = linalg.matmul ins(%collapsed_530, %958 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%960 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_161, %959 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_533 = tensor.expand_shape %960 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%961 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_533 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%962 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%961 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_534 = tensor.expand_shape %952 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%963 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_534 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%964 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%963 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_535 = tensor.collapse_shape %964 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_536 = tensor.collapse_shape %957 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_537 = tensor.collapse_shape %962 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%965 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_536 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%966 = linalg.batch_matmul ins(%collapsed_535, %965 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_538 = tensor.expand_shape %966 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%967 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_538, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%968 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%967, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_539 = tensor.collapse_shape %968 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%969:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_539 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%970 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_539, %969#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%971 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%970 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%972 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%971 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%973 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%971, %972 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%974 = linalg.batch_matmul ins(%973, %collapsed_537 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_540 = tensor.expand_shape %974 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%975 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_540 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%976 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%975 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%977 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_160 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_541 = tensor.collapse_shape %976 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%978 = linalg.matmul ins(%collapsed_541, %977 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%979 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_159, %978 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_542 = tensor.expand_shape %979 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%980 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_529, %expanded_542 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_543 = tensor.collapse_shape %980 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%981 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_543 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%982 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%981 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%983 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%982 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%984 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%981, %983 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%985 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%984, %984 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%986 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%985 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%987 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%986 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%988 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%987 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%989 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_543 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%990 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%989 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%991 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%988 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%992 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%991 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%993 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_543, %990 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%994 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%993, %992 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%995 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%994, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%996 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%995, %cst_158 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%997 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_157 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%998 = linalg.matmul ins(%996, %997 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%999 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_156, %998 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1000 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%999 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1001 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_155 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1002 = linalg.matmul ins(%1000, %1001 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1003 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_154, %1002 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1004 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_543, %1003 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_544 = tensor.expand_shape %1004 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1005 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_544 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1006 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1005 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1007 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1006 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1008 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1005, %1007 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1009 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1008, %1008 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1010 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1009 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1011 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1010 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1012 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1011 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1013 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_544 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1014 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1013 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1015 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1012 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1016 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1015 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1017 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_544, %1014 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1018 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1017, %1016 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1019 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1018, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1020 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1019, %cst_153 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1021 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_152 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_545 = tensor.collapse_shape %1020 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1022 = linalg.matmul ins(%collapsed_545, %1021 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1023 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_151, %1022 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_546 = tensor.expand_shape %1023 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1024 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_546 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1025 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_150 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1026 = linalg.matmul ins(%collapsed_545, %1025 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1027 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_149, %1026 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_547 = tensor.expand_shape %1027 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1028 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_547 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1029 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1028 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1030 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_148 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1031 = linalg.matmul ins(%collapsed_545, %1030 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1032 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_147, %1031 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_548 = tensor.expand_shape %1032 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1033 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_548 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1034 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1033 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_549 = tensor.expand_shape %1024 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1035 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_549 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1036 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1035 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_550 = tensor.collapse_shape %1036 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_551 = tensor.collapse_shape %1029 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_552 = tensor.collapse_shape %1034 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1037 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_551 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1038 = linalg.batch_matmul ins(%collapsed_550, %1037 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_553 = tensor.expand_shape %1038 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1039 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_553, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1040 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1039, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_554 = tensor.collapse_shape %1040 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1041:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_554 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1042 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_554, %1041#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1043 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1042 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1044 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1043 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1045 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1043, %1044 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1046 = linalg.batch_matmul ins(%1045, %collapsed_552 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_555 = tensor.expand_shape %1046 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1047 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_555 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1048 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1047 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1049 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_146 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_556 = tensor.collapse_shape %1048 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1050 = linalg.matmul ins(%collapsed_556, %1049 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1051 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_145, %1050 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_557 = tensor.expand_shape %1051 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1052 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_544, %expanded_557 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_558 = tensor.collapse_shape %1052 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1053 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_558 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1054 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1053 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1055 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1054 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1056 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1053, %1055 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1057 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1056, %1056 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1058 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1057 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1059 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1058 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1060 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1059 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1061 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_558 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1062 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1061 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1063 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1060 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1064 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1063 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1065 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_558, %1062 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1066 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1065, %1064 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1067 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1066, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1068 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1067, %cst_144 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1069 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_143 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1070 = linalg.matmul ins(%1068, %1069 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1071 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_142, %1070 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1072 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1071 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1073 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_141 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1074 = linalg.matmul ins(%1072, %1073 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1075 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_140, %1074 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1076 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_558, %1075 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_559 = tensor.expand_shape %1076 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1077 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_559 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1078 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1077 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1079 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1078 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1080 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1077, %1079 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1081 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1080, %1080 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1082 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1081 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1083 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1082 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1084 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1083 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1085 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_559 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1086 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1085 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1087 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1084 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1088 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1087 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1089 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_559, %1086 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1090 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1089, %1088 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1091 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1090, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1092 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1091, %cst_139 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1093 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_138 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_560 = tensor.collapse_shape %1092 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1094 = linalg.matmul ins(%collapsed_560, %1093 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1095 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_137, %1094 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_561 = tensor.expand_shape %1095 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1096 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_561 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1097 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_136 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1098 = linalg.matmul ins(%collapsed_560, %1097 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1099 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_135, %1098 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_562 = tensor.expand_shape %1099 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1100 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_562 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1101 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1100 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1102 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_134 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1103 = linalg.matmul ins(%collapsed_560, %1102 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1104 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_133, %1103 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_563 = tensor.expand_shape %1104 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1105 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_563 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1106 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1105 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_564 = tensor.expand_shape %1096 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1107 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_564 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1108 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1107 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_565 = tensor.collapse_shape %1108 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_566 = tensor.collapse_shape %1101 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_567 = tensor.collapse_shape %1106 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1109 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_566 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1110 = linalg.batch_matmul ins(%collapsed_565, %1109 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_568 = tensor.expand_shape %1110 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1111 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_568, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1112 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1111, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_569 = tensor.collapse_shape %1112 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1113:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_569 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1114 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_569, %1113#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1115 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1114 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1116 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1115 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1117 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1115, %1116 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1118 = linalg.batch_matmul ins(%1117, %collapsed_567 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_570 = tensor.expand_shape %1118 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1119 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_570 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1120 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1119 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1121 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_132 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_571 = tensor.collapse_shape %1120 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1122 = linalg.matmul ins(%collapsed_571, %1121 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1123 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_131, %1122 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_572 = tensor.expand_shape %1123 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1124 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_559, %expanded_572 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_573 = tensor.collapse_shape %1124 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1125 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_573 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1126 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1125 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1127 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1126 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1128 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1125, %1127 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1129 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1128, %1128 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1130 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1129 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1131 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1130 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1132 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1131 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1133 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_573 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1134 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1133 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1135 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1132 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1136 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1135 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1137 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_573, %1134 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1138 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1137, %1136 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1139 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1138, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1140 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1139, %cst_130 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1141 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_129 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1142 = linalg.matmul ins(%1140, %1141 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1143 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_128, %1142 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1144 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1143 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1145 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_127 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1146 = linalg.matmul ins(%1144, %1145 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1147 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_126, %1146 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1148 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_573, %1147 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_574 = tensor.expand_shape %1148 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1149 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_574 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1150 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1149 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1151 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1150 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1152 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1149, %1151 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1153 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1152, %1152 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1154 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1153 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1155 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1154 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1156 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1155 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1157 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_574 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1158 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1157 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1159 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1156 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1160 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1159 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1161 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_574, %1158 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1162 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1161, %1160 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1163 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1162, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1164 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1163, %cst_125 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1165 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_124 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_575 = tensor.collapse_shape %1164 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1166 = linalg.matmul ins(%collapsed_575, %1165 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1167 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_123, %1166 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_576 = tensor.expand_shape %1167 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1168 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_576 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1169 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_122 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1170 = linalg.matmul ins(%collapsed_575, %1169 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1171 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_121, %1170 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_577 = tensor.expand_shape %1171 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1172 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_577 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1173 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1172 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1174 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_120 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1175 = linalg.matmul ins(%collapsed_575, %1174 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1176 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_119, %1175 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_578 = tensor.expand_shape %1176 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1177 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_578 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1178 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1177 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_579 = tensor.expand_shape %1168 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1179 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_579 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1180 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1179 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_580 = tensor.collapse_shape %1180 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_581 = tensor.collapse_shape %1173 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_582 = tensor.collapse_shape %1178 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1181 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_581 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1182 = linalg.batch_matmul ins(%collapsed_580, %1181 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_583 = tensor.expand_shape %1182 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1183 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_583, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1184 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1183, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_584 = tensor.collapse_shape %1184 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1185:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_584 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1186 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_584, %1185#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1187 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1186 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1188 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1187 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1189 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1187, %1188 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1190 = linalg.batch_matmul ins(%1189, %collapsed_582 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_585 = tensor.expand_shape %1190 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1191 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_585 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1192 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1191 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1193 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_118 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_586 = tensor.collapse_shape %1192 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1194 = linalg.matmul ins(%collapsed_586, %1193 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1195 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_117, %1194 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_587 = tensor.expand_shape %1195 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1196 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_574, %expanded_587 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_588 = tensor.collapse_shape %1196 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1197 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_588 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1198 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1197 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1199 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1198 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1200 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1197, %1199 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1201 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1200, %1200 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1202 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1201 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1203 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1202 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1204 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1203 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1205 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_588 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1206 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1205 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1207 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1204 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1208 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1207 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1209 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_588, %1206 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1210 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1209, %1208 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1211 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1210, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1212 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1211, %cst_116 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1213 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_115 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1214 = linalg.matmul ins(%1212, %1213 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1215 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_114, %1214 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1216 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1215 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1217 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_113 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1218 = linalg.matmul ins(%1216, %1217 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1219 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_112, %1218 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1220 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_588, %1219 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_589 = tensor.expand_shape %1220 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1221 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_589 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1222 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1221 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1223 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1222 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1224 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1221, %1223 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1225 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1224, %1224 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1226 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1225 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1227 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1226 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1228 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1227 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1229 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_589 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1230 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1229 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1231 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1228 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1232 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1231 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1233 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_589, %1230 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1234 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1233, %1232 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1235 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1234, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1236 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1235, %cst_111 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1237 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_110 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_590 = tensor.collapse_shape %1236 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1238 = linalg.matmul ins(%collapsed_590, %1237 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1239 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_109, %1238 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_591 = tensor.expand_shape %1239 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1240 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_591 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1241 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_108 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1242 = linalg.matmul ins(%collapsed_590, %1241 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1243 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_107, %1242 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_592 = tensor.expand_shape %1243 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1244 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_592 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1245 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1244 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1246 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_106 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1247 = linalg.matmul ins(%collapsed_590, %1246 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1248 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_105, %1247 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_593 = tensor.expand_shape %1248 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1249 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_593 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1250 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1249 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_594 = tensor.expand_shape %1240 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1251 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_594 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1252 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1251 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_595 = tensor.collapse_shape %1252 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_596 = tensor.collapse_shape %1245 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_597 = tensor.collapse_shape %1250 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1253 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_596 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1254 = linalg.batch_matmul ins(%collapsed_595, %1253 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_598 = tensor.expand_shape %1254 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1255 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_598, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1256 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1255, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_599 = tensor.collapse_shape %1256 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1257:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_599 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1258 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_599, %1257#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1259 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1258 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1260 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1259 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1261 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1259, %1260 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1262 = linalg.batch_matmul ins(%1261, %collapsed_597 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_600 = tensor.expand_shape %1262 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1263 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_600 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1264 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1263 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1265 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_104 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_601 = tensor.collapse_shape %1264 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1266 = linalg.matmul ins(%collapsed_601, %1265 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1267 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_103, %1266 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_602 = tensor.expand_shape %1267 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1268 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_589, %expanded_602 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_603 = tensor.collapse_shape %1268 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1269 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_603 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1270 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1269 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1271 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1270 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1272 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1269, %1271 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1273 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1272, %1272 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1274 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1273 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1275 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1274 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1276 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1275 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1277 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_603 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1278 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1277 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1279 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1276 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1280 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1279 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1281 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_603, %1278 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1282 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1281, %1280 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1283 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1282, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1284 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1283, %cst_102 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1285 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_101 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1286 = linalg.matmul ins(%1284, %1285 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1287 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_100, %1286 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1288 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1287 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1289 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_99 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1290 = linalg.matmul ins(%1288, %1289 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1291 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_98, %1290 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1292 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_603, %1291 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_604 = tensor.expand_shape %1292 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1293 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_604 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1294 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1293 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1295 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1294 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1296 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1293, %1295 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1297 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1296, %1296 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1298 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1297 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1299 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1298 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1300 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1299 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1301 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_604 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1302 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1301 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1303 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1300 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1304 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1303 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1305 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_604, %1302 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1306 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1305, %1304 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1307 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1306, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1308 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1307, %cst_97 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1309 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_96 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_605 = tensor.collapse_shape %1308 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1310 = linalg.matmul ins(%collapsed_605, %1309 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1311 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_95, %1310 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_606 = tensor.expand_shape %1311 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1312 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_606 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1313 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_94 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1314 = linalg.matmul ins(%collapsed_605, %1313 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1315 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_93, %1314 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_607 = tensor.expand_shape %1315 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1316 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_607 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1317 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1316 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1318 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_92 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1319 = linalg.matmul ins(%collapsed_605, %1318 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1320 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_91, %1319 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_608 = tensor.expand_shape %1320 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1321 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_608 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1322 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1321 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_609 = tensor.expand_shape %1312 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1323 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_609 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1324 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1323 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_610 = tensor.collapse_shape %1324 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_611 = tensor.collapse_shape %1317 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_612 = tensor.collapse_shape %1322 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1325 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_611 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1326 = linalg.batch_matmul ins(%collapsed_610, %1325 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_613 = tensor.expand_shape %1326 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1327 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_613, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1328 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1327, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_614 = tensor.collapse_shape %1328 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1329:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_614 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1330 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_614, %1329#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1331 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1330 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1332 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1331 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1333 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1331, %1332 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1334 = linalg.batch_matmul ins(%1333, %collapsed_612 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_615 = tensor.expand_shape %1334 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1335 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_615 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1336 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1335 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1337 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_90 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_616 = tensor.collapse_shape %1336 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1338 = linalg.matmul ins(%collapsed_616, %1337 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1339 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_89, %1338 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_617 = tensor.expand_shape %1339 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1340 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_604, %expanded_617 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_618 = tensor.collapse_shape %1340 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1341 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_618 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1342 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1341 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1343 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1342 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1344 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1341, %1343 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1345 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1344, %1344 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1346 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1345 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1347 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1346 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1348 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1347 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1349 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_618 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1350 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1349 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1351 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1348 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1352 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1351 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1353 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_618, %1350 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1354 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1353, %1352 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1355 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1354, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1356 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1355, %cst_88 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1357 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_87 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1358 = linalg.matmul ins(%1356, %1357 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1359 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_86, %1358 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1360 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1359 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1361 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_85 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1362 = linalg.matmul ins(%1360, %1361 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1363 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_84, %1362 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1364 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_618, %1363 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_619 = tensor.expand_shape %1364 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1365 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_619 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1366 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1365 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1367 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1366 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1368 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1365, %1367 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1369 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1368, %1368 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1370 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1369 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1371 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1370 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1372 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1371 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1373 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_619 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1374 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1373 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1375 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1372 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1376 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1375 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1377 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_619, %1374 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1378 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1377, %1376 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1379 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1378, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1380 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1379, %cst_83 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1381 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_82 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_620 = tensor.collapse_shape %1380 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1382 = linalg.matmul ins(%collapsed_620, %1381 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1383 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_81, %1382 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_621 = tensor.expand_shape %1383 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1384 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_621 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1385 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_80 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1386 = linalg.matmul ins(%collapsed_620, %1385 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1387 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_79, %1386 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_622 = tensor.expand_shape %1387 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1388 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_622 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1389 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1388 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1390 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_78 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1391 = linalg.matmul ins(%collapsed_620, %1390 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1392 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_77, %1391 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_623 = tensor.expand_shape %1392 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1393 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_623 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1394 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1393 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_624 = tensor.expand_shape %1384 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1395 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_624 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1396 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1395 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_625 = tensor.collapse_shape %1396 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_626 = tensor.collapse_shape %1389 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_627 = tensor.collapse_shape %1394 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1397 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_626 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1398 = linalg.batch_matmul ins(%collapsed_625, %1397 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_628 = tensor.expand_shape %1398 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1399 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_628, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1400 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1399, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_629 = tensor.collapse_shape %1400 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1401:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_629 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1402 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_629, %1401#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1403 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1402 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1404 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1403 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1405 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1403, %1404 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1406 = linalg.batch_matmul ins(%1405, %collapsed_627 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_630 = tensor.expand_shape %1406 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1407 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_630 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1408 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1407 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1409 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_76 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_631 = tensor.collapse_shape %1408 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1410 = linalg.matmul ins(%collapsed_631, %1409 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1411 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_75, %1410 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_632 = tensor.expand_shape %1411 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1412 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_619, %expanded_632 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_633 = tensor.collapse_shape %1412 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1413 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_633 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1414 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1413 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1415 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1414 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1416 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1413, %1415 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1417 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1416, %1416 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1418 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1417 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1419 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1418 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1420 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1419 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1421 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_633 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1422 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1421 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1423 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1420 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1424 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1423 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1425 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_633, %1422 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1426 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1425, %1424 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1427 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1426, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1428 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1427, %cst_74 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1429 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_73 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1430 = linalg.matmul ins(%1428, %1429 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1431 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_72, %1430 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1432 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1431 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1433 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_71 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1434 = linalg.matmul ins(%1432, %1433 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1435 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_70, %1434 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1436 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_633, %1435 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_634 = tensor.expand_shape %1436 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1437 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_634 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1438 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1437 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1439 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1438 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1440 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1437, %1439 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1441 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1440, %1440 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1442 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1441 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1443 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1442 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1444 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1443 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1445 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_634 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1446 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1445 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1447 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1444 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1448 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1447 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1449 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_634, %1446 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1450 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1449, %1448 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1451 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1450, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1452 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1451, %cst_69 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1453 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_68 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_635 = tensor.collapse_shape %1452 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1454 = linalg.matmul ins(%collapsed_635, %1453 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1455 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_67, %1454 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_636 = tensor.expand_shape %1455 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1456 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_636 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1457 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_66 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1458 = linalg.matmul ins(%collapsed_635, %1457 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1459 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_65, %1458 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_637 = tensor.expand_shape %1459 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1460 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_637 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1461 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1460 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1462 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_64 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1463 = linalg.matmul ins(%collapsed_635, %1462 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1464 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_63, %1463 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_638 = tensor.expand_shape %1464 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1465 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_638 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1466 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1465 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_639 = tensor.expand_shape %1456 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1467 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_639 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1468 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1467 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_640 = tensor.collapse_shape %1468 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_641 = tensor.collapse_shape %1461 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_642 = tensor.collapse_shape %1466 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1469 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_641 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1470 = linalg.batch_matmul ins(%collapsed_640, %1469 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_643 = tensor.expand_shape %1470 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1471 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_643, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1472 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1471, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_644 = tensor.collapse_shape %1472 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1473:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_644 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1474 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_644, %1473#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1475 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1474 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1476 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1475 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1477 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1475, %1476 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1478 = linalg.batch_matmul ins(%1477, %collapsed_642 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_645 = tensor.expand_shape %1478 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1479 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_645 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1480 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1479 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1481 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_62 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_646 = tensor.collapse_shape %1480 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1482 = linalg.matmul ins(%collapsed_646, %1481 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1483 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_61, %1482 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_647 = tensor.expand_shape %1483 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1484 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_634, %expanded_647 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_648 = tensor.collapse_shape %1484 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1485 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_648 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1486 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1485 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1487 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1486 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1488 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1485, %1487 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1489 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1488, %1488 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1490 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1489 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1491 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1490 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1492 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1491 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1493 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_648 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1494 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1493 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1495 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1492 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1496 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1495 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1497 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_648, %1494 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1498 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1497, %1496 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1499 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1498, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1500 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1499, %cst_60 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1501 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_59 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1502 = linalg.matmul ins(%1500, %1501 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1503 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_58, %1502 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1504 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1503 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1505 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_57 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1506 = linalg.matmul ins(%1504, %1505 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1507 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_56, %1506 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1508 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_648, %1507 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_649 = tensor.expand_shape %1508 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1509 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_649 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1510 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1509 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1511 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1510 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1512 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1509, %1511 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1513 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1512, %1512 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1514 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1513 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1515 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1514 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1516 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1515 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1517 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_649 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1518 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1517 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1519 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1516 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1520 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1519 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1521 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_649, %1518 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1522 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1521, %1520 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1523 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1522, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1524 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1523, %cst_55 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1525 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_54 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_650 = tensor.collapse_shape %1524 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1526 = linalg.matmul ins(%collapsed_650, %1525 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1527 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_53, %1526 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_651 = tensor.expand_shape %1527 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1528 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_651 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1529 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_52 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1530 = linalg.matmul ins(%collapsed_650, %1529 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1531 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_51, %1530 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_652 = tensor.expand_shape %1531 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1532 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_652 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1533 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1532 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1534 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_50 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1535 = linalg.matmul ins(%collapsed_650, %1534 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1536 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_49, %1535 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_653 = tensor.expand_shape %1536 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1537 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_653 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1538 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1537 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_654 = tensor.expand_shape %1528 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1539 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_654 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1540 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1539 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_655 = tensor.collapse_shape %1540 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_656 = tensor.collapse_shape %1533 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_657 = tensor.collapse_shape %1538 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1541 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_656 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1542 = linalg.batch_matmul ins(%collapsed_655, %1541 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_658 = tensor.expand_shape %1542 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1543 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_658, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1544 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1543, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_659 = tensor.collapse_shape %1544 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1545:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_659 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1546 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_659, %1545#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1547 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1546 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1548 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1547 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1549 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1547, %1548 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1550 = linalg.batch_matmul ins(%1549, %collapsed_657 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_660 = tensor.expand_shape %1550 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1551 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_660 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1552 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1551 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1553 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_48 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_661 = tensor.collapse_shape %1552 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1554 = linalg.matmul ins(%collapsed_661, %1553 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1555 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_47, %1554 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_662 = tensor.expand_shape %1555 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1556 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_649, %expanded_662 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_663 = tensor.collapse_shape %1556 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1557 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_663 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1558 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1557 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1559 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1558 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1560 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1557, %1559 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1561 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1560, %1560 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1562 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1561 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1563 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1562 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1564 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1563 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1565 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_663 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1566 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1565 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1567 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1564 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1568 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1567 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1569 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_663, %1566 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1570 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1569, %1568 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1571 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1570, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1572 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1571, %cst_46 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1573 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_45 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1574 = linalg.matmul ins(%1572, %1573 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1575 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_44, %1574 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1576 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1575 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1577 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_43 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1578 = linalg.matmul ins(%1576, %1577 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1579 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_42, %1578 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1580 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_663, %1579 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_664 = tensor.expand_shape %1580 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1581 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_664 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1582 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1581 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1583 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1582 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1584 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1581, %1583 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1585 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1584, %1584 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1586 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1585 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1587 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1586 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1588 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1587 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1589 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_664 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1590 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1589 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1591 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1588 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1592 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1591 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1593 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_664, %1590 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1594 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1593, %1592 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1595 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1594, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1596 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1595, %cst_41 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1597 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_40 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_665 = tensor.collapse_shape %1596 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1598 = linalg.matmul ins(%collapsed_665, %1597 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1599 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_39, %1598 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_666 = tensor.expand_shape %1599 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1600 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_666 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1601 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_38 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1602 = linalg.matmul ins(%collapsed_665, %1601 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1603 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_37, %1602 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_667 = tensor.expand_shape %1603 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1604 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_667 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1605 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1604 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1606 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_36 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1607 = linalg.matmul ins(%collapsed_665, %1606 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1608 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_35, %1607 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_668 = tensor.expand_shape %1608 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1609 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_668 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1610 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1609 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_669 = tensor.expand_shape %1600 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1611 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_669 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1612 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1611 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_670 = tensor.collapse_shape %1612 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_671 = tensor.collapse_shape %1605 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_672 = tensor.collapse_shape %1610 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1613 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_671 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1614 = linalg.batch_matmul ins(%collapsed_670, %1613 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_673 = tensor.expand_shape %1614 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1615 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_673, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1616 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1615, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_674 = tensor.collapse_shape %1616 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1617:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_674 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1618 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_674, %1617#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1619 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1618 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1620 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1619 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1621 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1619, %1620 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1622 = linalg.batch_matmul ins(%1621, %collapsed_672 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_675 = tensor.expand_shape %1622 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1623 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_675 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1624 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1623 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1625 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_34 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_676 = tensor.collapse_shape %1624 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1626 = linalg.matmul ins(%collapsed_676, %1625 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1627 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_33, %1626 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_677 = tensor.expand_shape %1627 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1628 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_664, %expanded_677 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_678 = tensor.collapse_shape %1628 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1629 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_678 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1630 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1629 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1631 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1630 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1632 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1629, %1631 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1633 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1632, %1632 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1634 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1633 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1635 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1634 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1636 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1635 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1637 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_678 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1638 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1637 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1639 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1636 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1640 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1639 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1641 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_678, %1638 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1642 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1641, %1640 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1643 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1642, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1644 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1643, %cst_32 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1645 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_31 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1646 = linalg.matmul ins(%1644, %1645 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1647 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_30, %1646 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1648 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1647 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1649 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_29 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1650 = linalg.matmul ins(%1648, %1649 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1651 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_28, %1650 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1652 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_678, %1651 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_679 = tensor.expand_shape %1652 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1653 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_679 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1654 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1653 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1655 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1654 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1656 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1653, %1655 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1657 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1656, %1656 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1658 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1657 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1659 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1658 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1660 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1659 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1661 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_679 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1662 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1661 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1663 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1660 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1664 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1663 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1665 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_679, %1662 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1666 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1665, %1664 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1667 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1666, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1668 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1667, %cst_27 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1669 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_26 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_680 = tensor.collapse_shape %1668 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1670 = linalg.matmul ins(%collapsed_680, %1669 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1671 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_25, %1670 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_681 = tensor.expand_shape %1671 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1672 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_681 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1673 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_24 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1674 = linalg.matmul ins(%collapsed_680, %1673 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1675 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_23, %1674 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_682 = tensor.expand_shape %1675 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1676 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_682 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1677 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1676 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1678 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_22 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1679 = linalg.matmul ins(%collapsed_680, %1678 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1680 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_21, %1679 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_683 = tensor.expand_shape %1680 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1681 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_683 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1682 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1681 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_684 = tensor.expand_shape %1672 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1683 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_684 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1684 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1683 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_685 = tensor.collapse_shape %1684 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_686 = tensor.collapse_shape %1677 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_687 = tensor.collapse_shape %1682 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1685 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_686 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1686 = linalg.batch_matmul ins(%collapsed_685, %1685 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_688 = tensor.expand_shape %1686 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1687 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_688, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1688 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1687, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_689 = tensor.collapse_shape %1688 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1689:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_689 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1690 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_689, %1689#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1691 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1690 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1692 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1691 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1693 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1691, %1692 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1694 = linalg.batch_matmul ins(%1693, %collapsed_687 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_690 = tensor.expand_shape %1694 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1695 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_690 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1696 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1695 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1697 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_20 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_691 = tensor.collapse_shape %1696 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1698 = linalg.matmul ins(%collapsed_691, %1697 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1699 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_19, %1698 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_692 = tensor.expand_shape %1699 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1700 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_679, %expanded_692 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_693 = tensor.collapse_shape %1700 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1701 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_693 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1702 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1701 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1703 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1702 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1704 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1701, %1703 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1705 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1704, %1704 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1706 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1705 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1707 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1706 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1708 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1707 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1709 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_693 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1710 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1709 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1711 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1708 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1712 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1711 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1713 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_693, %1710 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1714 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1713, %1712 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1715 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1714, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1716 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1715, %cst_18 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1717 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_17 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1718 = linalg.matmul ins(%1716, %1717 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1719 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_16, %1718 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1720 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1719 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1721 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_15 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1722 = linalg.matmul ins(%1720, %1721 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1723 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_14, %1722 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1724 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_693, %1723 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_694 = tensor.expand_shape %1724 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1725 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_694 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1726 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1725 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1727 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1726 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1728 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1725, %1727 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1729 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1728, %1728 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1730 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1729 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1731 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1730 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1732 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1731 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1733 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_694 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1734 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1733 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1735 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1732 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1736 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1735 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1737 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_694, %1734 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1738 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1737, %1736 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1739 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1738, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1740 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1739, %cst_13 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1741 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_12 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_695 = tensor.collapse_shape %1740 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1742 = linalg.matmul ins(%collapsed_695, %1741 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1743 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_11, %1742 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_696 = tensor.expand_shape %1743 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1744 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_696 : tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.mulf %in, %cst_348 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1745 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_10 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1746 = linalg.matmul ins(%collapsed_695, %1745 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1747 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_9, %1746 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_697 = tensor.expand_shape %1747 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1748 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_697 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1749 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1748 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1750 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_8 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%1751 = linalg.matmul ins(%collapsed_695, %1750 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1752 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_7, %1751 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_698 = tensor.expand_shape %1752 [[0, 1], [2, 3]] : tensor<8x2048xf32> into tensor<1x8x32x64xf32>
%1753 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_698 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1754 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1753 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%expanded_699 = tensor.expand_shape %1744 [[0], [1], [2, 3]] : tensor<1x8x2048xf32> into tensor<1x8x32x64xf32>
%1755 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_699 : tensor<1x8x32x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%1756 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1755 : tensor<1x32x8x64xf32>) outs(%70 : tensor<1x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x8x64xf32>
%collapsed_700 = tensor.collapse_shape %1756 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_701 = tensor.collapse_shape %1749 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%collapsed_702 = tensor.collapse_shape %1754 [[0, 1], [2], [3]] : tensor<1x32x8x64xf32> into tensor<32x8x64xf32>
%1757 = linalg.generic {indexing_maps = [#map1, #map20], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_701 : tensor<32x8x64xf32>) outs(%80 : tensor<32x64x8xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<32x64x8xf32>
%1758 = linalg.batch_matmul ins(%collapsed_700, %1757 : tensor<32x8x64xf32>, tensor<32x64x8xf32>) outs(%83 : tensor<32x8x8xf32>) -> tensor<32x8x8xf32>
%expanded_703 = tensor.expand_shape %1758 [[0, 1], [2], [3]] : tensor<32x8x8xf32> into tensor<1x32x8x8xf32>
%1759 = linalg.generic {indexing_maps = [#map19, #map8, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_703, %28 : tensor<1x32x8x8xf32>, tensor<1x1x8x8xf32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x32x8x8xf32>
%1760 = linalg.generic {indexing_maps = [#map19, #map11, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1759, %cst_329 : tensor<1x32x8x8xf32>, tensor<f32>) outs(%85 : tensor<1x32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %in_712 : f32
%1819 = arith.select %1818, %in, %in_712 : f32
linalg.yield %1819 : f32
} -> tensor<1x32x8x8xf32>
%collapsed_704 = tensor.collapse_shape %1760 [[0, 1], [2], [3]] : tensor<1x32x8x8xf32> into tensor<32x8x8xf32>
%1761:2 = linalg.generic {indexing_maps = [#map1, #map14, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%collapsed_704 : tensor<32x8x8xf32>) outs(%91, %89 : tensor<32x8x1xf32>, tensor<32x8x1xi64>) {
^bb0(%in: f32, %out: f32, %out_712: i64):
%1818 = linalg.index 2 : index
%1819 = arith.index_cast %1818 : index to i64
%1820 = arith.maximumf %in, %out : f32
%1821 = arith.cmpf ogt, %in, %out : f32
%1822 = arith.select %1821, %1819, %out_712 : i64
linalg.yield %1820, %1822 : f32, i64
} -> (tensor<32x8x1xf32>, tensor<32x8x1xi64>)
%1762 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_704, %1761#0 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1763 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1762 : tensor<32x8x8xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.exp %in : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1764 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1763 : tensor<32x8x8xf32>) outs(%95 : tensor<32x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<32x8x1xf32>
%1765 = linalg.generic {indexing_maps = [#map1, #map14, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1763, %1764 : tensor<32x8x8xf32>, tensor<32x8x1xf32>) outs(%82 : tensor<32x8x8xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.divf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<32x8x8xf32>
%1766 = linalg.batch_matmul ins(%1765, %collapsed_702 : tensor<32x8x8xf32>, tensor<32x8x64xf32>) outs(%99 : tensor<32x8x64xf32>) -> tensor<32x8x64xf32>
%expanded_705 = tensor.expand_shape %1766 [[0, 1], [2], [3]] : tensor<32x8x64xf32> into tensor<1x32x8x64xf32>
%1767 = linalg.generic {indexing_maps = [#map9, #map18], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_705 : tensor<1x32x8x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1768 = linalg.generic {indexing_maps = [#map19, #map9], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1767 : tensor<1x8x32x64xf32>) outs(%101 : tensor<1x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x8x32x64xf32>
%1769 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_6 : tensor<2048x2048xf32>) outs(%60 : tensor<2048x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x2048xf32>
%collapsed_706 = tensor.collapse_shape %1768 [[0, 1], [2, 3]] : tensor<1x8x32x64xf32> into tensor<8x2048xf32>
%1770 = linalg.matmul ins(%collapsed_706, %1769 : tensor<8x2048xf32>, tensor<2048x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1771 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_5, %1770 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_707 = tensor.expand_shape %1771 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1772 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_694, %expanded_707 : tensor<1x8x2048xf32>, tensor<1x8x2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%collapsed_708 = tensor.collapse_shape %1772 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1773 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_708 : tensor<8x2048xf32>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1774 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1773 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1775 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1774 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1776 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1773, %1775 : tensor<8x2048xf64>, tensor<8x1xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1777 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1776, %1776 : tensor<8x2048xf64>, tensor<8x2048xf64>) outs(%108 : tensor<8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<8x2048xf64>
%1778 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%1777 : tensor<8x2048xf64>) outs(%111 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1779 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1778 : tensor<8x1xf64>) outs(%110 : tensor<8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<8x1xf64>
%1780 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1779 : tensor<8x1xf64>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1781 = linalg.generic {indexing_maps = [#map4, #map7], iterator_types = ["parallel", "reduction"]} ins(%collapsed_708 : tensor<8x2048xf32>) outs(%120 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1782 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1781 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1783 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1780 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<8x1xf32>
%1784 = linalg.generic {indexing_maps = [#map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1783 : tensor<8x1xf32>) outs(%118 : tensor<8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<8x1xf32>
%1785 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_708, %1782 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1786 = linalg.generic {indexing_maps = [#map4, #map7, #map4], iterator_types = ["parallel", "parallel"]} ins(%1785, %1784 : tensor<8x2048xf32>, tensor<8x1xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1787 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1786, %cst_337 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1788 = linalg.generic {indexing_maps = [#map4, #map6, #map4], iterator_types = ["parallel", "parallel"]} ins(%1787, %cst_4 : tensor<8x2048xf32>, tensor<2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1789 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_3 : tensor<8192x2048xf32>) outs(%129 : tensor<2048x8192xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x8192xf32>
%1790 = linalg.matmul ins(%1788, %1789 : tensor<8x2048xf32>, tensor<2048x8192xf32>) outs(%132 : tensor<8x8192xf32>) -> tensor<8x8192xf32>
%1791 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_2, %1790 : tensor<8192xf32>, tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x8192xf32>
%1792 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%1791 : tensor<8x8192xf32>) outs(%131 : tensor<8x8192xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.cmpf ugt, %in, %cst_340 : f32
%1819 = arith.select %1818, %in, %cst_340 : f32
linalg.yield %1819 : f32
} -> tensor<8x8192xf32>
%1793 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_1 : tensor<2048x8192xf32>) outs(%136 : tensor<8192x2048xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<8192x2048xf32>
%1794 = linalg.matmul ins(%1792, %1793 : tensor<8x8192xf32>, tensor<8192x2048xf32>) outs(%63 : tensor<8x2048xf32>) -> tensor<8x2048xf32>
%1795 = linalg.generic {indexing_maps = [#map6, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst_0, %1794 : tensor<2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%1796 = linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%collapsed_708, %1795 : tensor<8x2048xf32>, tensor<8x2048xf32>) outs(%62 : tensor<8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<8x2048xf32>
%expanded_709 = tensor.expand_shape %1796 [[0, 1], [2]] : tensor<8x2048xf32> into tensor<1x8x2048xf32>
%1797 = linalg.generic {indexing_maps = [#map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_709 : tensor<1x8x2048xf32>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f32, %out: f64):
%1818 = arith.extf %in : f32 to f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1798 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1797 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1799 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1798 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1800 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1797, %1799 : tensor<1x8x2048xf64>, tensor<1x8x1xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.subf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1801 = linalg.generic {indexing_maps = [#map13, #map13, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1800, %1800 : tensor<1x8x2048xf64>, tensor<1x8x2048xf64>) outs(%39 : tensor<1x8x2048xf64>) {
^bb0(%in: f64, %in_712: f64, %out: f64):
%1818 = arith.mulf %in, %in_712 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x2048xf64>
%1802 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%1801 : tensor<1x8x2048xf64>) outs(%42 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.addf %in, %out : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1803 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1802 : tensor<1x8x1xf64>) outs(%41 : tensor<1x8x1xf64>) {
^bb0(%in: f64, %out: f64):
%1818 = arith.divf %in, %cst_346 : f64
linalg.yield %1818 : f64
} -> tensor<1x8x1xf64>
%1804 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1803 : tensor<1x8x1xf64>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f64, %out: f32):
%1818 = arith.truncf %in : f64 to f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1805 = linalg.generic {indexing_maps = [#map1, #map14], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded_709 : tensor<1x8x2048xf32>) outs(%51 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.addf %in, %out : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1806 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1805 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.divf %in, %cst_347 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1807 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1804 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = arith.truncf %cst_344 : f64 to f32
%1819 = arith.addf %in, %1818 : f32
linalg.yield %1819 : f32
} -> tensor<1x8x1xf32>
%1808 = linalg.generic {indexing_maps = [#map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1807 : tensor<1x8x1xf32>) outs(%49 : tensor<1x8x1xf32>) {
^bb0(%in: f32, %out: f32):
%1818 = math.rsqrt %in : f32
linalg.yield %1818 : f32
} -> tensor<1x8x1xf32>
%1809 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_709, %1806 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.subf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1810 = linalg.generic {indexing_maps = [#map13, #map15, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1809, %1808 : tensor<1x8x2048xf32>, tensor<1x8x1xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1811 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1810, %cst_337 : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.mulf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1812 = linalg.generic {indexing_maps = [#map13, #map16, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1811, %cst : tensor<1x8x2048xf32>, tensor<2048xf32>) outs(%0 : tensor<1x8x2048xf32>) {
^bb0(%in: f32, %in_712: f32, %out: f32):
%1818 = arith.addf %in, %in_712 : f32
linalg.yield %1818 : f32
} -> tensor<1x8x2048xf32>
%1813 = tensor.empty() : tensor<2048x50272xf32>
%1814 = linalg.generic {indexing_maps = [#map4, #map17], iterator_types = ["parallel", "parallel"]} ins(%cst_339 : tensor<50272x2048xf32>) outs(%1813 : tensor<2048x50272xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2048x50272xf32>
%collapsed_710 = tensor.collapse_shape %1812 [[0, 1], [2]] : tensor<1x8x2048xf32> into tensor<8x2048xf32>
%1815 = tensor.empty() : tensor<8x50272xf32>
%1816 = linalg.fill ins(%cst_340 : f32) outs(%1815 : tensor<8x50272xf32>) -> tensor<8x50272xf32>
%1817 = linalg.matmul ins(%collapsed_710, %1814 : tensor<8x2048xf32>, tensor<2048x50272xf32>) outs(%1816 : tensor<8x50272xf32>) -> tensor<8x50272xf32>
%expanded_711 = tensor.expand_shape %1817 [[0, 1], [2]] : tensor<8x50272xf32> into tensor<1x8x50272xf32>
return %expanded_711 : tensor<1x8x50272xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment