Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rsuderman/5c1c58fc3e1e7156f1625375e13cb002 to your computer and use it in GitHub Desktop.
Save rsuderman/5c1c58fc3e1e7156f1625375e13cb002 to your computer and use it in GitHub Desktop.
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> ()>
func.func private @broadcast_scale_widen(
%value : tensor<4x64x96xf16>, %scale : tensor<f32>) -> tensor<4x64x96xf32> {
%empty_f32 = tensor.empty() : tensor<4x64x96xf32>
%scaled = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%value, %scale : tensor<4x64x96xf16>, tensor<f32>) outs(%empty_f32 : tensor<4x64x96xf32>) {
^bb0(%in0: f16, %in1: f32, %out: f32):
%ext = arith.extf %in0 : f16 to f32
%mul = arith.mulf %ext, %in1 : f32
linalg.yield %mul : f32
} -> tensor<4x64x96xf32>
return %scaled : tensor<4x64x96xf32>
}
func.func private @broadcast_scale_narrow(
%value : tensor<4x64x96xf32>, %scale : tensor<f32>) -> tensor<4x64x96xf16> {
%empty_f32 = tensor.empty() : tensor<4x64x96xf16>
%scaled = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%value, %scale : tensor<4x64x96xf32>, tensor<f32>) outs(%empty_f32 : tensor<4x64x96xf16>) {
^bb0(%in0: f32, %in1: f32, %out: f16):
%mul = arith.mulf %in0, %in1 : f32
%trunc = arith.truncf %mul : f32 to f16
linalg.yield %trunc : f16
} -> tensor<4x64x96xf16>
return %scaled : tensor<4x64x96xf16>
}
func.func private @scaled_mmt(
%query : tensor<4x64x96xf16>, %query_scale : tensor<f32>,
%key : tensor<4x64x96xf16>, %key_scale : tensor<f32>,
%value : tensor<4x64x96xf16>, %value_scale : tensor<f32>,
%scale : tensor<f32>, %result_scale : tensor<f32>) -> tensor<4x64x96xf16> {
%query_fp32 = func.call @broadcast_scale_widen(%query, %query_scale) : (tensor<4x64x96xf16>, tensor<f32>) -> tensor<4x64x96xf32>
%key_fp32 = func.call @broadcast_scale_widen(%key, %key_scale) : (tensor<4x64x96xf16>, tensor<f32>) -> tensor<4x64x96xf32>
%value_fp32 = func.call @broadcast_scale_widen(%value, %value_scale) : (tensor<4x64x96xf16>, tensor<f32>) -> tensor<4x64x96xf32>
%extract = tensor.extract %scale[] : tensor<f32>
%empty = tensor.empty() : tensor<4x64x96xf32>
%attention = iree_linalg_ext.attention ins(%query_fp32, %key_fp32, %value_fp32, %extract : tensor<4x64x96xf32>,
tensor<4x64x96xf32>, tensor<4x64x96xf32>, f32) outs(%empty : tensor<4x64x96xf32>) -> tensor<4x64x96xf32>
%attention_fp16 = func.call @broadcast_scale_narrow(%attention, %result_scale) : (tensor<4x64x96xf32>, tensor<f32>) -> tensor<4x64x96xf16>
return %attention_fp16 : tensor<4x64x96xf16>
}
func.func @main(
%query : tensor<4x64x96xf16>, %query_scale : tensor<f32>,
%key : tensor<4x64x96xf16>, %key_scale : tensor<f32>,
%value : tensor<4x64x96xf16>, %value_scale : tensor<f32>,
%scale : tensor<f32>, %result_scale : tensor<f32>) -> tensor<4x64x96xf16> {
%call = func.call @scaled_mmt(%query, %query_scale, %key, %key_scale, %value, %value_scale, %scale, %result_scale)
: (tensor<4x64x96xf16>, tensor<f32>, tensor<4x64x96xf16>, tensor<f32>, tensor<4x64x96xf16>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<4x64x96xf16>
return %call : tensor<4x64x96xf16>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment