Skip to content

Instantly share code, notes, and snippets.

@vivekkhandelwal1
Last active July 6, 2023 14:57
Show Gist options
  • Select an option

  • Save vivekkhandelwal1/b458ba8e74936b2445693663ff89a967 to your computer and use it in GitHub Desktop.

Select an option

Save vivekkhandelwal1/b458ba8e74936b2445693663ff89a967 to your computer and use it in GitHub Desktop.
hal.executable public @forward_dispatch_1 {
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], min_subgroup_size = 32, max_subgroup_size = 32, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>}> {
hal.executable.export public @forward_dispatch_1_generic_D_i1xf32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 4, sets = [<0, bindings = [<0, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device loc(callsite("<stdin>":820:10 at "<stdin>":26:3)), %arg1: index loc("<stdin>":26:3), %arg2: index loc(callsite("<stdin>":786:10 at "<stdin>":26:3))):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2 loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
hal.return %x, %y, %z : index, index, index loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
builtin.module {
func.func @forward_dispatch_1_generic_D_i1xf32() {
%c32_i64 = arith.constant 32 : i64 loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%c0 = arith.constant 0 : index loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%true = arith.constant true loc(callsite("<stdin>":787:13 at "<stdin>":26:3))
%c1_i64 = arith.constant 1 : i64 loc(callsite("<stdin>":784:11 at "<stdin>":26:3))
%c427868160 = arith.constant 427868160 : index loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%0 = hal.interface.constant.load[0] : i32 loc("<stdin>":26:3)
%1 = hal.interface.constant.load[1] : i32 loc("<stdin>":26:3)
%2 = hal.interface.constant.load[2] : i32 loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%3 = hal.interface.constant.load[3] : i32 loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%4 = arith.extui %1 : i32 to i64 loc("<stdin>":26:3)
%5 = arith.shli %4, %c32_i64 : i64 loc("<stdin>":26:3)
%6 = arith.extui %0 : i32 to i64 loc("<stdin>":26:3)
%7 = arith.ori %6, %5 : i64 loc("<stdin>":26:3)
%8 = arith.index_castui %7 : i64 to index loc("<stdin>":26:3)
%9 = arith.extui %3 : i32 to i64 loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%10 = arith.shli %9, %c32_i64 : i64 loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%11 = arith.extui %2 : i32 to i64 loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%12 = arith.ori %11, %10 : i64 loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%13 = arith.index_castui %12 : i64 to index loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%14 = flow.dispatch.workload.ordinal %13, 1 : index loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%15 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c427868160) : !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%14} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%16 = flow.dispatch.workload.ordinal %8, 0 : index loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%17 = arith.index_cast %16 : index to i64 loc(callsite("<stdin>":777:12 at "<stdin>":26:3))
%18 = arith.addi %17, %c1_i64 : i64 loc(callsite("<stdin>":785:12 at "<stdin>":26:3))
%19 = arith.index_cast %18 : i64 to index loc(callsite("<stdin>":786:10 at "<stdin>":26:3))
%20 = tensor.empty(%19) : tensor<1x1x1x?xi1> loc(callsite("<stdin>":810:21 at "<stdin>":26:3))
%21 = tensor.empty(%19) : tensor<?xf32> loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
%22 = tensor.empty(%19) : tensor<?xi1> loc(callsite("<stdin>":813:10 at "<stdin>":26:3))
%23 = linalg.fill ins(%true : i1) outs(%20 : tensor<1x1x1x?xi1>) -> tensor<1x1x1x?xi1> loc(callsite("<stdin>":811:10 at "<stdin>":26:3))
%24 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%22 : tensor<?xi1>) {
^bb0(%out: i1 loc("<stdin>":814:10)):
%26 = linalg.index 0 : index loc(callsite("<stdin>":815:15 at "<stdin>":26:3))
%extracted = tensor.extract %23[%c0, %c0, %c0, %26] : tensor<1x1x1x?xi1> loc(callsite("<stdin>":816:20 at "<stdin>":26:3))
linalg.yield %extracted : i1 loc(callsite("<stdin>":817:7 at "<stdin>":26:3))
} -> tensor<?xi1> loc(callsite("<stdin>":813:10 at "<stdin>":26:3))
%25 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<?xi1>) outs(%21 : tensor<?xf32>) {
^bb0(%in: i1 loc(callsite("<stdin>":820:10 at "<stdin>":26:3)), %out: f32 loc(callsite("<stdin>":820:10 at "<stdin>":26:3))):
%26 = arith.uitofp %in : i1 to f32 loc(callsite("<stdin>":822:15 at "<stdin>":26:3))
linalg.yield %26 : f32 loc(callsite("<stdin>":823:7 at "<stdin>":26:3))
} -> tensor<?xf32> loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
flow.dispatch.tensor.store %25, %15, offsets = [0], sizes = [%14], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%14} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
return loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
} loc(callsite("<stdin>":820:10 at "<stdin>":26:3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment