Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created March 1, 2024 07:32
Show Gist options
  • Save AmosLewis/f498809d2f8193820b6a53e20873d185 to your computer and use it in GitHub Desktop.
Save AmosLewis/f498809d2f8193820b6a53e20873d185 to your computer and use it in GitHub Desktop.
#loc = loc(unknown)
module attributes {torch.debug_module_name = "BernoulliModule"} {
func.func @forward(%arg0: !torch.vtensor<[?,?,?],f64> loc(unknown)) -> (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>) {
%true = torch.constant.bool true loc(#loc)
%false = torch.constant.bool false loc(#loc)
%int7 = torch.constant.int 7 loc(#loc)
%int2 = torch.constant.int 2 loc(#loc)
%int1 = torch.constant.int 1 loc(#loc)
%none = torch.constant.none loc(#loc)
%float0.000000e00 = torch.constant.float 0.000000e+00 loc(#loc4)
%float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc)
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<f64>) : !torch.vtensor<[],f64> loc(#loc4)
%int0 = torch.constant.int 0 loc(#loc)
%1 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc4)
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc4)
%3 = torch.aten.size.int %arg0, %int2 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc4)
%4 = torch.prim.ListConstruct %1, %2, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc4)
%5 = torch.aten.broadcast_to %0, %4 : !torch.vtensor<[],f64>, !torch.list<int> -> !torch.vtensor<[?,?,?],f64> loc(#loc4)
%6 = torch.aten.uniform %5, %float0.000000e00, %float1.000000e00, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> loc(#loc4)
%7 = torch.aten.lt.Tensor %6, %arg0 : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> loc(#loc4)
%8 = torch.aten.to.dtype %7, %int7, %false, %false, %none : !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> loc(#loc4)
%9 = torch.aten.sum %8, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[],f64> loc(#loc5)
%10 = torch.aten.numel %8 : !torch.vtensor<[?,?,?],f64> -> !torch.int loc(#loc5)
%11 = torch.aten.div.Scalar %9, %10 : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> loc(#loc5)
%12 = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc6)
%13 = torch.aten.sum.dim_IntList %8, %12, %true, %none : !torch.vtensor<[?,?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> loc(#loc6)
%14 = torch.aten.size.int %8, %int0 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc6)
%15 = torch.aten.mul.int %int1, %14 : !torch.int, !torch.int -> !torch.int loc(#loc6)
%16 = torch.aten.size.int %8, %int1 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc6)
%17 = torch.aten.mul.int %15, %16 : !torch.int, !torch.int -> !torch.int loc(#loc6)
%18 = torch.aten.size.int %8, %int2 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc6)
%19 = torch.aten.mul.int %17, %18 : !torch.int, !torch.int -> !torch.int loc(#loc6)
%20 = torch.aten.div.Scalar %13, %19 : !torch.vtensor<[1,1,1],f64>, !torch.int -> !torch.vtensor<[1,1,1],f64> loc(#loc6)
%21 = torch.aten.sub.Tensor %8, %20, %float1.000000e00 : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[1,1,1],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64> loc(#loc6)
%22 = torch.aten.mul.Tensor %21, %21 : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],f64> loc(#loc6)
%23 = torch.aten.sum.dim_IntList %22, %12, %false, %none : !torch.vtensor<[?,?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64> loc(#loc6)
%24 = torch.aten.size.int %8, %int0 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc6)
%25 = torch.aten.mul.int %int1, %24 : !torch.int, !torch.int -> !torch.int loc(#loc6)
%26 = torch.aten.size.int %8, %int1 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc6)
%27 = torch.aten.mul.int %25, %26 : !torch.int, !torch.int -> !torch.int loc(#loc6)
%28 = torch.aten.size.int %8, %int2 : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int loc(#loc6)
%29 = torch.aten.mul.int %27, %28 : !torch.int, !torch.int -> !torch.int loc(#loc6)
%30 = torch.aten.Float.Scalar %29 : !torch.int -> !torch.float loc(#loc6)
%31 = torch.aten.add %30, %float1.000000e00 : !torch.float, !torch.float -> !torch.float loc(#loc6)
%32 = torch.aten.ge.float %31, %float1.000000e00 : !torch.float, !torch.float -> !torch.bool loc(#loc6)
torch.runtime.assert %32, "correction value should be less than or equal to productDimSize + 1" loc(#loc6)
%33 = torch.aten.sub.float %30, %float1.000000e00 : !torch.float, !torch.float -> !torch.float loc(#loc6)
%34 = torch.aten.div.Scalar %23, %33 : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64> loc(#loc6)
%35 = torch.aten.sqrt %34 : !torch.vtensor<[],f64> -> !torch.vtensor<[],f64> loc(#loc6)
return %11, %35 : !torch.vtensor<[],f64>, !torch.vtensor<[],f64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/rng.py":193:12)
#loc2 = loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/rng.py":194:15)
#loc3 = loc("/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/rng.py":195:14)
#loc4 = loc("aten::bernoulli"(#loc1))
#loc5 = loc("aten::mean"(#loc2))
#loc6 = loc("aten::std"(#loc3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment