Skip to content

Instantly share code, notes, and snippets.

module @module {
util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.fence, %arg5: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%c2_i64 = arith.constant 2 : i64
%0 = hal.tensor.import wait(%arg4) => %arg1 : !hal.buffer_view -> tensor<1x64xf32>
%1 = hal.tensor.import wait(%arg4) => %arg3 : !hal.buffer_view -> tensor<1xi64>
%2 = hal.tensor.import wait(%arg4) => %arg0 : !hal.buffer_view -> tensor<4x64xf32>
%3 = tensor.empty() : tensor<1xi64>
%4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1 : tensor<1xi64>) outs(%3 : tensor<1xi64>) {
^bb0(%in: i64, %out: i64):
%14 = arith.muli %in, %c2_i64 : i64
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
util.global private @__device_0 = #hal.device.target<"hip", {ordinal = 0 : index}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>]> : !hal.device
stream.executable private @main$async_dispatch_0 {
stream.executable.export public @main$async_dispatch_0
module @module {
util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%0 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<1x64xf32>
%1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<1xi64>
%2 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<4x64xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<1x64xf32> into tensor<64xf32>
%3 = tensor.empty() : tensor<1x64xf32>
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<64xf32>) outs(%3 : tensor<1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
util.global private @__device_0 = #hal.device.target<"hip", {ordinal = 0 : index}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>]> : !hal.device
stream.executable private @main$async_dispatch_0 {
stream.executable.export public @main$async_dispatch_0
module @module {
util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%0 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<1x64xf32>
%1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<1xi64>
%2 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<16x16x64xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<1x64xf32> into tensor<64xf32>
%3 = tensor.empty() : tensor<1x64xf32>
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<64xf32>) outs(%3 : tensor<1x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
util.global private @__device_0 = #hal.device.target<"hip", {ordinal = 0 : index}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMAR3_F32_16x16x16_F16>, <WMMAR3_F16_16x16x16_F16>, <WMMAR3_F32_16x16x16_BF16>, <WMMAR3_BF16_16x16x16_BF16>, <WMMAR3_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>]> : !hal.device
stream.executable private @main_dispatch_0 {
stream.executable.export public @main_dispatch_0_elementwise_broadca
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> ()>
#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
module {
func.func @main(%arg0: tensor<64x128x128xf16>, %arg1: tensor<64x128x128xf16>, %arg2: tensor<64x128x128xf16>) -> tensor<64x128x128xf16> {
%c1 = arith.constant 1 : index
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
util.global private @__device_0 = #hal.device.target<"hip", {ordinal = 0 : index}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 8192>>, ukernels = "none"}>]> : !hal.device
util.global private @__device_1 = #hal.device.target<"hip", {ordinal = 1 : index}, [#hal.executable.target<"rocm", "rocm-hsa
module @module {
util.global private @__auto.ffn_gate.weight.shard.0 {stream.affinity = #hal.device.promise<@__device_0>} = #stream.parameter.named<"model"::"ffn_gate.weight.shard.0"> : tensor<64x64xf16>
util.global private @__auto.ffn_gate.weight.shard.1 {stream.affinity = #hal.device.promise<@__device_1>} = #stream.parameter.named<"model"::"ffn_gate.weight.shard.1"> : tensor<64x64xf16>
util.global private @__auto.ffn_up.weight.shard.0 {stream.affinity = #hal.device.promise<@__device_0>} = #stream.parameter.named<"model"::"ffn_up.weight.shard.0"> : tensor<64x64xf16>
util.global private @__auto.ffn_up.weight.shard.1 {stream.affinity = #hal.device.promise<@__device_1>} = #stream.parameter.named<"model"::"ffn_up.weight.shard.1"> : tensor<64x64xf16>
util.global private @__auto.ffn_down.weight.shard.0 {stream.affinity = #hal.device.promise<@__device_0>} = #stream.parameter.named<"model"::"ffn_down.weight.shard.0"> : tensor<64x64xf16>
util.global private @__auto.ffn_down.weight.shard.1 {stream.affinity
import torch
from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
bs=2
length = 5
heads = 3
dims = 8