Created
September 12, 2023 15:29
-
-
Save Max191/c76474d1d484cb05d191543877110d28 to your computer and use it in GitHub Desktop.
IR for reassociation of quantized matmul
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
builtin.module { | |
func.func @quantized_matmul(%arg0: tensor<11008x32x128xi8>, %arg1: tensor<11008x32x1xf32>, %arg2: tensor<11008x32x1xf32>, %arg3: tensor<1x1x32x128xf32>) -> tensor<1x1x11008xf32> { | |
%cst = arith.constant 0.000000e+00 : f32 | |
%4 = tensor.empty() : tensor<1x1x11008xf32> | |
%5 = tensor.empty() : tensor<11008x32x128xf32> | |
%6 = linalg.fill ins(%cst : f32) outs(%4 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32> | |
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor<11008x32x128xi8>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>) outs(%5 : tensor<11008x32x128xf32>) { | |
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32): | |
%9 = arith.extui %in : i8 to i32 | |
%10 = arith.uitofp %9 : i32 to f32 | |
%11 = arith.subf %10, %in_1 : f32 | |
%12 = arith.mulf %11, %in_0 : f32 | |
linalg.yield %12 : f32 | |
} -> tensor<11008x32x128xf32> | |
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg3, %7 : tensor<1x1x32x128xf32>, tensor<11008x32x128xf32>) outs(%6 : tensor<1x1x11008xf32>) { | |
^bb0(%in: f32, %in_0: f32, %out: f32): | |
%9 = arith.mulf %in, %in_0 : f32 | |
%10 = arith.addf %9, %out : f32 | |
linalg.yield %10 : f32 | |
} -> tensor<1x1x11008xf32> | |
return %8 : tensor<1x1x11008xf32> | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
builtin.module { | |
func.func @quantized_matmul_reassociated(%matrix: tensor<11008x32x128xi8>, %matrix_scales: tensor<11008x32x1xf32>, %matrix_zps: tensor<11008x32x1xf32>, %vector: tensor<1x1x32x128xf32>) -> (tensor<1x1x11008xf32>) { | |
%cst_127 = arith.constant 127.000000e+00 : f32 | |
%cst = arith.constant 0.000000e+00 : f32 | |
%cst_0 = arith.constant 0 : i8 | |
%cst_1 = arith.constant 0 : i16 | |
%10 = tensor.empty() : tensor<1x1x32xf32> | |
%0 = linalg.fill ins(%cst : f32) outs(%10 : tensor<1x1x32xf32>) -> tensor<1x1x32xf32> | |
%1 = tensor.empty() : tensor<1x1x32xf32> | |
%12 = tensor.empty() : tensor<1x1x32xf32> | |
%2 = linalg.fill ins(%cst : f32) outs(%12 : tensor<1x1x32xf32>) -> tensor<1x1x32xf32> | |
%3 = tensor.empty() : tensor<1x1x32x128xi8> | |
%14 = tensor.empty() : tensor<1x1x11008x32xi16> | |
%4 = linalg.fill ins(%cst_1 : i16) outs(%14 : tensor<1x1x11008x32xi16>) -> tensor<1x1x11008x32xi16> | |
%15 = tensor.empty() : tensor<1x1x11008xf32> | |
%5 = linalg.fill ins(%cst : f32) outs(%15 : tensor<1x1x11008xf32>) -> tensor<1x1x11008xf32> | |
%vec_max = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], | |
iterator_types = ["parallel", "parallel", "parallel", "reduction"]} | |
ins(%vector : tensor<1x1x32x128xf32>) | |
outs(%0 : tensor<1x1x32xf32>) { | |
^bb0(%vec: f32, %out: f32): | |
%abs = math.absf %vec : f32 | |
%max = arith.maxf %abs, %out : f32 | |
linalg.yield %max : f32 | |
} -> tensor<1x1x32xf32> | |
%vec_scales = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, | |
affine_map<(d0, d1, d2) -> (d0, d1, d2)>], | |
iterator_types = ["parallel", "parallel", "parallel"]} | |
ins(%vec_max : tensor<1x1x32xf32>) | |
outs(%1 : tensor<1x1x32xf32>) { | |
^bb0(%vec_m: f32, %out: f32): | |
%scale = arith.divf %vec_m, %cst_127 : f32 | |
linalg.yield %scale : f32 | |
} -> tensor<1x1x32xf32> | |
%quantized_vec = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], | |
iterator_types = ["parallel", "parallel", "parallel", "parallel"]} | |
ins(%vector, %vec_scales : tensor<1x1x32x128xf32>, tensor<1x1x32xf32>) | |
outs(%3 : tensor<1x1x32x128xi8>) { | |
^bb0(%vec: f32, %vec_s: f32, %out: i8): | |
%scaled = arith.divf %vec, %vec_s : f32 | |
%quant = arith.fptoui %scaled : f32 to i8 | |
linalg.yield %quant : i8 | |
} -> tensor<1x1x32x128xi8> | |
%vec_scaled_sums = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], | |
iterator_types = ["parallel", "parallel", "parallel", "reduction"]} | |
ins(%vector : tensor<1x1x32x128xf32>) | |
outs(%2 : tensor<1x1x32xf32>) { | |
^bb0(%vec: f32, %out: f32): | |
%sum = arith.addf %vec, %out : f32 | |
linalg.yield %sum : f32 | |
} -> tensor<1x1x32xf32> | |
%dequant_matvec_0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, | |
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, | |
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], | |
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} | |
ins(%quantized_vec, %matrix : tensor<1x1x32x128xi8>, tensor<11008x32x128xi8>) | |
outs(%4 : tensor<1x1x11008x32xi16>) { | |
^bb0(%vec: i8, %weight: i8, %out: i16): | |
%vec_i16 = arith.extsi %vec : i8 to i16 | |
%weight_i16 = arith.extui %weight : i8 to i16 | |
%product = arith.muli %weight_i16, %vec_i16 : i16 | |
%result = arith.addi %product, %out : i16 | |
linalg.yield %result : i16 | |
} -> tensor<1x1x11008x32xi16> | |
%dequant_matvec_1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, | |
affine_map<(d0, d1, d2, d3) -> (d2, d3, 0)>, | |
affine_map<(d0, d1, d2, d3) -> (d2, d3, 0)>, | |
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], | |
iterator_types = ["parallel", "parallel", "parallel", "reduction"]} | |
ins(%dequant_matvec_0, %vec_scales, %vec_scaled_sums, %matrix_scales, %matrix_zps : tensor<1x1x11008x32xi16>, tensor<1x1x32xf32>, tensor<1x1x32xf32>, tensor<11008x32x1xf32>, tensor<11008x32x1xf32>) | |
outs(%5 : tensor<1x1x11008xf32>) { | |
^bb0(%matvec: i16, %vec_s: f32, %vec_ss: f32, %mat_s: f32, %mat_zp: f32, %out: f32): | |
%dq_matvec_0 = arith.extsi %matvec : i16 to i32 | |
%dq_matvec_1 = arith.sitofp %dq_matvec_0 : i32 to f32 | |
%scaled_result_0 = arith.mulf %dq_matvec_1, %vec_s : f32 | |
%scaled_result_1 = arith.mulf %scaled_result_0, %mat_s : f32 | |
%zp_scaled_0 = arith.mulf %mat_zp, %mat_s : f32 | |
%zp_scaled_1 = arith.mulf %zp_scaled_0, %vec_ss : f32 | |
%group_result = arith.subf %scaled_result_1, %zp_scaled_1 : f32 | |
%result = arith.addf %group_result, %out : f32 | |
linalg.yield %result : f32 | |
} -> tensor<1x1x11008xf32> | |
return %dequant_matvec_1 : tensor<1x1x11008xf32> | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment