Skip to content

Instantly share code, notes, and snippets.

View antiagainst's full-sized avatar

Lei Zhang antiagainst

View GitHub Profile
%5:1283 = stream.resource.pack slices({
[0, 3] = %c640,
[0, 3] = %c153664,
[1, 3] = %c640,
[2, 6] = %c1280,
[3, 5] = %c640,
[3, 5] = %c1327104,
[4, 7] = %c1280,
[5, 989] = %c11796480,
[6, 8] = %c20480,
module attributes {torch.debug_module_name = "_lambda"} {
flow.executable private @forward_dispatch_0 {
flow.executable.export public @forward_dispatch_0_generic_16384 workgroups(%arg0: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @forward_dispatch_0_generic_16384(%arg0: !flow.dispatch.tensor<readonly:tensor<16384xf16>>, %arg1: !flow.dispatch.tensor<writeonly:tensor<16384xf16>>) {
%cst = arith.constant 5.4899807 : f32
%0 = flow.dispatch.tensor.load %arg0, offsets = [0], sizes = [16384], strides = [1] : !flow.dispatch.tensor<readonly:tensor<16384xf16>> -> tensor<16384xf16>
This file has been truncated, but you can view the full file.
module attributes {tf_saved_model.semantics} {
flow.executable private @main_dispatch_0 {
flow.executable.export public @main_dispatch_0_generic_DxDxD
builtin.module {
func.func @main_dispatch_0_generic_DxDxD(%arg0: index, %arg1: index, %arg2: index, %arg3: !flow.dispatch.tensor<readonly:tensor<1x30522x128xi8>>, %arg4: index, %arg5: !flow.dispatch.tensor<readonly:tensor<?x?xi32>>, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: index, %arg11: !flow.dispatch.tensor<writeonly:tensor<?x?x?xi8>>) {
%0 = flow.dispatch.tie_shape %arg5 : !flow.dispatch.tensor<readonly:tensor<?x?xi32>>{%arg6, %arg7}
%1 = flow.dispatch.tie_shape %arg11 : !flow.dispatch.tensor<writeonly:tensor<?x?x?xi8>>{%arg8, %arg9, %arg10}
%2 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [1, 30522, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x30522x128xi8>> -> tensor<1x30522x128xi8>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Float16, CooperativeMatrixNV], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]> {
spirv.GlobalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
spirv.GlobalVariable @__workgroup_mem__5 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>
spirv.GlobalVariable @__workgroup_mem__4 : !spirv.ptr<!spirv.struct<(!spirv.array<256 x vector<4xf32>>)>, Workgroup>
spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.GlobalVariable @__resource_var_0_0_ bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_1_ bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
spirv.GlobalVariable @__resource_var_0_2_ bind(0, 2) : !spirv.ptr<!spirv.struct<(!spirv.rtarra
// -----// IR Dump After TileAndDistributeToWorkgroups (iree-codegen-tile-and-distribute-to-workgroups) //----- //
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_proper
# Copied from https://colab.sandbox.google.com/github/iree-org/iree/blob/main/samples/colab/resnet.ipynb
# Running the following commands to install needed packages
# pip install --upgrade iree-compiler iree-runtime iree-tools-tf -f https://github.com/iree-org/iree/releases
# pip install --upgrade tf-nightly
from iree import runtime as ireert
from iree import compiler as ireec
from iree.tf.support import module_utils
// Option#1: use sequence of blocks
// Missing all goodies for regions :(
// Option#2: use scf dialect for scf
// Missing story inside spirv; hard to deserialize from blob :(
// Option#3: flesh out spirv scf ops: introduce new terminators:
// - spv.mlir.fallthrough // jump to the next region
// - spv.mlir.condition // jump to the next region if true; jump out of the current scf op if false
// - spv.mlir.merge // jump out of the current scf op
// tools/iree-compile --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=valhall-unknown-android31 --iree-flow-enable-fuse-padding-into-consumer-ops ~/models/mhlo-conv.mlir -o /dev/null --mlir-print-ir-after-all --mlir-print-ir-after-change --mlir-disable-threading --mlir-elide-elementsattrs-if-larger=8 -debug-only=iree-spirv-vectorize &>! mhlo-conv.log
// iree-org/iree@a8e4c38c
// -----// IR Dump After mlir::iree_compiler::IREE::HAL::(anonymous namespace)::MaterializeInterfacesPass (iree-hal-materialize-interfaces) //----- //
#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, Group
// tools/iree-compile --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=valhall-unknown-android31 ~/models/mhlo-dot.mlir -o /dev/null --mlir-print-ir-after-all --mlir-print-ir-after-change --mlir-disable-threading --mlir-elide-elementsattrs-if-larger=8 -debug-only=iree-spirv-vectorize &>! mhlo-dot.log
// iree-org/iree@a8e4c38c
// -----// IR Dump After mlir::iree_compiler::IREE::HAL::(anonymous namespace)::MaterializeInterfacesPass (iree-hal-materialize-interfaces) //----- //
#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, G
func.func @conv_pad_dispatch_1_conv_2d_nhwc_hwcf_1x112x112x16x3x3x3() {
%cst = arith.constant dense<0.000000e+00> : vector<1x2x2x4xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c112 = arith.constant 112 : index
%c16 = arith.constant 16 : index
%cst_0 = arith.constant 0.000000e+00 : f32