Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created March 1, 2024 07:44
Show Gist options
  • Save AmosLewis/a38200c0264065e7cf60ebe52e168fe7 to your computer and use it in GitHub Desktop.
Save AmosLewis/a38200c0264065e7cf60ebe52e168fe7 to your computer and use it in GitHub Desktop.
#loc = loc(unknown)
module attributes {torch.debug_module_name = "StdCorrectionModule"} {
func.func @forward(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[],f32> {
%int6 = torch.constant.int 6 loc(#loc)
%float2.000000e00 = torch.constant.float 2.000000e+00 loc(#loc)
%float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc)
%true = torch.constant.bool true loc(#loc)
%int2 = torch.constant.int 2 loc(#loc)
%int1 = torch.constant.int 1 loc(#loc)
%int0 = torch.constant.int 0 loc(#loc)
%false = torch.constant.bool false loc(#loc)
%int7 = torch.constant.int 7 loc(#loc2)
%none = torch.constant.none loc(#loc)
%0 = torch.aten.to.dtype %arg0, %int7, %false, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> loc(#loc2)
%1 = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc2)
%2 = torch.aten.sum.dim_IntList %0, %1, %true, %none : !torch.vtensor<[?,?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> loc(#loc2)
%3 = torch.aten.size.int %0, %int0 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc2)
%4 = torch.aten.mul.int %int1, %3 : !torch.int, !torch.int -> !torch.int loc(#loc2)
%5 = torch.aten.size.int %0, %int1 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc2)
%6 = torch.aten.mul.int %4, %5 : !torch.int, !torch.int -> !torch.int loc(#loc2)
%7 = torch.aten.size.int %0, %int2 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc2)
%8 = torch.aten.mul.int %6, %7 : !torch.int, !torch.int -> !torch.int loc(#loc2)
%9 = torch.aten.div.Scalar %2, %8 : !torch.vtensor<[1,1,1],f64>, !torch.int -> !torch.vtensor<[1,1,1],f64> loc(#loc2)
%10 = torch.aten.sub.Tensor %0, %9, %float1.000000e00 : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[1,1,1],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64> loc(#loc2)
%11 = torch.aten.mul.Tensor %10, %10 : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],f64> loc(#loc2)
%12 = torch.aten.sum.dim_IntList %11, %1, %false, %none : !torch.vtensor<[?,?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64> loc(#loc2)
%13 = torch.aten.size.int %0, %int0 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc2)
%14 = torch.aten.mul.int %int1, %13 : !torch.int, !torch.int -> !torch.int loc(#loc2)
%15 = torch.aten.size.int %0, %int1 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc2)
%16 = torch.aten.mul.int %14, %15 : !torch.int, !torch.int -> !torch.int loc(#loc2)
%17 = torch.aten.size.int %0, %int2 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc2)
%18 = torch.aten.mul.int %16, %17 : !torch.int, !torch.int -> !torch.int loc(#loc2)
%19 = torch.aten.Float.Scalar %18 : !torch.int -> !torch.float loc(#loc2)
%20 = torch.aten.add %19, %float1.000000e00 : !torch.float, !torch.float -> !torch.float loc(#loc2)
%21 = torch.aten.ge.float %20, %float2.000000e00 : !torch.float, !torch.float -> !torch.bool loc(#loc2)
torch.runtime.assert %21, "correction value should be less than or equal to productDimSize + 1" loc(#loc2)
%22 = torch.aten.sub.float %19, %float2.000000e00 : !torch.float, !torch.float -> !torch.float loc(#loc2)
%23 = torch.aten.div.Scalar %12, %22 : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64> loc(#loc2)
%24 = torch.aten.to.dtype %23, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> loc(#loc2)
%25 = torch.aten.sqrt %24 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> loc(#loc2)
return %25 : !torch.vtensor<[],f32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/stats.py":419:15)
#loc2 = loc("aten::std"(#loc1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment