Skip to content

Instantly share code, notes, and snippets.

@vivekkhandelwal1
Created December 22, 2022 05:04
Show Gist options
  • Save vivekkhandelwal1/3f912e3a9ba62ce2895533185f837b44 to your computer and use it in GitHub Desktop.
Save vivekkhandelwal1/3f912e3a9ba62ce2895533185f837b44 to your computer and use it in GitHub Desktop.
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,128],si64>) -> !torch.vtensor<[1,2],f32> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%true = torch.constant.bool true
%none = torch.constant.none
%int768 = torch.constant.int 768
%int128 = torch.constant.int 128
%int-1 = torch.constant.int -1
%int6 = torch.constant.int 6
%false = torch.constant.bool false
%0 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2x768xf32>) : !torch.vtensor<[2,768],f32>
%1 = torch.vtensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.vtensor<[3072,768],f32>
%2 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.vtensor<[768,3072],f32>
%3 = torch.vtensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.vtensor<[3072],f32>
%4 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.vtensor<[768,768],f32>
%5 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.vtensor<[1,1,1024,1024],ui8>
%6 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.vtensor<[768,2304],f32>
%7 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.vtensor<[2304],f32>
%8 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.vtensor<[768],f32>
%9 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1024x768xf32>) : !torch.vtensor<[1024,768],f32>
%10 = torch.vtensor.literal(dense_resource<__elided__> : tensor<50257x768xf32>) : !torch.vtensor<[50257,768],f32>
%int-2 = torch.constant.int -2
%int11 = torch.constant.int 11
%float-3.402820e38 = torch.constant.float -3.4028234663852886E+38
%int4 = torch.constant.int 4
%float1.000000e-05 = torch.constant.float 1.000000e-05
%int2304 = torch.constant.int 2304
%int294912 = torch.constant.int 294912
%int1536 = torch.constant.int 1536
%int12 = torch.constant.int 12
%int64 = torch.constant.int 64
%int3 = torch.constant.int 3
%float8.000000e00 = torch.constant.float 8.000000e+00
%int1024 = torch.constant.int 1024
%int1048576 = torch.constant.int 1048576
%int3072 = torch.constant.int 3072
%float5.000000e-01 = torch.constant.float 5.000000e-01
%float3.000000e00 = torch.constant.float 3.000000e+00
%float4.471500e-02 = torch.constant.float 4.471500e-02
%float7.978850e-01 = torch.constant.float 0.79788456080286541
%float1.000000e00 = torch.constant.float 1.000000e+00
%cpu = torch.constant.device "cpu"
%11 = torch.prim.ListConstruct %int-1, %int128 : (!torch.int, !torch.int) -> !torch.list<int>
%12 = torch.aten.view %arg0, %11 : !torch.vtensor<[1,128],si64>, !torch.list<int> -> !torch.vtensor<[1,128],si64>
%13 = torch.aten.arange.start_step %int0, %int128, %int1, %int4, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64>
%14 = torch.aten.unsqueeze %13, %int0 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[1,128],si64>
%15 = torch.aten.view %14, %11 : !torch.vtensor<[1,128],si64>, !torch.list<int> -> !torch.vtensor<[1,128],si64>
%16 = torch.aten.embedding %10, %12, %int-1, %false, %false : !torch.vtensor<[50257,768],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,768],f32>
%17 = torch.aten.embedding %9, %15, %int-1, %false, %false : !torch.vtensor<[1024,768],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,768],f32>
%18 = torch.aten.add.Tensor %16, %17, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%19 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%20 = torch.aten.sum.dim_IntList %18, %19, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%21 = torch.aten.div.Scalar %20, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%22 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%23 = torch.aten.broadcast_to %21, %22 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%24 = torch.aten.sub.Tensor %18, %23, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%25 = torch.aten.mul.Tensor %24, %24 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%26 = torch.aten.sum.dim_IntList %25, %19, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%27 = torch.aten.div.Scalar %26, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%28 = torch.aten.add.Scalar %27, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%29 = torch.aten.rsqrt %28 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%30 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%31 = torch.aten.broadcast_to %29, %30 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%32 = torch.aten.mul.Tensor %24, %31 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%33 = torch.aten.mul.Tensor %32, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%34 = torch.aten.add.Tensor %33, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%35 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int>
%36 = torch.aten.view %34, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%37 = torch.aten.mm %36, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32>
%38 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32>
%39 = torch.aten.add.Tensor %38, %37, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32>
%40 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%41 = torch.aten.view %39, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32>
%42 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%43 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%44 = torch.aten.as_strided %41, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%45 = torch.aten.as_strided %41, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%46 = torch.aten.as_strided %41, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%47 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%48 = torch.aten.view %44, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%49 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%50 = torch.aten.permute %48, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%51 = torch.aten.view %45, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%52 = torch.aten.permute %51, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%53 = torch.aten.view %46, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%54 = torch.aten.permute %53, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%55 = torch.aten.transpose.int %52, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32>
%56 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%57 = torch.aten.broadcast_to %50, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%58 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%59 = torch.aten.view %57, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%60 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%61 = torch.aten.broadcast_to %55, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32>
%62 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%63 = torch.aten.view %61, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32>
%64 = torch.aten.bmm %59, %63 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32>
%65 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%66 = torch.aten.view %64, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%67 = torch.prim.ListConstruct : () -> !torch.list<int>
%68 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64>
%69 = torch.aten.to.dtype %68, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%70 = torch.aten.broadcast_to %69, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%71 = torch.aten.div.Tensor %66, %70 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%72 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%73 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%74 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%75 = torch.aten.as_strided %74, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%76 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%77 = torch.aten.as_strided %75, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8>
%78 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%79 = torch.aten.as_strided %77, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8>
%80 = torch.aten.to.dtype %79, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1>
%81 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64>
%82 = torch.aten.to.dtype %81, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%83 = torch.aten.broadcast_to %82, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%84 = torch.aten.where.self %80, %71, %83 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%85 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%values, %indices = torch.aten.max.dim %84, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64>
%86 = torch.aten.sub.Tensor %84, %values, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%87 = torch.aten.exp %86 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32>
%88 = torch.aten.sum.dim_IntList %87, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32>
%89 = torch.aten.div.Tensor %87, %88 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32>
%90 = torch.aten.broadcast_to %89, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%91 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%92 = torch.aten.view %90, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32>
%93 = torch.aten.broadcast_to %54, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%94 = torch.aten.view %93, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%95 = torch.aten.bmm %92, %94 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32>
%96 = torch.aten.view %95, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%97 = torch.aten.permute %96, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%98 = torch.aten.clone %97, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32>
%99 = torch.aten.view %98, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%100 = torch.aten.view %99, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%101 = torch.aten.mm %100, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32>
%102 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%103 = torch.aten.add.Tensor %102, %101, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%104 = torch.aten.view %103, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%105 = torch.aten.add.Tensor %104, %18, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%106 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%107 = torch.aten.sum.dim_IntList %105, %106, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%108 = torch.aten.div.Scalar %107, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%109 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%110 = torch.aten.broadcast_to %108, %109 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%111 = torch.aten.sub.Tensor %105, %110, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%112 = torch.aten.mul.Tensor %111, %111 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%113 = torch.aten.sum.dim_IntList %112, %106, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%114 = torch.aten.div.Scalar %113, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%115 = torch.aten.add.Scalar %114, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%116 = torch.aten.rsqrt %115 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%117 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%118 = torch.aten.broadcast_to %116, %117 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%119 = torch.aten.mul.Tensor %111, %118 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%120 = torch.aten.mul.Tensor %119, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%121 = torch.aten.add.Tensor %120, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%122 = torch.aten.view %121, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%123 = torch.aten.mm %122, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32>
%124 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32>
%125 = torch.aten.add.Tensor %124, %123, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32>
%126 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%127 = torch.aten.view %125, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32>
%128 = torch.aten.mul.Scalar %127, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%129 = torch.aten.pow.Tensor_Scalar %127, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%130 = torch.aten.mul.Scalar %129, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%131 = torch.aten.add.Tensor %127, %130, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%132 = torch.aten.mul.Scalar %131, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%133 = torch.aten.tanh %132 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%134 = torch.aten.add.Scalar %133, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%135 = torch.aten.mul.Tensor %128, %134 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%136 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int>
%137 = torch.aten.view %135, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32>
%138 = torch.aten.mm %137, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32>
%139 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%140 = torch.aten.add.Tensor %139, %138, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%141 = torch.aten.view %140, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%142 = torch.aten.add.Tensor %105, %141, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%143 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%144 = torch.aten.sum.dim_IntList %142, %143, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%145 = torch.aten.div.Scalar %144, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%146 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%147 = torch.aten.broadcast_to %145, %146 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%148 = torch.aten.sub.Tensor %142, %147, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%149 = torch.aten.mul.Tensor %148, %148 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%150 = torch.aten.sum.dim_IntList %149, %143, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%151 = torch.aten.div.Scalar %150, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%152 = torch.aten.add.Scalar %151, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%153 = torch.aten.rsqrt %152 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%154 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%155 = torch.aten.broadcast_to %153, %154 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%156 = torch.aten.mul.Tensor %148, %155 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%157 = torch.aten.mul.Tensor %156, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%158 = torch.aten.add.Tensor %157, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%159 = torch.aten.view %158, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%160 = torch.aten.mm %159, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32>
%161 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32>
%162 = torch.aten.add.Tensor %161, %160, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32>
%163 = torch.aten.view %162, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32>
%164 = torch.aten.as_strided %163, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%165 = torch.aten.as_strided %163, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%166 = torch.aten.as_strided %163, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%167 = torch.aten.view %164, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%168 = torch.aten.permute %167, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%169 = torch.aten.view %165, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%170 = torch.aten.permute %169, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%171 = torch.aten.view %166, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%172 = torch.aten.permute %171, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%173 = torch.aten.transpose.int %170, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32>
%174 = torch.aten.broadcast_to %168, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%175 = torch.aten.view %174, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%176 = torch.aten.broadcast_to %173, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32>
%177 = torch.aten.view %176, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32>
%178 = torch.aten.bmm %175, %177 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32>
%179 = torch.aten.view %178, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%180 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64>
%181 = torch.aten.to.dtype %180, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%182 = torch.aten.broadcast_to %181, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%183 = torch.aten.div.Tensor %179, %182 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%184 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%185 = torch.aten.as_strided %184, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%186 = torch.aten.as_strided %185, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8>
%187 = torch.aten.as_strided %186, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8>
%188 = torch.aten.to.dtype %187, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1>
%189 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64>
%190 = torch.aten.to.dtype %189, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%191 = torch.aten.broadcast_to %190, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%192 = torch.aten.where.self %188, %183, %191 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%values_0, %indices_1 = torch.aten.max.dim %192, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64>
%193 = torch.aten.sub.Tensor %192, %values_0, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%194 = torch.aten.exp %193 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32>
%195 = torch.aten.sum.dim_IntList %194, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32>
%196 = torch.aten.div.Tensor %194, %195 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32>
%197 = torch.aten.broadcast_to %196, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%198 = torch.aten.view %197, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32>
%199 = torch.aten.broadcast_to %172, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%200 = torch.aten.view %199, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%201 = torch.aten.bmm %198, %200 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32>
%202 = torch.aten.view %201, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%203 = torch.aten.permute %202, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%204 = torch.aten.clone %203, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32>
%205 = torch.aten.view %204, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%206 = torch.aten.view %205, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%207 = torch.aten.mm %206, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32>
%208 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%209 = torch.aten.add.Tensor %208, %207, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%210 = torch.aten.view %209, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%211 = torch.aten.add.Tensor %210, %142, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%212 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%213 = torch.aten.sum.dim_IntList %211, %212, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%214 = torch.aten.div.Scalar %213, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%215 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%216 = torch.aten.broadcast_to %214, %215 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%217 = torch.aten.sub.Tensor %211, %216, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%218 = torch.aten.mul.Tensor %217, %217 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%219 = torch.aten.sum.dim_IntList %218, %212, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%220 = torch.aten.div.Scalar %219, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%221 = torch.aten.add.Scalar %220, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%222 = torch.aten.rsqrt %221 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%223 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%224 = torch.aten.broadcast_to %222, %223 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%225 = torch.aten.mul.Tensor %217, %224 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%226 = torch.aten.mul.Tensor %225, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%227 = torch.aten.add.Tensor %226, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%228 = torch.aten.view %227, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%229 = torch.aten.mm %228, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32>
%230 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32>
%231 = torch.aten.add.Tensor %230, %229, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32>
%232 = torch.aten.view %231, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32>
%233 = torch.aten.mul.Scalar %232, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%234 = torch.aten.pow.Tensor_Scalar %232, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%235 = torch.aten.mul.Scalar %234, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%236 = torch.aten.add.Tensor %232, %235, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%237 = torch.aten.mul.Scalar %236, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%238 = torch.aten.tanh %237 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%239 = torch.aten.add.Scalar %238, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%240 = torch.aten.mul.Tensor %233, %239 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%241 = torch.aten.view %240, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32>
%242 = torch.aten.mm %241, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32>
%243 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%244 = torch.aten.add.Tensor %243, %242, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%245 = torch.aten.view %244, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%246 = torch.aten.add.Tensor %211, %245, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%247 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%248 = torch.aten.sum.dim_IntList %246, %247, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%249 = torch.aten.div.Scalar %248, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%250 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%251 = torch.aten.broadcast_to %249, %250 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%252 = torch.aten.sub.Tensor %246, %251, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%253 = torch.aten.mul.Tensor %252, %252 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%254 = torch.aten.sum.dim_IntList %253, %247, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%255 = torch.aten.div.Scalar %254, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%256 = torch.aten.add.Scalar %255, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%257 = torch.aten.rsqrt %256 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%258 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%259 = torch.aten.broadcast_to %257, %258 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%260 = torch.aten.mul.Tensor %252, %259 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%261 = torch.aten.mul.Tensor %260, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%262 = torch.aten.add.Tensor %261, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%263 = torch.aten.view %262, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%264 = torch.aten.mm %263, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32>
%265 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32>
%266 = torch.aten.add.Tensor %265, %264, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32>
%267 = torch.aten.view %266, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32>
%268 = torch.aten.as_strided %267, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%269 = torch.aten.as_strided %267, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%270 = torch.aten.as_strided %267, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%271 = torch.aten.view %268, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%272 = torch.aten.permute %271, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%273 = torch.aten.view %269, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%274 = torch.aten.permute %273, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%275 = torch.aten.view %270, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%276 = torch.aten.permute %275, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%277 = torch.aten.transpose.int %274, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32>
%278 = torch.aten.broadcast_to %272, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%279 = torch.aten.view %278, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%280 = torch.aten.broadcast_to %277, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32>
%281 = torch.aten.view %280, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32>
%282 = torch.aten.bmm %279, %281 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32>
%283 = torch.aten.view %282, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%284 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64>
%285 = torch.aten.to.dtype %284, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%286 = torch.aten.broadcast_to %285, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%287 = torch.aten.div.Tensor %283, %286 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%288 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%289 = torch.aten.as_strided %288, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%290 = torch.aten.as_strided %289, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8>
%291 = torch.aten.as_strided %290, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8>
%292 = torch.aten.to.dtype %291, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1>
%293 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64>
%294 = torch.aten.to.dtype %293, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%295 = torch.aten.broadcast_to %294, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%296 = torch.aten.where.self %292, %287, %295 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%values_2, %indices_3 = torch.aten.max.dim %296, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64>
%297 = torch.aten.sub.Tensor %296, %values_2, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%298 = torch.aten.exp %297 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32>
%299 = torch.aten.sum.dim_IntList %298, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32>
%300 = torch.aten.div.Tensor %298, %299 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32>
%301 = torch.aten.broadcast_to %300, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%302 = torch.aten.view %301, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32>
%303 = torch.aten.broadcast_to %276, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%304 = torch.aten.view %303, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%305 = torch.aten.bmm %302, %304 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32>
%306 = torch.aten.view %305, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%307 = torch.aten.permute %306, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%308 = torch.aten.clone %307, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32>
%309 = torch.aten.view %308, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%310 = torch.aten.view %309, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%311 = torch.aten.mm %310, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32>
%312 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%313 = torch.aten.add.Tensor %312, %311, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%314 = torch.aten.view %313, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%315 = torch.aten.add.Tensor %314, %246, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%316 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%317 = torch.aten.sum.dim_IntList %315, %316, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%318 = torch.aten.div.Scalar %317, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%319 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%320 = torch.aten.broadcast_to %318, %319 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%321 = torch.aten.sub.Tensor %315, %320, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%322 = torch.aten.mul.Tensor %321, %321 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%323 = torch.aten.sum.dim_IntList %322, %316, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%324 = torch.aten.div.Scalar %323, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%325 = torch.aten.add.Scalar %324, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%326 = torch.aten.rsqrt %325 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%327 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%328 = torch.aten.broadcast_to %326, %327 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%329 = torch.aten.mul.Tensor %321, %328 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%330 = torch.aten.mul.Tensor %329, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%331 = torch.aten.add.Tensor %330, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%332 = torch.aten.view %331, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%333 = torch.aten.mm %332, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32>
%334 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32>
%335 = torch.aten.add.Tensor %334, %333, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32>
%336 = torch.aten.view %335, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32>
%337 = torch.aten.mul.Scalar %336, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%338 = torch.aten.pow.Tensor_Scalar %336, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%339 = torch.aten.mul.Scalar %338, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%340 = torch.aten.add.Tensor %336, %339, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%341 = torch.aten.mul.Scalar %340, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%342 = torch.aten.tanh %341 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%343 = torch.aten.add.Scalar %342, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%344 = torch.aten.mul.Tensor %337, %343 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%345 = torch.aten.view %344, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32>
%346 = torch.aten.mm %345, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32>
%347 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%348 = torch.aten.add.Tensor %347, %346, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%349 = torch.aten.view %348, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%350 = torch.aten.add.Tensor %315, %349, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%351 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%352 = torch.aten.sum.dim_IntList %350, %351, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%353 = torch.aten.div.Scalar %352, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%354 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%355 = torch.aten.broadcast_to %353, %354 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%356 = torch.aten.sub.Tensor %350, %355, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%357 = torch.aten.mul.Tensor %356, %356 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%358 = torch.aten.sum.dim_IntList %357, %351, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%359 = torch.aten.div.Scalar %358, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%360 = torch.aten.add.Scalar %359, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%361 = torch.aten.rsqrt %360 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%362 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%363 = torch.aten.broadcast_to %361, %362 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%364 = torch.aten.mul.Tensor %356, %363 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%365 = torch.aten.mul.Tensor %364, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%366 = torch.aten.add.Tensor %365, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%367 = torch.aten.view %366, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%368 = torch.aten.mm %367, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32>
%369 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32>
%370 = torch.aten.add.Tensor %369, %368, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32>
%371 = torch.aten.view %370, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32>
%372 = torch.aten.as_strided %371, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%373 = torch.aten.as_strided %371, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%374 = torch.aten.as_strided %371, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%375 = torch.aten.view %372, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%376 = torch.aten.permute %375, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%377 = torch.aten.view %373, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%378 = torch.aten.permute %377, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%379 = torch.aten.view %374, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%380 = torch.aten.permute %379, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%381 = torch.aten.transpose.int %378, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32>
%382 = torch.aten.broadcast_to %376, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%383 = torch.aten.view %382, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%384 = torch.aten.broadcast_to %381, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32>
%385 = torch.aten.view %384, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32>
%386 = torch.aten.bmm %383, %385 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32>
%387 = torch.aten.view %386, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%388 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64>
%389 = torch.aten.to.dtype %388, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%390 = torch.aten.broadcast_to %389, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%391 = torch.aten.div.Tensor %387, %390 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%392 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%393 = torch.aten.as_strided %392, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%394 = torch.aten.as_strided %393, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8>
%395 = torch.aten.as_strided %394, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8>
%396 = torch.aten.to.dtype %395, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1>
%397 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64>
%398 = torch.aten.to.dtype %397, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%399 = torch.aten.broadcast_to %398, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%400 = torch.aten.where.self %396, %391, %399 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%values_4, %indices_5 = torch.aten.max.dim %400, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64>
%401 = torch.aten.sub.Tensor %400, %values_4, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%402 = torch.aten.exp %401 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32>
%403 = torch.aten.sum.dim_IntList %402, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32>
%404 = torch.aten.div.Tensor %402, %403 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32>
%405 = torch.aten.broadcast_to %404, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%406 = torch.aten.view %405, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32>
%407 = torch.aten.broadcast_to %380, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%408 = torch.aten.view %407, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%409 = torch.aten.bmm %406, %408 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32>
%410 = torch.aten.view %409, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%411 = torch.aten.permute %410, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%412 = torch.aten.clone %411, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32>
%413 = torch.aten.view %412, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%414 = torch.aten.view %413, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%415 = torch.aten.mm %414, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32>
%416 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%417 = torch.aten.add.Tensor %416, %415, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%418 = torch.aten.view %417, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%419 = torch.aten.add.Tensor %418, %350, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%420 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%421 = torch.aten.sum.dim_IntList %419, %420, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%422 = torch.aten.div.Scalar %421, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%423 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%424 = torch.aten.broadcast_to %422, %423 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%425 = torch.aten.sub.Tensor %419, %424, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%426 = torch.aten.mul.Tensor %425, %425 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%427 = torch.aten.sum.dim_IntList %426, %420, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%428 = torch.aten.div.Scalar %427, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%429 = torch.aten.add.Scalar %428, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%430 = torch.aten.rsqrt %429 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%431 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%432 = torch.aten.broadcast_to %430, %431 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%433 = torch.aten.mul.Tensor %425, %432 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%434 = torch.aten.mul.Tensor %433, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%435 = torch.aten.add.Tensor %434, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%436 = torch.aten.view %435, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%437 = torch.aten.mm %436, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32>
%438 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32>
%439 = torch.aten.add.Tensor %438, %437, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32>
%440 = torch.aten.view %439, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32>
%441 = torch.aten.mul.Scalar %440, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%442 = torch.aten.pow.Tensor_Scalar %440, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%443 = torch.aten.mul.Scalar %442, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%444 = torch.aten.add.Tensor %440, %443, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%445 = torch.aten.mul.Scalar %444, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%446 = torch.aten.tanh %445 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%447 = torch.aten.add.Scalar %446, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%448 = torch.aten.mul.Tensor %441, %447 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%449 = torch.aten.view %448, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32>
%450 = torch.aten.mm %449, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32>
%451 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%452 = torch.aten.add.Tensor %451, %450, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%453 = torch.aten.view %452, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%454 = torch.aten.add.Tensor %419, %453, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%455 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%456 = torch.aten.sum.dim_IntList %454, %455, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%457 = torch.aten.div.Scalar %456, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%458 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%459 = torch.aten.broadcast_to %457, %458 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%460 = torch.aten.sub.Tensor %454, %459, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%461 = torch.aten.mul.Tensor %460, %460 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%462 = torch.aten.sum.dim_IntList %461, %455, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%463 = torch.aten.div.Scalar %462, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%464 = torch.aten.add.Scalar %463, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%465 = torch.aten.rsqrt %464 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%466 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%467 = torch.aten.broadcast_to %465, %466 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%468 = torch.aten.mul.Tensor %460, %467 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%469 = torch.aten.mul.Tensor %468, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%470 = torch.aten.add.Tensor %469, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%471 = torch.aten.view %470, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%472 = torch.aten.mm %471, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32>
%473 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32>
%474 = torch.aten.add.Tensor %473, %472, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32>
%475 = torch.aten.view %474, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32>
%476 = torch.aten.as_strided %475, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%477 = torch.aten.as_strided %475, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%478 = torch.aten.as_strided %475, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%479 = torch.aten.view %476, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%480 = torch.aten.permute %479, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%481 = torch.aten.view %477, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%482 = torch.aten.permute %481, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%483 = torch.aten.view %478, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%484 = torch.aten.permute %483, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%485 = torch.aten.transpose.int %482, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32>
%486 = torch.aten.broadcast_to %480, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%487 = torch.aten.view %486, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%488 = torch.aten.broadcast_to %485, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32>
%489 = torch.aten.view %488, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32>
%490 = torch.aten.bmm %487, %489 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32>
%491 = torch.aten.view %490, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%492 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64>
%493 = torch.aten.to.dtype %492, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%494 = torch.aten.broadcast_to %493, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%495 = torch.aten.div.Tensor %491, %494 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%496 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%497 = torch.aten.as_strided %496, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%498 = torch.aten.as_strided %497, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8>
%499 = torch.aten.as_strided %498, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8>
%500 = torch.aten.to.dtype %499, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1>
%501 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64>
%502 = torch.aten.to.dtype %501, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%503 = torch.aten.broadcast_to %502, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%504 = torch.aten.where.self %500, %495, %503 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%values_6, %indices_7 = torch.aten.max.dim %504, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64>
%505 = torch.aten.sub.Tensor %504, %values_6, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%506 = torch.aten.exp %505 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32>
%507 = torch.aten.sum.dim_IntList %506, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32>
%508 = torch.aten.div.Tensor %506, %507 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32>
%509 = torch.aten.broadcast_to %508, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%510 = torch.aten.view %509, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32>
%511 = torch.aten.broadcast_to %484, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%512 = torch.aten.view %511, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%513 = torch.aten.bmm %510, %512 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32>
%514 = torch.aten.view %513, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%515 = torch.aten.permute %514, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%516 = torch.aten.clone %515, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32>
%517 = torch.aten.view %516, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%518 = torch.aten.view %517, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%519 = torch.aten.mm %518, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32>
%520 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%521 = torch.aten.add.Tensor %520, %519, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%522 = torch.aten.view %521, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%523 = torch.aten.add.Tensor %522, %454, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%524 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%525 = torch.aten.sum.dim_IntList %523, %524, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%526 = torch.aten.div.Scalar %525, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%527 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%528 = torch.aten.broadcast_to %526, %527 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%529 = torch.aten.sub.Tensor %523, %528, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%530 = torch.aten.mul.Tensor %529, %529 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%531 = torch.aten.sum.dim_IntList %530, %524, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%532 = torch.aten.div.Scalar %531, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%533 = torch.aten.add.Scalar %532, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%534 = torch.aten.rsqrt %533 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%535 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%536 = torch.aten.broadcast_to %534, %535 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%537 = torch.aten.mul.Tensor %529, %536 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%538 = torch.aten.mul.Tensor %537, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%539 = torch.aten.add.Tensor %538, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%540 = torch.aten.view %539, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%541 = torch.aten.mm %540, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32>
%542 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32>
%543 = torch.aten.add.Tensor %542, %541, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32>
%544 = torch.aten.view %543, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32>
%545 = torch.aten.mul.Scalar %544, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%546 = torch.aten.pow.Tensor_Scalar %544, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%547 = torch.aten.mul.Scalar %546, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%548 = torch.aten.add.Tensor %544, %547, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%549 = torch.aten.mul.Scalar %548, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%550 = torch.aten.tanh %549 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%551 = torch.aten.add.Scalar %550, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%552 = torch.aten.mul.Tensor %545, %551 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%553 = torch.aten.view %552, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32>
%554 = torch.aten.mm %553, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32>
%555 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%556 = torch.aten.add.Tensor %555, %554, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%557 = torch.aten.view %556, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%558 = torch.aten.add.Tensor %523, %557, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%559 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%560 = torch.aten.sum.dim_IntList %558, %559, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%561 = torch.aten.div.Scalar %560, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%562 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%563 = torch.aten.broadcast_to %561, %562 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%564 = torch.aten.sub.Tensor %558, %563, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%565 = torch.aten.mul.Tensor %564, %564 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%566 = torch.aten.sum.dim_IntList %565, %559, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%567 = torch.aten.div.Scalar %566, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%568 = torch.aten.add.Scalar %567, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%569 = torch.aten.rsqrt %568 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%570 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%571 = torch.aten.broadcast_to %569, %570 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%572 = torch.aten.mul.Tensor %564, %571 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%573 = torch.aten.mul.Tensor %572, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%574 = torch.aten.add.Tensor %573, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%575 = torch.aten.view %574, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%576 = torch.aten.mm %575, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32>
%577 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32>
%578 = torch.aten.add.Tensor %577, %576, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32>
%579 = torch.aten.view %578, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32>
%580 = torch.aten.as_strided %579, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%581 = torch.aten.as_strided %579, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%582 = torch.aten.as_strided %579, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%583 = torch.aten.view %580, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%584 = torch.aten.permute %583, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%585 = torch.aten.view %581, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%586 = torch.aten.permute %585, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%587 = torch.aten.view %582, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%588 = torch.aten.permute %587, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%589 = torch.aten.transpose.int %586, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32>
%590 = torch.aten.broadcast_to %584, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%591 = torch.aten.view %590, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%592 = torch.aten.broadcast_to %589, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32>
%593 = torch.aten.view %592, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32>
%594 = torch.aten.bmm %591, %593 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32>
%595 = torch.aten.view %594, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%596 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64>
%597 = torch.aten.to.dtype %596, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%598 = torch.aten.broadcast_to %597, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%599 = torch.aten.div.Tensor %595, %598 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%600 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%601 = torch.aten.as_strided %600, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8>
%602 = torch.aten.as_strided %601, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8>
%603 = torch.aten.as_strided %602, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8>
%604 = torch.aten.to.dtype %603, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1>
%605 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64>
%606 = torch.aten.to.dtype %605, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
%607 = torch.aten.broadcast_to %606, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
%608 = torch.aten.where.self %604, %599, %607 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32>
%values_8, %indices_9 = torch.aten.max.dim %608, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64>
%609 = torch.aten.sub.Tensor %608, %values_8, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
%610 = torch.aten.exp %609 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32>
%611 = torch.aten.sum.dim_IntList %610, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32>
%612 = torch.aten.div.Tensor %610, %611 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32>
%613 = torch.aten.broadcast_to %612, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32>
%614 = torch.aten.view %613, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32>
%615 = torch.aten.broadcast_to %588, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%616 = torch.aten.view %615, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32>
%617 = torch.aten.bmm %614, %616 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32>
%618 = torch.aten.view %617, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32>
%619 = torch.aten.permute %618, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32>
%620 = torch.aten.clone %619, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32>
%621 = torch.aten.view %620, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%622 = torch.aten.view %621, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%623 = torch.aten.mm %622, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32>
%624 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%625 = torch.aten.add.Tensor %624, %623, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%626 = torch.aten.view %625, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%627 = torch.aten.add.Tensor %626, %558, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%628 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%629 = torch.aten.sum.dim_IntList %627, %628, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%630 = torch.aten.div.Scalar %629, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%631 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%632 = torch.aten.broadcast_to %630, %631 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%633 = torch.aten.sub.Tensor %627, %632, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%634 = torch.aten.mul.Tensor %633, %633 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%635 = torch.aten.sum.dim_IntList %634, %628, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%636 = torch.aten.div.Scalar %635, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%637 = torch.aten.add.Scalar %636, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%638 = torch.aten.rsqrt %637 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%639 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%640 = torch.aten.broadcast_to %638, %639 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%641 = torch.aten.mul.Tensor %633, %640 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%642 = torch.aten.mul.Tensor %641, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%643 = torch.aten.add.Tensor %642, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%644 = torch.aten.view %643, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%645 = torch.aten.mm %644, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32>
%646 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32>
%647 = torch.aten.add.Tensor %646, %645, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32>
%648 = torch.aten.view %647, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32>
%649 = torch.aten.mul.Scalar %648, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%650 = torch.aten.pow.Tensor_Scalar %648, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%651 = torch.aten.mul.Scalar %650, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%652 = torch.aten.add.Tensor %648, %651, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%653 = torch.aten.mul.Scalar %652, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32>
%654 = torch.aten.tanh %653 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%655 = torch.aten.add.Scalar %654, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32>
%656 = torch.aten.mul.Tensor %649, %655 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32>
%657 = torch.aten.view %656, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32>
%658 = torch.aten.mm %657, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32>
%659 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32>
%660 = torch.aten.add.Tensor %659, %658, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32>
%661 = torch.aten.view %660, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%662 = torch.aten.add.Tensor %627, %661, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%663 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%664 = torch.aten.sum.dim_IntList %662, %663, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%665 = torch.aten.div.Scalar %664, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%666 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%667 = torch.aten.broadcast_to %665, %666 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%668 = torch.aten.sub.Tensor %662, %667, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%669 = torch.aten.mul.Tensor %668, %668 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%670 = torch.aten.sum.dim_IntList %669, %663, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32>
%671 = torch.aten.div.Scalar %670, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32>
%672 = torch.aten.add.Scalar %671, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32>
%673 = torch.aten.rsqrt %672 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32>
%674 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%675 = torch.aten.broadcast_to %673, %674 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%676 = torch.aten.mul.Tensor %668, %675 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32>
%677 = torch.aten.mul.Tensor %676, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32>
%678 = torch.aten.add.Tensor %677, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32>
%679 = torch.aten.view %678, %42 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32>
%680 = torch.aten.transpose.int %0, %int0, %int1 : !torch.vtensor<[2,768],f32>, !torch.int, !torch.int -> !torch.vtensor<[768,2],f32>
%681 = torch.prim.ListConstruct %int128, %int768 : (!torch.int, !torch.int) -> !torch.list<int>
%682 = torch.aten.view %679, %681 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32>
%683 = torch.aten.mm %682, %680 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2],f32> -> !torch.vtensor<[128,2],f32>
%684 = torch.prim.ListConstruct %int1, %int128, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%685 = torch.aten.view %683, %684 : !torch.vtensor<[128,2],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2],f32>
%686 = torch.aten.arange.start_step %int0, %int1, %int1, %none, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],si64>
%687 = torch.aten.slice.Tensor %685, %int1, %int-1, %int0, %int1 : !torch.vtensor<[1,128,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,2],f32>
%688 = torch.aten.squeeze.dim %687, %int1 : !torch.vtensor<[1,1,2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
%689 = torch.prim.ListConstruct %686 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor>
%690 = torch.aten.index.Tensor %688, %689 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32>
return %690 : !torch.vtensor<[1,2],f32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment