Skip to content

Instantly share code, notes, and snippets.

@Max191
Created September 12, 2023 15:29
Show Gist options
  • Save Max191/c76474d1d484cb05d191543877110d28 to your computer and use it in GitHub Desktop.
Save Max191/c76474d1d484cb05d191543877110d28 to your computer and use it in GitHub Desktop.
IR for reassociation of quantized matmul
builtin.module {
func.func @quantized_matmul(%arg0: tensor<11008x32x128xi8>, %arg1: tensor<11008x32x1xf32>, %arg2: tensor<11008x32x1xf32>, %arg3: tensor<1x1x32x128xf32>) -> tensor<1x1x11008xf32> {
%cst = arith.constant 0.000000e+00 : f32
%4 = tensor.empty() : tensor<1x1x11008xf32>
%5 = tensor.empty() : tensor<11008x32x128xf32>
%6 = linalg.fill ins(%cst : f32) outs(%4 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<11008x32x128xi8>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>) outs(%5 : tensor<11008x32x128xf32>) {
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
%9 = arith.extui %in : i8 to i32
%10 = arith.uitofp %9 : i32 to f32
%11 = arith.subf %10, %in_1 : f32
%12 = arith.mulf %11, %in_0 : f32
linalg.yield %12 : f32
} -> tensor<11008x32x128xf32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg3, %7 : tensor<1x1x32x128xf32>, tensor<11008x32x128xf32>) outs(%6 : tensor<1x1x11008xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%9 = arith.mulf %in, %in_0 : f32
%10 = arith.addf %9, %out : f32
linalg.yield %10 : f32
} -> tensor<1x1x11008xf32>
return %8 : tensor<1x1x11008xf32>
}
}
builtin.module {
func.func @quantized_matmul_reassociated(%matrix: tensor<11008x32x128xi8>, %matrix_scales: tensor<11008x32x1xf32>, %matrix_zps: tensor<11008x32x1xf32>, %vector: tensor<1x1x32x128xf32>) -> (tensor<1x1x11008xf32>) {
%cst_127 = arith.constant 127.000000e+00 : f32
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0 : i8
%cst_1 = arith.constant 0 : i16
%10 = tensor.empty() : tensor<1x1x32xf32>
%0 = linalg.fill ins(%cst : f32) outs(%10 : tensor<1x1x32xf32>) -> tensor<1x1x32xf32>
%1 = tensor.empty() : tensor<1x1x32xf32>
%12 = tensor.empty() : tensor<1x1x32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%12 : tensor<1x1x32xf32>) -> tensor<1x1x32xf32>
%3 = tensor.empty() : tensor<1x1x32x128xi8>
%14 = tensor.empty() : tensor<1x1x11008x32xi16>
%4 = linalg.fill ins(%cst_1 : i16) outs(%14 : tensor<1x1x11008x32xi16>) -> tensor<1x1x11008x32xi16>
%15 = tensor.empty() : tensor<1x1x11008xf32>
%5 = linalg.fill ins(%cst : f32) outs(%15 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32>
%vec_max = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%vector : tensor<1x1x32x128xf32>)
outs(%0 : tensor<1x1x32xf32>) {
^bb0(%vec: f32, %out: f32):
%abs = math.absf %vec : f32
%max = arith.maxf %abs, %out : f32
linalg.yield %max : f32
} -> tensor<1x1x32xf32>
%vec_scales = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%vec_max : tensor<1x1x32xf32>)
outs(%1 : tensor<1x1x32xf32>) {
^bb0(%vec_m: f32, %out: f32):
%scale = arith.divf %vec_m, %cst_127 : f32
linalg.yield %scale : f32
} -> tensor<1x1x32xf32>
%quantized_vec = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%vector, %vec_scales : tensor<1x1x32x128xf32>, tensor<1x1x32xf32>)
outs(%3 : tensor<1x1x32x128xi8>) {
^bb0(%vec: f32, %vec_s: f32, %out: i8):
%scaled = arith.divf %vec, %vec_s : f32
%quant = arith.fptoui %scaled : f32 to i8
linalg.yield %quant : i8
} -> tensor<1x1x32x128xi8>
%vec_scaled_sums = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%vector : tensor<1x1x32x128xf32>)
outs(%2 : tensor<1x1x32xf32>) {
^bb0(%vec: f32, %out: f32):
%sum = arith.addf %vec, %out : f32
linalg.yield %sum : f32
} -> tensor<1x1x32xf32>
%dequant_matvec_0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%quantized_vec, %matrix : tensor<1x1x32x128xi8>, tensor<11008x32x128xi8>)
outs(%4 : tensor<1x1x11008x32xi16>) {
^bb0(%vec: i8, %weight: i8, %out: i16):
%vec_i16 = arith.extsi %vec : i8 to i16
%weight_i16 = arith.extui %weight : i8 to i16
%product = arith.muli %weight_i16, %vec_i16 : i16
%result = arith.addi %product, %out : i16
linalg.yield %result : i16
} -> tensor<1x1x11008x32xi16>
%dequant_matvec_1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3, 0)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%dequant_matvec_0, %vec_scales, %vec_scaled_sums, %matrix_scales, %matrix_zps : tensor<1x1x11008x32xi16>, tensor<1x1x32xf32>, tensor<1x1x32xf32>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>)
outs(%5 : tensor<1x1x11008xf32>) {
^bb0(%matvec: i16, %vec_s: f32, %vec_ss: f32, %mat_s: f32, %mat_zp: f32, %out: f32):
%dq_matvec_0 = arith.extsi %matvec : i16 to i32
%dq_matvec_1 = arith.sitofp %dq_matvec_0 : i32 to f32
%scaled_result_0 = arith.mulf %dq_matvec_1, %vec_s : f32
%scaled_result_1 = arith.mulf %scaled_result_0, %mat_s : f32
%zp_scaled_0 = arith.mulf %mat_zp, %mat_s : f32
%zp_scaled_1 = arith.mulf %zp_scaled_0, %vec_ss : f32
%group_result = arith.subf %scaled_result_1, %zp_scaled_1 : f32
%result = arith.addf %group_result, %out : f32
linalg.yield %result : f32
} -> tensor<1x1x11008xf32>
return %dequant_matvec_1 : tensor<1x1x11008xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment