-
-
Save vivekkhandelwal1/3f912e3a9ba62ce2895533185f837b44 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module attributes {torch.debug_module_name = "_lambda"} { | |
func.func @forward(%arg0: !torch.vtensor<[1,128],si64>) -> !torch.vtensor<[1,2],f32> { | |
%int1 = torch.constant.int 1 | |
%int0 = torch.constant.int 0 | |
%int2 = torch.constant.int 2 | |
%true = torch.constant.bool true | |
%none = torch.constant.none | |
%int768 = torch.constant.int 768 | |
%int128 = torch.constant.int 128 | |
%int-1 = torch.constant.int -1 | |
%int6 = torch.constant.int 6 | |
%false = torch.constant.bool false | |
%0 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2x768xf32>) : !torch.vtensor<[2,768],f32> | |
%1 = torch.vtensor.literal(dense_resource<__elided__> : tensor<3072x768xf32>) : !torch.vtensor<[3072,768],f32> | |
%2 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768x3072xf32>) : !torch.vtensor<[768,3072],f32> | |
%3 = torch.vtensor.literal(dense_resource<__elided__> : tensor<3072xf32>) : !torch.vtensor<[3072],f32> | |
%4 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768x768xf32>) : !torch.vtensor<[768,768],f32> | |
%5 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1x1x1024x1024xui8>) : !torch.vtensor<[1,1,1024,1024],ui8> | |
%6 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768x2304xf32>) : !torch.vtensor<[768,2304],f32> | |
%7 = torch.vtensor.literal(dense_resource<__elided__> : tensor<2304xf32>) : !torch.vtensor<[2304],f32> | |
%8 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768xf32>) : !torch.vtensor<[768],f32> | |
%9 = torch.vtensor.literal(dense_resource<__elided__> : tensor<1024x768xf32>) : !torch.vtensor<[1024,768],f32> | |
%10 = torch.vtensor.literal(dense_resource<__elided__> : tensor<50257x768xf32>) : !torch.vtensor<[50257,768],f32> | |
%int-2 = torch.constant.int -2 | |
%int11 = torch.constant.int 11 | |
%float-3.402820e38 = torch.constant.float -3.4028234663852886E+38 | |
%int4 = torch.constant.int 4 | |
%float1.000000e-05 = torch.constant.float 1.000000e-05 | |
%int2304 = torch.constant.int 2304 | |
%int294912 = torch.constant.int 294912 | |
%int1536 = torch.constant.int 1536 | |
%int12 = torch.constant.int 12 | |
%int64 = torch.constant.int 64 | |
%int3 = torch.constant.int 3 | |
%float8.000000e00 = torch.constant.float 8.000000e+00 | |
%int1024 = torch.constant.int 1024 | |
%int1048576 = torch.constant.int 1048576 | |
%int3072 = torch.constant.int 3072 | |
%float5.000000e-01 = torch.constant.float 5.000000e-01 | |
%float3.000000e00 = torch.constant.float 3.000000e+00 | |
%float4.471500e-02 = torch.constant.float 4.471500e-02 | |
%float7.978850e-01 = torch.constant.float 0.79788456080286541 | |
%float1.000000e00 = torch.constant.float 1.000000e+00 | |
%cpu = torch.constant.device "cpu" | |
%11 = torch.prim.ListConstruct %int-1, %int128 : (!torch.int, !torch.int) -> !torch.list<int> | |
%12 = torch.aten.view %arg0, %11 : !torch.vtensor<[1,128],si64>, !torch.list<int> -> !torch.vtensor<[1,128],si64> | |
%13 = torch.aten.arange.start_step %int0, %int128, %int1, %int4, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> | |
%14 = torch.aten.unsqueeze %13, %int0 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[1,128],si64> | |
%15 = torch.aten.view %14, %11 : !torch.vtensor<[1,128],si64>, !torch.list<int> -> !torch.vtensor<[1,128],si64> | |
%16 = torch.aten.embedding %10, %12, %int-1, %false, %false : !torch.vtensor<[50257,768],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,768],f32> | |
%17 = torch.aten.embedding %9, %15, %int-1, %false, %false : !torch.vtensor<[1024,768],f32>, !torch.vtensor<[1,128],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,128,768],f32> | |
%18 = torch.aten.add.Tensor %16, %17, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%19 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%20 = torch.aten.sum.dim_IntList %18, %19, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%21 = torch.aten.div.Scalar %20, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%22 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%23 = torch.aten.broadcast_to %21, %22 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%24 = torch.aten.sub.Tensor %18, %23, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%25 = torch.aten.mul.Tensor %24, %24 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%26 = torch.aten.sum.dim_IntList %25, %19, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%27 = torch.aten.div.Scalar %26, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%28 = torch.aten.add.Scalar %27, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%29 = torch.aten.rsqrt %28 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%30 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%31 = torch.aten.broadcast_to %29, %30 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%32 = torch.aten.mul.Tensor %24, %31 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%33 = torch.aten.mul.Tensor %32, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%34 = torch.aten.add.Tensor %33, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%35 = torch.prim.ListConstruct %int-1, %int768 : (!torch.int, !torch.int) -> !torch.list<int> | |
%36 = torch.aten.view %34, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%37 = torch.aten.mm %36, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32> | |
%38 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32> | |
%39 = torch.aten.add.Tensor %38, %37, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32> | |
%40 = torch.prim.ListConstruct %int1, %int128, %int2304 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%41 = torch.aten.view %39, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32> | |
%42 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%43 = torch.prim.ListConstruct %int294912, %int2304, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%44 = torch.aten.as_strided %41, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%45 = torch.aten.as_strided %41, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%46 = torch.aten.as_strided %41, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%47 = torch.prim.ListConstruct %int1, %int128, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%48 = torch.aten.view %44, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%49 = torch.prim.ListConstruct %int0, %int2, %int1, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%50 = torch.aten.permute %48, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%51 = torch.aten.view %45, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%52 = torch.aten.permute %51, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%53 = torch.aten.view %46, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%54 = torch.aten.permute %53, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%55 = torch.aten.transpose.int %52, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32> | |
%56 = torch.prim.ListConstruct %int1, %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%57 = torch.aten.broadcast_to %50, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%58 = torch.prim.ListConstruct %int12, %int128, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%59 = torch.aten.view %57, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%60 = torch.prim.ListConstruct %int1, %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%61 = torch.aten.broadcast_to %55, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32> | |
%62 = torch.prim.ListConstruct %int12, %int64, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%63 = torch.aten.view %61, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32> | |
%64 = torch.aten.bmm %59, %63 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32> | |
%65 = torch.prim.ListConstruct %int1, %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%66 = torch.aten.view %64, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%67 = torch.prim.ListConstruct : () -> !torch.list<int> | |
%68 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64> | |
%69 = torch.aten.to.dtype %68, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%70 = torch.aten.broadcast_to %69, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%71 = torch.aten.div.Tensor %66, %70 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%72 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%73 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%74 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%75 = torch.aten.as_strided %74, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%76 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%77 = torch.aten.as_strided %75, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8> | |
%78 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%79 = torch.aten.as_strided %77, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8> | |
%80 = torch.aten.to.dtype %79, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1> | |
%81 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64> | |
%82 = torch.aten.to.dtype %81, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%83 = torch.aten.broadcast_to %82, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%84 = torch.aten.where.self %80, %71, %83 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%85 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> | |
%values, %indices = torch.aten.max.dim %84, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64> | |
%86 = torch.aten.sub.Tensor %84, %values, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> | |
%87 = torch.aten.exp %86 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%88 = torch.aten.sum.dim_IntList %87, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32> | |
%89 = torch.aten.div.Tensor %87, %88 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%90 = torch.aten.broadcast_to %89, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%91 = torch.prim.ListConstruct %int12, %int128, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%92 = torch.aten.view %90, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32> | |
%93 = torch.aten.broadcast_to %54, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%94 = torch.aten.view %93, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%95 = torch.aten.bmm %92, %94 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32> | |
%96 = torch.aten.view %95, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%97 = torch.aten.permute %96, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%98 = torch.aten.clone %97, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32> | |
%99 = torch.aten.view %98, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%100 = torch.aten.view %99, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%101 = torch.aten.mm %100, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32> | |
%102 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%103 = torch.aten.add.Tensor %102, %101, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%104 = torch.aten.view %103, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%105 = torch.aten.add.Tensor %104, %18, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%106 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%107 = torch.aten.sum.dim_IntList %105, %106, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%108 = torch.aten.div.Scalar %107, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%109 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%110 = torch.aten.broadcast_to %108, %109 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%111 = torch.aten.sub.Tensor %105, %110, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%112 = torch.aten.mul.Tensor %111, %111 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%113 = torch.aten.sum.dim_IntList %112, %106, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%114 = torch.aten.div.Scalar %113, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%115 = torch.aten.add.Scalar %114, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%116 = torch.aten.rsqrt %115 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%117 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%118 = torch.aten.broadcast_to %116, %117 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%119 = torch.aten.mul.Tensor %111, %118 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%120 = torch.aten.mul.Tensor %119, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%121 = torch.aten.add.Tensor %120, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%122 = torch.aten.view %121, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%123 = torch.aten.mm %122, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32> | |
%124 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> | |
%125 = torch.aten.add.Tensor %124, %123, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32> | |
%126 = torch.prim.ListConstruct %int1, %int128, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%127 = torch.aten.view %125, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32> | |
%128 = torch.aten.mul.Scalar %127, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%129 = torch.aten.pow.Tensor_Scalar %127, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%130 = torch.aten.mul.Scalar %129, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%131 = torch.aten.add.Tensor %127, %130, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%132 = torch.aten.mul.Scalar %131, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%133 = torch.aten.tanh %132 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%134 = torch.aten.add.Scalar %133, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%135 = torch.aten.mul.Tensor %128, %134 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%136 = torch.prim.ListConstruct %int-1, %int3072 : (!torch.int, !torch.int) -> !torch.list<int> | |
%137 = torch.aten.view %135, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32> | |
%138 = torch.aten.mm %137, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32> | |
%139 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%140 = torch.aten.add.Tensor %139, %138, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%141 = torch.aten.view %140, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%142 = torch.aten.add.Tensor %105, %141, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%143 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%144 = torch.aten.sum.dim_IntList %142, %143, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%145 = torch.aten.div.Scalar %144, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%146 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%147 = torch.aten.broadcast_to %145, %146 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%148 = torch.aten.sub.Tensor %142, %147, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%149 = torch.aten.mul.Tensor %148, %148 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%150 = torch.aten.sum.dim_IntList %149, %143, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%151 = torch.aten.div.Scalar %150, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%152 = torch.aten.add.Scalar %151, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%153 = torch.aten.rsqrt %152 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%154 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%155 = torch.aten.broadcast_to %153, %154 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%156 = torch.aten.mul.Tensor %148, %155 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%157 = torch.aten.mul.Tensor %156, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%158 = torch.aten.add.Tensor %157, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%159 = torch.aten.view %158, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%160 = torch.aten.mm %159, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32> | |
%161 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32> | |
%162 = torch.aten.add.Tensor %161, %160, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32> | |
%163 = torch.aten.view %162, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32> | |
%164 = torch.aten.as_strided %163, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%165 = torch.aten.as_strided %163, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%166 = torch.aten.as_strided %163, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%167 = torch.aten.view %164, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%168 = torch.aten.permute %167, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%169 = torch.aten.view %165, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%170 = torch.aten.permute %169, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%171 = torch.aten.view %166, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%172 = torch.aten.permute %171, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%173 = torch.aten.transpose.int %170, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32> | |
%174 = torch.aten.broadcast_to %168, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%175 = torch.aten.view %174, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%176 = torch.aten.broadcast_to %173, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32> | |
%177 = torch.aten.view %176, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32> | |
%178 = torch.aten.bmm %175, %177 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32> | |
%179 = torch.aten.view %178, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%180 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64> | |
%181 = torch.aten.to.dtype %180, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%182 = torch.aten.broadcast_to %181, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%183 = torch.aten.div.Tensor %179, %182 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%184 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%185 = torch.aten.as_strided %184, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%186 = torch.aten.as_strided %185, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8> | |
%187 = torch.aten.as_strided %186, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8> | |
%188 = torch.aten.to.dtype %187, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1> | |
%189 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64> | |
%190 = torch.aten.to.dtype %189, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%191 = torch.aten.broadcast_to %190, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%192 = torch.aten.where.self %188, %183, %191 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%values_0, %indices_1 = torch.aten.max.dim %192, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64> | |
%193 = torch.aten.sub.Tensor %192, %values_0, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> | |
%194 = torch.aten.exp %193 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%195 = torch.aten.sum.dim_IntList %194, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32> | |
%196 = torch.aten.div.Tensor %194, %195 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%197 = torch.aten.broadcast_to %196, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%198 = torch.aten.view %197, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32> | |
%199 = torch.aten.broadcast_to %172, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%200 = torch.aten.view %199, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%201 = torch.aten.bmm %198, %200 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32> | |
%202 = torch.aten.view %201, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%203 = torch.aten.permute %202, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%204 = torch.aten.clone %203, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32> | |
%205 = torch.aten.view %204, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%206 = torch.aten.view %205, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%207 = torch.aten.mm %206, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32> | |
%208 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%209 = torch.aten.add.Tensor %208, %207, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%210 = torch.aten.view %209, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%211 = torch.aten.add.Tensor %210, %142, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%212 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%213 = torch.aten.sum.dim_IntList %211, %212, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%214 = torch.aten.div.Scalar %213, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%215 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%216 = torch.aten.broadcast_to %214, %215 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%217 = torch.aten.sub.Tensor %211, %216, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%218 = torch.aten.mul.Tensor %217, %217 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%219 = torch.aten.sum.dim_IntList %218, %212, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%220 = torch.aten.div.Scalar %219, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%221 = torch.aten.add.Scalar %220, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%222 = torch.aten.rsqrt %221 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%223 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%224 = torch.aten.broadcast_to %222, %223 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%225 = torch.aten.mul.Tensor %217, %224 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%226 = torch.aten.mul.Tensor %225, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%227 = torch.aten.add.Tensor %226, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%228 = torch.aten.view %227, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%229 = torch.aten.mm %228, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32> | |
%230 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> | |
%231 = torch.aten.add.Tensor %230, %229, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32> | |
%232 = torch.aten.view %231, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32> | |
%233 = torch.aten.mul.Scalar %232, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%234 = torch.aten.pow.Tensor_Scalar %232, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%235 = torch.aten.mul.Scalar %234, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%236 = torch.aten.add.Tensor %232, %235, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%237 = torch.aten.mul.Scalar %236, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%238 = torch.aten.tanh %237 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%239 = torch.aten.add.Scalar %238, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%240 = torch.aten.mul.Tensor %233, %239 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%241 = torch.aten.view %240, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32> | |
%242 = torch.aten.mm %241, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32> | |
%243 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%244 = torch.aten.add.Tensor %243, %242, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%245 = torch.aten.view %244, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%246 = torch.aten.add.Tensor %211, %245, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%247 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%248 = torch.aten.sum.dim_IntList %246, %247, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%249 = torch.aten.div.Scalar %248, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%250 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%251 = torch.aten.broadcast_to %249, %250 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%252 = torch.aten.sub.Tensor %246, %251, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%253 = torch.aten.mul.Tensor %252, %252 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%254 = torch.aten.sum.dim_IntList %253, %247, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%255 = torch.aten.div.Scalar %254, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%256 = torch.aten.add.Scalar %255, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%257 = torch.aten.rsqrt %256 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%258 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%259 = torch.aten.broadcast_to %257, %258 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%260 = torch.aten.mul.Tensor %252, %259 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%261 = torch.aten.mul.Tensor %260, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%262 = torch.aten.add.Tensor %261, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%263 = torch.aten.view %262, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%264 = torch.aten.mm %263, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32> | |
%265 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32> | |
%266 = torch.aten.add.Tensor %265, %264, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32> | |
%267 = torch.aten.view %266, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32> | |
%268 = torch.aten.as_strided %267, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%269 = torch.aten.as_strided %267, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%270 = torch.aten.as_strided %267, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%271 = torch.aten.view %268, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%272 = torch.aten.permute %271, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%273 = torch.aten.view %269, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%274 = torch.aten.permute %273, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%275 = torch.aten.view %270, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%276 = torch.aten.permute %275, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%277 = torch.aten.transpose.int %274, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32> | |
%278 = torch.aten.broadcast_to %272, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%279 = torch.aten.view %278, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%280 = torch.aten.broadcast_to %277, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32> | |
%281 = torch.aten.view %280, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32> | |
%282 = torch.aten.bmm %279, %281 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32> | |
%283 = torch.aten.view %282, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%284 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64> | |
%285 = torch.aten.to.dtype %284, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%286 = torch.aten.broadcast_to %285, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%287 = torch.aten.div.Tensor %283, %286 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%288 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%289 = torch.aten.as_strided %288, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%290 = torch.aten.as_strided %289, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8> | |
%291 = torch.aten.as_strided %290, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8> | |
%292 = torch.aten.to.dtype %291, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1> | |
%293 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64> | |
%294 = torch.aten.to.dtype %293, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%295 = torch.aten.broadcast_to %294, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%296 = torch.aten.where.self %292, %287, %295 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%values_2, %indices_3 = torch.aten.max.dim %296, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64> | |
%297 = torch.aten.sub.Tensor %296, %values_2, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> | |
%298 = torch.aten.exp %297 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%299 = torch.aten.sum.dim_IntList %298, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32> | |
%300 = torch.aten.div.Tensor %298, %299 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%301 = torch.aten.broadcast_to %300, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%302 = torch.aten.view %301, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32> | |
%303 = torch.aten.broadcast_to %276, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%304 = torch.aten.view %303, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%305 = torch.aten.bmm %302, %304 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32> | |
%306 = torch.aten.view %305, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%307 = torch.aten.permute %306, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%308 = torch.aten.clone %307, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32> | |
%309 = torch.aten.view %308, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%310 = torch.aten.view %309, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%311 = torch.aten.mm %310, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32> | |
%312 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%313 = torch.aten.add.Tensor %312, %311, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%314 = torch.aten.view %313, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%315 = torch.aten.add.Tensor %314, %246, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%316 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%317 = torch.aten.sum.dim_IntList %315, %316, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%318 = torch.aten.div.Scalar %317, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%319 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%320 = torch.aten.broadcast_to %318, %319 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%321 = torch.aten.sub.Tensor %315, %320, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%322 = torch.aten.mul.Tensor %321, %321 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%323 = torch.aten.sum.dim_IntList %322, %316, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%324 = torch.aten.div.Scalar %323, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%325 = torch.aten.add.Scalar %324, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%326 = torch.aten.rsqrt %325 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%327 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%328 = torch.aten.broadcast_to %326, %327 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%329 = torch.aten.mul.Tensor %321, %328 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%330 = torch.aten.mul.Tensor %329, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%331 = torch.aten.add.Tensor %330, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%332 = torch.aten.view %331, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%333 = torch.aten.mm %332, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32> | |
%334 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> | |
%335 = torch.aten.add.Tensor %334, %333, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32> | |
%336 = torch.aten.view %335, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32> | |
%337 = torch.aten.mul.Scalar %336, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%338 = torch.aten.pow.Tensor_Scalar %336, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%339 = torch.aten.mul.Scalar %338, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%340 = torch.aten.add.Tensor %336, %339, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%341 = torch.aten.mul.Scalar %340, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%342 = torch.aten.tanh %341 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%343 = torch.aten.add.Scalar %342, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%344 = torch.aten.mul.Tensor %337, %343 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%345 = torch.aten.view %344, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32> | |
%346 = torch.aten.mm %345, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32> | |
%347 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%348 = torch.aten.add.Tensor %347, %346, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%349 = torch.aten.view %348, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%350 = torch.aten.add.Tensor %315, %349, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%351 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%352 = torch.aten.sum.dim_IntList %350, %351, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%353 = torch.aten.div.Scalar %352, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%354 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%355 = torch.aten.broadcast_to %353, %354 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%356 = torch.aten.sub.Tensor %350, %355, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%357 = torch.aten.mul.Tensor %356, %356 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%358 = torch.aten.sum.dim_IntList %357, %351, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%359 = torch.aten.div.Scalar %358, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%360 = torch.aten.add.Scalar %359, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%361 = torch.aten.rsqrt %360 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%362 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%363 = torch.aten.broadcast_to %361, %362 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%364 = torch.aten.mul.Tensor %356, %363 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%365 = torch.aten.mul.Tensor %364, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%366 = torch.aten.add.Tensor %365, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%367 = torch.aten.view %366, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%368 = torch.aten.mm %367, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32> | |
%369 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32> | |
%370 = torch.aten.add.Tensor %369, %368, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32> | |
%371 = torch.aten.view %370, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32> | |
%372 = torch.aten.as_strided %371, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%373 = torch.aten.as_strided %371, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%374 = torch.aten.as_strided %371, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%375 = torch.aten.view %372, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%376 = torch.aten.permute %375, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%377 = torch.aten.view %373, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%378 = torch.aten.permute %377, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%379 = torch.aten.view %374, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%380 = torch.aten.permute %379, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%381 = torch.aten.transpose.int %378, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32> | |
%382 = torch.aten.broadcast_to %376, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%383 = torch.aten.view %382, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%384 = torch.aten.broadcast_to %381, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32> | |
%385 = torch.aten.view %384, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32> | |
%386 = torch.aten.bmm %383, %385 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32> | |
%387 = torch.aten.view %386, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%388 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64> | |
%389 = torch.aten.to.dtype %388, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%390 = torch.aten.broadcast_to %389, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%391 = torch.aten.div.Tensor %387, %390 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%392 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%393 = torch.aten.as_strided %392, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%394 = torch.aten.as_strided %393, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8> | |
%395 = torch.aten.as_strided %394, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8> | |
%396 = torch.aten.to.dtype %395, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1> | |
%397 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64> | |
%398 = torch.aten.to.dtype %397, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%399 = torch.aten.broadcast_to %398, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%400 = torch.aten.where.self %396, %391, %399 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%values_4, %indices_5 = torch.aten.max.dim %400, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64> | |
%401 = torch.aten.sub.Tensor %400, %values_4, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> | |
%402 = torch.aten.exp %401 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%403 = torch.aten.sum.dim_IntList %402, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32> | |
%404 = torch.aten.div.Tensor %402, %403 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%405 = torch.aten.broadcast_to %404, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%406 = torch.aten.view %405, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32> | |
%407 = torch.aten.broadcast_to %380, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%408 = torch.aten.view %407, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%409 = torch.aten.bmm %406, %408 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32> | |
%410 = torch.aten.view %409, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%411 = torch.aten.permute %410, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%412 = torch.aten.clone %411, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32> | |
%413 = torch.aten.view %412, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%414 = torch.aten.view %413, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%415 = torch.aten.mm %414, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32> | |
%416 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%417 = torch.aten.add.Tensor %416, %415, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%418 = torch.aten.view %417, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%419 = torch.aten.add.Tensor %418, %350, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%420 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%421 = torch.aten.sum.dim_IntList %419, %420, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%422 = torch.aten.div.Scalar %421, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%423 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%424 = torch.aten.broadcast_to %422, %423 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%425 = torch.aten.sub.Tensor %419, %424, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%426 = torch.aten.mul.Tensor %425, %425 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%427 = torch.aten.sum.dim_IntList %426, %420, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%428 = torch.aten.div.Scalar %427, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%429 = torch.aten.add.Scalar %428, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%430 = torch.aten.rsqrt %429 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%431 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%432 = torch.aten.broadcast_to %430, %431 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%433 = torch.aten.mul.Tensor %425, %432 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%434 = torch.aten.mul.Tensor %433, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%435 = torch.aten.add.Tensor %434, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%436 = torch.aten.view %435, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%437 = torch.aten.mm %436, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32> | |
%438 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> | |
%439 = torch.aten.add.Tensor %438, %437, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32> | |
%440 = torch.aten.view %439, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32> | |
%441 = torch.aten.mul.Scalar %440, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%442 = torch.aten.pow.Tensor_Scalar %440, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%443 = torch.aten.mul.Scalar %442, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%444 = torch.aten.add.Tensor %440, %443, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%445 = torch.aten.mul.Scalar %444, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%446 = torch.aten.tanh %445 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%447 = torch.aten.add.Scalar %446, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%448 = torch.aten.mul.Tensor %441, %447 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%449 = torch.aten.view %448, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32> | |
%450 = torch.aten.mm %449, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32> | |
%451 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%452 = torch.aten.add.Tensor %451, %450, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%453 = torch.aten.view %452, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%454 = torch.aten.add.Tensor %419, %453, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%455 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%456 = torch.aten.sum.dim_IntList %454, %455, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%457 = torch.aten.div.Scalar %456, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%458 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%459 = torch.aten.broadcast_to %457, %458 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%460 = torch.aten.sub.Tensor %454, %459, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%461 = torch.aten.mul.Tensor %460, %460 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%462 = torch.aten.sum.dim_IntList %461, %455, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%463 = torch.aten.div.Scalar %462, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%464 = torch.aten.add.Scalar %463, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%465 = torch.aten.rsqrt %464 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%466 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%467 = torch.aten.broadcast_to %465, %466 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%468 = torch.aten.mul.Tensor %460, %467 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%469 = torch.aten.mul.Tensor %468, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%470 = torch.aten.add.Tensor %469, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%471 = torch.aten.view %470, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%472 = torch.aten.mm %471, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32> | |
%473 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32> | |
%474 = torch.aten.add.Tensor %473, %472, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32> | |
%475 = torch.aten.view %474, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32> | |
%476 = torch.aten.as_strided %475, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%477 = torch.aten.as_strided %475, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%478 = torch.aten.as_strided %475, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%479 = torch.aten.view %476, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%480 = torch.aten.permute %479, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%481 = torch.aten.view %477, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%482 = torch.aten.permute %481, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%483 = torch.aten.view %478, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%484 = torch.aten.permute %483, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%485 = torch.aten.transpose.int %482, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32> | |
%486 = torch.aten.broadcast_to %480, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%487 = torch.aten.view %486, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%488 = torch.aten.broadcast_to %485, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32> | |
%489 = torch.aten.view %488, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32> | |
%490 = torch.aten.bmm %487, %489 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32> | |
%491 = torch.aten.view %490, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%492 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64> | |
%493 = torch.aten.to.dtype %492, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%494 = torch.aten.broadcast_to %493, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%495 = torch.aten.div.Tensor %491, %494 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%496 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%497 = torch.aten.as_strided %496, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%498 = torch.aten.as_strided %497, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8> | |
%499 = torch.aten.as_strided %498, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8> | |
%500 = torch.aten.to.dtype %499, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1> | |
%501 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64> | |
%502 = torch.aten.to.dtype %501, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%503 = torch.aten.broadcast_to %502, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%504 = torch.aten.where.self %500, %495, %503 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%values_6, %indices_7 = torch.aten.max.dim %504, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64> | |
%505 = torch.aten.sub.Tensor %504, %values_6, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> | |
%506 = torch.aten.exp %505 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%507 = torch.aten.sum.dim_IntList %506, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32> | |
%508 = torch.aten.div.Tensor %506, %507 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%509 = torch.aten.broadcast_to %508, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%510 = torch.aten.view %509, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32> | |
%511 = torch.aten.broadcast_to %484, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%512 = torch.aten.view %511, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%513 = torch.aten.bmm %510, %512 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32> | |
%514 = torch.aten.view %513, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%515 = torch.aten.permute %514, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%516 = torch.aten.clone %515, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32> | |
%517 = torch.aten.view %516, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%518 = torch.aten.view %517, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%519 = torch.aten.mm %518, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32> | |
%520 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%521 = torch.aten.add.Tensor %520, %519, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%522 = torch.aten.view %521, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%523 = torch.aten.add.Tensor %522, %454, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%524 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%525 = torch.aten.sum.dim_IntList %523, %524, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%526 = torch.aten.div.Scalar %525, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%527 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%528 = torch.aten.broadcast_to %526, %527 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%529 = torch.aten.sub.Tensor %523, %528, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%530 = torch.aten.mul.Tensor %529, %529 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%531 = torch.aten.sum.dim_IntList %530, %524, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%532 = torch.aten.div.Scalar %531, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%533 = torch.aten.add.Scalar %532, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%534 = torch.aten.rsqrt %533 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%535 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%536 = torch.aten.broadcast_to %534, %535 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%537 = torch.aten.mul.Tensor %529, %536 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%538 = torch.aten.mul.Tensor %537, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%539 = torch.aten.add.Tensor %538, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%540 = torch.aten.view %539, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%541 = torch.aten.mm %540, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32> | |
%542 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> | |
%543 = torch.aten.add.Tensor %542, %541, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32> | |
%544 = torch.aten.view %543, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32> | |
%545 = torch.aten.mul.Scalar %544, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%546 = torch.aten.pow.Tensor_Scalar %544, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%547 = torch.aten.mul.Scalar %546, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%548 = torch.aten.add.Tensor %544, %547, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%549 = torch.aten.mul.Scalar %548, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%550 = torch.aten.tanh %549 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%551 = torch.aten.add.Scalar %550, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%552 = torch.aten.mul.Tensor %545, %551 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%553 = torch.aten.view %552, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32> | |
%554 = torch.aten.mm %553, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32> | |
%555 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%556 = torch.aten.add.Tensor %555, %554, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%557 = torch.aten.view %556, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%558 = torch.aten.add.Tensor %523, %557, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%559 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%560 = torch.aten.sum.dim_IntList %558, %559, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%561 = torch.aten.div.Scalar %560, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%562 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%563 = torch.aten.broadcast_to %561, %562 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%564 = torch.aten.sub.Tensor %558, %563, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%565 = torch.aten.mul.Tensor %564, %564 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%566 = torch.aten.sum.dim_IntList %565, %559, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%567 = torch.aten.div.Scalar %566, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%568 = torch.aten.add.Scalar %567, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%569 = torch.aten.rsqrt %568 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%570 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%571 = torch.aten.broadcast_to %569, %570 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%572 = torch.aten.mul.Tensor %564, %571 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%573 = torch.aten.mul.Tensor %572, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%574 = torch.aten.add.Tensor %573, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%575 = torch.aten.view %574, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%576 = torch.aten.mm %575, %6 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2304],f32> -> !torch.vtensor<[128,2304],f32> | |
%577 = torch.aten.mul.Scalar %7, %int1 : !torch.vtensor<[2304],f32>, !torch.int -> !torch.vtensor<[2304],f32> | |
%578 = torch.aten.add.Tensor %577, %576, %int1 : !torch.vtensor<[2304],f32>, !torch.vtensor<[128,2304],f32>, !torch.int -> !torch.vtensor<[128,2304],f32> | |
%579 = torch.aten.view %578, %40 : !torch.vtensor<[128,2304],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2304],f32> | |
%580 = torch.aten.as_strided %579, %42, %43, %int0 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%581 = torch.aten.as_strided %579, %42, %43, %int768 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%582 = torch.aten.as_strided %579, %42, %43, %int1536 : !torch.vtensor<[1,128,2304],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%583 = torch.aten.view %580, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%584 = torch.aten.permute %583, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%585 = torch.aten.view %581, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%586 = torch.aten.permute %585, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%587 = torch.aten.view %582, %47 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%588 = torch.aten.permute %587, %49 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%589 = torch.aten.transpose.int %586, %int-1, %int-2 : !torch.vtensor<[1,12,128,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,64,128],f32> | |
%590 = torch.aten.broadcast_to %584, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%591 = torch.aten.view %590, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%592 = torch.aten.broadcast_to %589, %60 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,64,128],f32> | |
%593 = torch.aten.view %592, %62 : !torch.vtensor<[1,12,64,128],f32>, !torch.list<int> -> !torch.vtensor<[12,64,128],f32> | |
%594 = torch.aten.bmm %591, %593 : !torch.vtensor<[12,128,64],f32>, !torch.vtensor<[12,64,128],f32> -> !torch.vtensor<[12,128,128],f32> | |
%595 = torch.aten.view %594, %65 : !torch.vtensor<[12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%596 = torch.prim.NumToTensor.Scalar %float8.000000e00 : !torch.float -> !torch.vtensor<[],f64> | |
%597 = torch.aten.to.dtype %596, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%598 = torch.aten.broadcast_to %597, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%599 = torch.aten.div.Tensor %595, %598 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%600 = torch.aten.as_strided %5, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%601 = torch.aten.as_strided %600, %72, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1024,1024],ui8> | |
%602 = torch.aten.as_strided %601, %76, %73, %int0 : !torch.vtensor<[1,1,1024,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,1024],ui8> | |
%603 = torch.aten.as_strided %602, %78, %73, %int0 : !torch.vtensor<[1,1,128,1024],ui8>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,128,128],ui8> | |
%604 = torch.aten.to.dtype %603, %int11, %false, %false, %none : !torch.vtensor<[1,1,128,128],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,128,128],i1> | |
%605 = torch.prim.NumToTensor.Scalar %float-3.402820e38 : !torch.float -> !torch.vtensor<[],f64> | |
%606 = torch.aten.to.dtype %605, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> | |
%607 = torch.aten.broadcast_to %606, %67 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32> | |
%608 = torch.aten.where.self %604, %599, %607 : !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%values_8, %indices_9 = torch.aten.max.dim %608, %int-1, %true : !torch.vtensor<[1,12,128,128],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,12,128,1],f32>, !torch.vtensor<[1,12,128,1],si64> | |
%609 = torch.aten.sub.Tensor %608, %values_8, %int1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> | |
%610 = torch.aten.exp %609 : !torch.vtensor<[1,12,128,128],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%611 = torch.aten.sum.dim_IntList %610, %85, %true, %none : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,12,128,1],f32> | |
%612 = torch.aten.div.Tensor %610, %611 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,12,128,1],f32> -> !torch.vtensor<[1,12,128,128],f32> | |
%613 = torch.aten.broadcast_to %612, %65 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,128],f32> | |
%614 = torch.aten.view %613, %91 : !torch.vtensor<[1,12,128,128],f32>, !torch.list<int> -> !torch.vtensor<[12,128,128],f32> | |
%615 = torch.aten.broadcast_to %588, %56 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%616 = torch.aten.view %615, %58 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[12,128,64],f32> | |
%617 = torch.aten.bmm %614, %616 : !torch.vtensor<[12,128,128],f32>, !torch.vtensor<[12,128,64],f32> -> !torch.vtensor<[12,128,64],f32> | |
%618 = torch.aten.view %617, %56 : !torch.vtensor<[12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,12,128,64],f32> | |
%619 = torch.aten.permute %618, %49 : !torch.vtensor<[1,12,128,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,12,64],f32> | |
%620 = torch.aten.clone %619, %int0 : !torch.vtensor<[1,128,12,64],f32>, !torch.int -> !torch.vtensor<[1,128,12,64],f32> | |
%621 = torch.aten.view %620, %42 : !torch.vtensor<[1,128,12,64],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%622 = torch.aten.view %621, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%623 = torch.aten.mm %622, %4 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,768],f32> -> !torch.vtensor<[128,768],f32> | |
%624 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%625 = torch.aten.add.Tensor %624, %623, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%626 = torch.aten.view %625, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%627 = torch.aten.add.Tensor %626, %558, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%628 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%629 = torch.aten.sum.dim_IntList %627, %628, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%630 = torch.aten.div.Scalar %629, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%631 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%632 = torch.aten.broadcast_to %630, %631 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%633 = torch.aten.sub.Tensor %627, %632, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%634 = torch.aten.mul.Tensor %633, %633 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%635 = torch.aten.sum.dim_IntList %634, %628, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%636 = torch.aten.div.Scalar %635, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%637 = torch.aten.add.Scalar %636, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%638 = torch.aten.rsqrt %637 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%639 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%640 = torch.aten.broadcast_to %638, %639 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%641 = torch.aten.mul.Tensor %633, %640 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%642 = torch.aten.mul.Tensor %641, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%643 = torch.aten.add.Tensor %642, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%644 = torch.aten.view %643, %35 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%645 = torch.aten.mm %644, %2 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,3072],f32> -> !torch.vtensor<[128,3072],f32> | |
%646 = torch.aten.mul.Scalar %3, %int1 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> | |
%647 = torch.aten.add.Tensor %646, %645, %int1 : !torch.vtensor<[3072],f32>, !torch.vtensor<[128,3072],f32>, !torch.int -> !torch.vtensor<[128,3072],f32> | |
%648 = torch.aten.view %647, %126 : !torch.vtensor<[128,3072],f32>, !torch.list<int> -> !torch.vtensor<[1,128,3072],f32> | |
%649 = torch.aten.mul.Scalar %648, %float5.000000e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%650 = torch.aten.pow.Tensor_Scalar %648, %float3.000000e00 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%651 = torch.aten.mul.Scalar %650, %float4.471500e-02 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%652 = torch.aten.add.Tensor %648, %651, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32>, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%653 = torch.aten.mul.Scalar %652, %float7.978850e-01 : !torch.vtensor<[1,128,3072],f32>, !torch.float -> !torch.vtensor<[1,128,3072],f32> | |
%654 = torch.aten.tanh %653 : !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%655 = torch.aten.add.Scalar %654, %float1.000000e00, %int1 : !torch.vtensor<[1,128,3072],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,3072],f32> | |
%656 = torch.aten.mul.Tensor %649, %655 : !torch.vtensor<[1,128,3072],f32>, !torch.vtensor<[1,128,3072],f32> -> !torch.vtensor<[1,128,3072],f32> | |
%657 = torch.aten.view %656, %136 : !torch.vtensor<[1,128,3072],f32>, !torch.list<int> -> !torch.vtensor<[128,3072],f32> | |
%658 = torch.aten.mm %657, %1 : !torch.vtensor<[128,3072],f32>, !torch.vtensor<[3072,768],f32> -> !torch.vtensor<[128,768],f32> | |
%659 = torch.aten.mul.Scalar %8, %int1 : !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[768],f32> | |
%660 = torch.aten.add.Tensor %659, %658, %int1 : !torch.vtensor<[768],f32>, !torch.vtensor<[128,768],f32>, !torch.int -> !torch.vtensor<[128,768],f32> | |
%661 = torch.aten.view %660, %42 : !torch.vtensor<[128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%662 = torch.aten.add.Tensor %627, %661, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%663 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> | |
%664 = torch.aten.sum.dim_IntList %662, %663, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%665 = torch.aten.div.Scalar %664, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%666 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%667 = torch.aten.broadcast_to %665, %666 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%668 = torch.aten.sub.Tensor %662, %667, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%669 = torch.aten.mul.Tensor %668, %668 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%670 = torch.aten.sum.dim_IntList %669, %663, %true, %none : !torch.vtensor<[1,128,768],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,128,1],f32> | |
%671 = torch.aten.div.Scalar %670, %int768 : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%672 = torch.aten.add.Scalar %671, %float1.000000e-05, %int1 : !torch.vtensor<[1,128,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1],f32> | |
%673 = torch.aten.rsqrt %672 : !torch.vtensor<[1,128,1],f32> -> !torch.vtensor<[1,128,1],f32> | |
%674 = torch.prim.ListConstruct %int1, %int128, %int768 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%675 = torch.aten.broadcast_to %673, %674 : !torch.vtensor<[1,128,1],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%676 = torch.aten.mul.Tensor %668, %675 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[1,128,768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%677 = torch.aten.mul.Tensor %676, %8 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32> -> !torch.vtensor<[1,128,768],f32> | |
%678 = torch.aten.add.Tensor %677, %8, %int1 : !torch.vtensor<[1,128,768],f32>, !torch.vtensor<[768],f32>, !torch.int -> !torch.vtensor<[1,128,768],f32> | |
%679 = torch.aten.view %678, %42 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[1,128,768],f32> | |
%680 = torch.aten.transpose.int %0, %int0, %int1 : !torch.vtensor<[2,768],f32>, !torch.int, !torch.int -> !torch.vtensor<[768,2],f32> | |
%681 = torch.prim.ListConstruct %int128, %int768 : (!torch.int, !torch.int) -> !torch.list<int> | |
%682 = torch.aten.view %679, %681 : !torch.vtensor<[1,128,768],f32>, !torch.list<int> -> !torch.vtensor<[128,768],f32> | |
%683 = torch.aten.mm %682, %680 : !torch.vtensor<[128,768],f32>, !torch.vtensor<[768,2],f32> -> !torch.vtensor<[128,2],f32> | |
%684 = torch.prim.ListConstruct %int1, %int128, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> | |
%685 = torch.aten.view %683, %684 : !torch.vtensor<[128,2],f32>, !torch.list<int> -> !torch.vtensor<[1,128,2],f32> | |
%686 = torch.aten.arange.start_step %int0, %int1, %int1, %none, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],si64> | |
%687 = torch.aten.slice.Tensor %685, %int1, %int-1, %int0, %int1 : !torch.vtensor<[1,128,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,2],f32> | |
%688 = torch.aten.squeeze.dim %687, %int1 : !torch.vtensor<[1,1,2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> | |
%689 = torch.prim.ListConstruct %686 : (!torch.vtensor<[1],si64>) -> !torch.list<vtensor> | |
%690 = torch.aten.index.Tensor %688, %689 : !torch.vtensor<[1,2],f32>, !torch.list<vtensor> -> !torch.vtensor<[1,2],f32> | |
return %690 : !torch.vtensor<[1,2],f32> | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment