Skip to content

Instantly share code, notes, and snippets.

module @module {
util.func public @decode_bs4$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant dense<1> : tensor<1x1xi64>
%cst_0 = arith.constant dense<0> : tensor<1x1xi64>
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_2 = arith.constant dense<1.000000e+00> : tensor<3200x3200xf32>
%cst_3 = arith.constant dense<1.000000e+00> : tensor<3200xf32>
%c0_i64 = arith.constant 0 : i64
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
module @module {
func.func @decode_bs4(%arg0: !torch.vtensor<[4,1],si64>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[4],si64>, %arg3: !torch.vtensor<[4,?],si64>, %arg4: !torch.tensor<[?,2662400],f32>) -> !torch.vtensor<[4,1],si64> attributes {torch.assume_strict_symbolic_shapes} {
%int2662400 = torch.constant.int 2662400
%int16 = torch.constant.int 16
%int26 = torch.constant.int 26
%int100 = torch.constant.int 100
%int32 = torch.constant.int 32
%int3200 = torch.constant.int 3200
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
module @module {
util.global private @__auto.token_embd.weight = #stream.parameter.named<"model"::"token_embd.weight"> : tensor<32000x3200xf16>
util.global private @__auto.blk.0.attn_norm.weight = #stream.parameter.named<"model"::"blk.0.attn_norm.weight"> : tensor<3200xf32>
util.global private @__auto.blk.0.attn_k.weight = #stream.parameter.named<"model"::"blk.0.attn_k.weight"> : tensor<3200x3200xf16>
util.global private @__auto.blk.0.attn_v.weight = #stream.parameter.named<"model"::"blk.0.attn_v.weight"> : tensor<3200x3200xf16>
util.func public @decode_bs4$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.fence, %arg6: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant dense<1> : tensor<1x1xi64>
%cst_0 = arith.constant dense<0> : tensor<1x1xi64>
%c0_i64 = arith.constant 0 : i64
%c1 = arith
@rsuderman
rsuderman / gist:49c4d2b679b2ecffc2a9123617a912c8
Created June 11, 2024 19:57
Quant Conv with Scale and Offset
import matplotlib.pyplot as plt
import torch
A_SHAPE = (4, 8, 16, 16)
B_SHAPE = (8, 8, 4, 4)
torch.manual_seed(12345)
def generate_input(shape):
M = torch.rand(shape, dtype=torch.float)
return M
@rsuderman
rsuderman / gist:24f5835706e46241af313abcb0bf7394
Created June 11, 2024 18:14
QMM with scale and offset corrections
import matplotlib.pyplot as plt
import torch
A_SHAPE = (8, 128)
B_SHAPE = (16, 128)
torch.manual_seed(12345)
A_OFFSET = torch.rand((A_SHAPE[0],1), dtype=torch.float)
B_OFFSET = torch.rand((B_SHAPE[0],1), dtype=torch.float)
@rsuderman
rsuderman / gist:62f79b3ef527c2aecb7f1e1803392f48
Last active June 14, 2024 22:00
Conv per channel quant
import matplotlib.pyplot as plt
import torch
A_SHAPE = (4, 8, 16, 16)
B_SHAPE = (8, 8, 4, 4)
torch.manual_seed(12345)
def generate_input(shape):
M = torch.rand(shape, dtype=torch.float)
return M
@rsuderman
rsuderman / gist:ca2dbf8d998e34c4880a51fb94fceb85
Last active June 10, 2024 21:10
Matmul per channel quant
import matplotlib.pyplot as plt
import torch
A_SHAPE = (8, 128)
B_SHAPE = (16, 128)
torch.manual_seed(12345)
A_QUANT = torch.rand((A_SHAPE[0],1), dtype=torch.float)
B_QUANT = torch.rand((B_SHAPE[0],1), dtype=torch.float)
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> ()>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module @module {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @main(%arg0: tensor<4x64x32xf8E4M3FNUZ>, %arg1: tensor<4x64x32xf8E4M3FNUZ>, %arg2: tensor<4x64x32xf8E4M3FNUZ>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>) -> tensor<4x64x32xf8E4M3FNUZ> {
%cst = arith.constant 0.000000e+00 : f32
%c0_i64 = arith.constant 0 : i64
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> ()>
func.func private @broadcast_scale_widen(
%value : tensor<4x64x96xf8E4M3FNUZ>, %scale : tensor<f32>) -> tensor<4x64x96xf32> {
%empty_f32 = tensor.empty() : tensor<4x64x96xf32>
%scaled = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]}
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> ()>
func.func private @broadcast_scale_widen(
%value : tensor<4x64x96xf16>, %scale : tensor<f32>) -> tensor<4x64x96xf32> {
%empty_f32 = tensor.empty() : tensor<4x64x96xf32>
%scaled = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]}