IR for the embedding bag only is below
#map0 = affine_map<(d0, d1, d2) -> (d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
module attributes {torch.debug_module_name = "ToyEmbeddingBag"} {
func.func @forward(%arg0: tensor<4xi64>, %arg1: tensor<2xi64>) -> tensor<2x3xf32> {
%cst = arith.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0], [16.0, 17.0, 18.0], [19.0, 20.0, 21.0], [22.0, 23.0, 24.0], [25.0, 26.0, 27.0], [28.0, 29.0, 30.0]]> : tensor<10x3xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c1_i64 = arith.constant 1 : i64
%c2_i64 = arith.constant 2 : i64
%c4_i64 = arith.constant 4 : i64
%0 = linalg.init_tensor [2, 3] : tensor<2x3xf32>
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<4xi64>, tensor<2xi64>) outs(%1 : tensor<2x3xf32>) {
^bb0(%arg2: i64, %arg3: i64, %arg4: f32):
%3 = linalg.index 0 : index
%4 = arith.index_cast %3 : index to i64
%5 = arith.addi %4, %c1_i64 : i64
%6 = arith.index_cast %5 : i64 to index
%7 = arith.cmpi eq, %5, %c2_i64 : i64
%8 = tensor.extract %arg1[%6] : tensor<2xi64>
%9 = arith.select %7, %c4_i64, %8 : i64
%10 = linalg.index 1 : index
%11 = arith.index_cast %10 : index to i64
%12 = arith.cmpi slt, %arg3, %11 : i64
%13 = arith.cmpi eq, %arg3, %11 : i64
%14 = arith.ori %12, %13 : i1
%15 = arith.cmpi slt, %11, %9 : i64
%16 = arith.andi %14, %15 : i1
%17 = arith.index_cast %arg2 : i64 to index
%18 = linalg.index 2 : index
%19 = tensor.extract %cst[%17, %18] : tensor<10x3xf32>
%20 = arith.addf %19, %arg4 : f32
%21 = arith.select %16, %20, %arg4 : f32
linalg.yield %21 : f32
} -> tensor<2x3xf32>
return %2 : tensor<2x3xf32>
}
}
I set the constants so that it would be easier to interpret what intel-gpu is outputting. Example scripts for running gpu and cpu.
#!/bin/bash
~/nod/iree-build-latest/tools/iree-compile embedding_bag/dlrm.mlir --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=opencl-spirv --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-flow-demote-f64-to-f32=false --mlir-print-ir-after-all --iree-flow-trace-dispatch-tensors --iree-hal-dump-executable-binaries-to=embedding_bag -o embedding_bag/dlrm.vmfb 2> embedding_bag/dlrm_gpu_log.mlir
~/nod/iree-build-latest/tools/iree-run-module --device=level_zero --entry_function=forward --function_input="4xi64=1 2 4 5" --function_input="2xi64=0 1" --module_file=embedding_bag/dlrm.vmfb
#!/bin/bash
~/nod/iree-build-latest/tools/iree-compile embedding_bag/dlrm.mlir --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-flow-demote-f64-to-f32=false --mlir-print-ir-after-all --iree-flow-trace-dispatch-tensors --iree-hal-dump-executable-binaries-to=embedding_bag -o embedding_bag/dlrm_cpu.vmfb 2> embedding_bag/dlrm_cpu_log.mlir
~/nod/iree-build-latest/tools/iree-run-module --entry_function=forward --function_input="4xi64=1 2 4 5" --function_input="2xi64=0 1" --module_file=embedding_bag/dlrm_cpu.vmfb
replace the output files + module files + iree paths as necessary. GPU is giving this
EXEC @forward
=== forward_dispatch_0::forward_dispatch_0_generic_2x4x3 inputs ===
2xi64=0 1
4xi64=1 2 4 5
2x3xf32=[0 0 0][0 0 0]
=== forward_dispatch_0::forward_dispatch_0_generic_2x4x3 outputs ===
2x3xf32=[0 0 0][16 17 18]
result[0]: hal.buffer_view
2x3xf32=[0 0 0][16 17 18]
Which basically means that it is only properly indexing the last element (5) of the second bag in the embedding bag operation. CPU output for reference
EXEC @forward
=== forward_dispatch_0::forward_dispatch_0_generic_2x4x3 inputs ===
2xi64=0 1
4xi64=1 2 4 5
2x3xf32=[0 0 0][0 0 0]
=== forward_dispatch_0::forward_dispatch_0_generic_2x4x3 outputs ===
2x3xf32=[4 5 6][36 39 42]
result[0]: hal.buffer_view
2x3xf32=[4 5 6][36 39 42]
I haven't narrowed down the cause completely but it could be a problem with either the addressing not working properly for the constant weights, or the workgroups on gpu aren't being computed properly. If it is the latter then we might see a similar failure on GPU, but I wasn't able to get shark set up and there was a crash in my local IREE.
The branch needed to run level-zero stuff is here https://github.com/qedawkins/iree/tree/ze-shark-09072022-P64
Additionally I uploaded a python script for the embedding bag example here https://gist.github.com/5533c90d54988e4d13ac17528f49bffe
The IR comes from the self.sparse_arch(*sparse_features)
call in DLRMShark
def forward(
self, dense_features: torch.Tensor, *sparse_features
) -> torch.Tensor:
#embedded_dense = self.dense_arch(dense_features)
embedded_sparse = self.sparse_arch(*sparse_features)
#concatenated_dense = self.inter_arch(
# dense_features=embedded_dense, sparse_features=embedded_sparse
#)
#logits = self.over_arch(concatenated_dense)
return embedding_sparse
The full IR can be found in this gist: https://gist.github.com/qedawkins/e3023a70ef3e5797545c4b3468e98dc1 (with modifications to remove unused arguments)
The following commands (or something like it) can be used to test the model with dummy inputs on GPU:
~/nod/iree-build-latest/tools/iree-compile mini_dlrm/dlrm.mlir --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=opencl-spirv --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-flow-demote-f64-to-f32=false --mlir-print-ir-after-all --iree-flow-trace-dispatch-tensors --iree-hal-dump-executable-binaries-to=single_matmul -o mini_dlrm/dlrm.vmfb 2> mini_dlrm/dlrm_gpu_log.mlir
~/nod/iree-build-latest/tools/iree-run-module --device=level_zero --entry_function=forward --function_input=4xi64=1 --function_input=2xi64=2 --function_input=4xi64=3 --function_input=2xi64=4 --function_input=3xi64=5 --function_input=2xi64=6 --module_file=mini_dlrm/dlrm.vmfb
CPU:
~/nod/iree-build-latest/tools/iree-compile mini_dlrm/dlrm.mlir --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-flow-demote-f64-to-f32=false --mlir-print-ir-after-all --iree-flow-trace-dispatch-tensors --iree-hal-dump-executable-binaries-to=single_matmul -o mini_dlrm/dlrm_cpu.vmfb 2> mini_dlrm/dlrm_cpu_log.mlir
~/nod/iree-build-latest/tools/iree-run-module --entry_function=forward --function_input=4xi64=1 --function_input=2xi64=2 --function_input=4xi64=3 --function_input=2xi64=4 --function_input=3xi64=5 --function_input=2xi64=6 --module_file=mini_dlrm/dlrm_cpu.vmfb
and if we compare the dispatch results between these two runs we find that the first dispatch is the only immediately apparent discrepancy
/gpu
=== forward_dispatch_0::forward_dispatch_0_generic_2x4x8 outputs ===
2x8xf32=[0 0 0 0 0 0 0 0][-1.04558 -0.684866 -0.912183 0.301311 -0.196631 -1.73568 -1.09422 0.193754]
/cpu
=== forward_dispatch_0::forward_dispatch_0_generic_2x4x8 outputs ===
2x8xf32=[0 0 0 0 0 0 0 0][-2.09117 -1.36973 -1.82437 0.602622 -0.393261 -3.47136 -2.18844 0.387507]
and this is what the IR for that dispatch looks like.
#map0 = affine_map<(d0, d1, d2) -> (d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
module attributes {torch.debug_module_name = "DLRMShark"} {
flow.executable private @forward_dispatch_0 {
flow.executable.export public @forward_dispatch_0_generic_2x4x8 workgroups(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.default_workgroup_count %arg0, %arg1, %arg2
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @forward_dispatch_0_generic_2x4x8(%arg0: !flow.dispatch.tensor<readonly:2xi64>, %arg1: !flow.dispatch.tensor<readonly:100x8xf32>, %arg2: !flow.dispatch.tensor<readonly:4xi64>, %arg3: !flow.dispatch.tensor<readwrite:2x8xf32>) {
%c1_i64 = arith.constant 1 : i64
%c2_i64 = arith.constant 2 : i64
%c4_i64 = arith.constant 4 : i64
%0 = flow.dispatch.tensor.load %arg0, offsets = [0], sizes = [2], strides = [1] : !flow.dispatch.tensor<readonly:2xi64> -> tensor<2xi64>
%1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [100, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:100x8xf32> -> tensor<100x8xf32>
%2 = flow.dispatch.tensor.load %arg2, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:4xi64> -> tensor<4xi64>
%3 = flow.dispatch.tensor.load %arg3, offsets = [0, 0], sizes = [2, 8], strides = [1, 1] : !flow.dispatch.tensor<readwrite:2x8xf32> -> tensor<2x8xf32>
%4 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2, %0 : tensor<4xi64>, tensor<2xi64>) outs(%3 : tensor<2x8xf32>) {
^bb0(%arg4: i64, %arg5: i64, %arg6: f32):
%5 = linalg.index 0 : index
%6 = arith.index_cast %5 : index to i64
%7 = arith.addi %6, %c1_i64 : i64
%8 = arith.index_cast %7 : i64 to index
%9 = arith.cmpi eq, %7, %c2_i64 : i64
%10 = tensor.extract %0[%8] : tensor<2xi64>
%11 = arith.select %9, %c4_i64, %10 : i64
%12 = linalg.index 1 : index
%13 = arith.index_cast %12 : index to i64
%14 = arith.cmpi slt, %arg5, %13 : i64
%15 = arith.cmpi eq, %arg5, %13 : i64
%16 = arith.ori %14, %15 : i1
%17 = arith.cmpi slt, %13, %11 : i64
%18 = arith.andi %16, %17 : i1
%19 = arith.index_cast %arg4 : i64 to index
%20 = linalg.index 2 : index
%21 = tensor.extract %1[%19, %20] : tensor<100x8xf32>
%22 = arith.addf %21, %arg6 : f32
%23 = arith.select %18, %22, %arg6 : f32
linalg.yield %23 : f32
} -> tensor<2x8xf32>
flow.dispatch.tensor.store %4, %arg3, offsets = [0, 0], sizes = [2, 8], strides = [1, 1] : tensor<2x8xf32> -> !flow.dispatch.tensor<readwrite:2x8xf32>
return
}
}
}