Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rsuderman/cf0703aa2d686da4bbd9cec9e288066f to your computer and use it in GitHub Desktop.
Save rsuderman/cf0703aa2d686da4bbd9cec9e288066f to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> ()>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module @module {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @main(%arg0: tensor<4x64x32xf8E4M3FNUZ>, %arg1: tensor<4x64x32xf8E4M3FNUZ>, %arg2: tensor<4x64x32xf8E4M3FNUZ>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>) -> tensor<4x64x32xf8E4M3FNUZ> {
%cst = arith.constant 0.000000e+00 : f32
%c0_i64 = arith.constant 0 : i64
%cst_0 = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<4x64x32xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x64x32xf8E4M3FNUZ>) outs(%0 : tensor<4x64x32xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%33 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %33 : f32
} -> tensor<4x64x32xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x64x32xf8E4M3FNUZ>) outs(%0 : tensor<4x64x32xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%33 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %33 : f32
} -> tensor<4x64x32xf32>
%3 = tensor.empty() : tensor<4x32x64xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<4x64x32xf32>) outs(%3 : tensor<4x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<4x32x64xf32>
%5 = tensor.empty() : tensor<4x64x64xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<4x64x64xf32>) -> tensor<4x64x64xf32>
%7 = linalg.batch_matmul ins(%1, %4 : tensor<4x64x32xf32>, tensor<4x32x64xf32>) outs(%6 : tensor<4x64x64xf32>) -> tensor<4x64x64xf32>
%8 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %arg3 : tensor<4x64x64xf32>, tensor<f32>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%33 = arith.mulf %in, %in_1 : f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%9 = tensor.empty() : tensor<4x64x64xf8E4M3FNUZ>
%10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<4x64x64xf32>) outs(%9 : tensor<4x64x64xf8E4M3FNUZ>) {
^bb0(%in: f32, %out: f8E4M3FNUZ):
%33 = arith.truncf %in : f32 to f8E4M3FNUZ
linalg.yield %33 : f8E4M3FNUZ
} -> tensor<4x64x64xf8E4M3FNUZ>
%11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10 : tensor<4x64x64xf8E4M3FNUZ>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%33 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%12 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%11, %arg4 : tensor<4x64x64xf32>, tensor<f32>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%33 = arith.mulf %in, %in_1 : f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%13 = tensor.empty() : tensor<4x64xi64>
%14 = linalg.fill ins(%c0_i64 : i64) outs(%13 : tensor<4x64xi64>) -> tensor<4x64xi64>
%15 = tensor.empty() : tensor<4x64xf32>
%16 = linalg.fill ins(%cst_0 : f32) outs(%15 : tensor<4x64xf32>) -> tensor<4x64xf32>
%17:2 = linalg.generic {indexing_maps = [#map, #map3, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12 : tensor<4x64x64xf32>) outs(%16, %14 : tensor<4x64xf32>, tensor<4x64xi64>) {
^bb0(%in: f32, %out: f32, %out_1: i64):
%33 = linalg.index 2 : index
%34 = arith.index_cast %33 : index to i64
%35 = arith.maximumf %in, %out : f32
%36 = arith.cmpf ogt, %in, %out : f32
%37 = arith.select %36, %34, %out_1 : i64
linalg.yield %35, %37 : f32, i64
} -> (tensor<4x64xf32>, tensor<4x64xi64>)
%expanded = tensor.expand_shape %17#0 [[0], [1, 2]] output_shape [4, 64, 1] : tensor<4x64xf32> into tensor<4x64x1xf32>
%18 = linalg.generic {indexing_maps = [#map, #map4, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12, %expanded : tensor<4x64x64xf32>, tensor<4x64x1xf32>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%33 = arith.subf %in, %in_1 : f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%19 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18 : tensor<4x64x64xf32>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = math.exp %in : f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%20 = tensor.empty() : tensor<4x64x1xf32>
%21 = linalg.fill ins(%cst : f32) outs(%20 : tensor<4x64x1xf32>) -> tensor<4x64x1xf32>
%22 = linalg.generic {indexing_maps = [#map, #map4], iterator_types = ["parallel", "parallel", "reduction"]} ins(%19 : tensor<4x64x64xf32>) outs(%21 : tensor<4x64x1xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.addf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<4x64x1xf32>
%23 = linalg.generic {indexing_maps = [#map, #map4, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%19, %22 : tensor<4x64x64xf32>, tensor<4x64x1xf32>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%33 = arith.divf %in, %in_1 : f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%24 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%23, %arg5 : tensor<4x64x64xf32>, tensor<f32>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%33 = arith.mulf %in, %in_1 : f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%25 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%24 : tensor<4x64x64xf32>) outs(%9 : tensor<4x64x64xf8E4M3FNUZ>) {
^bb0(%in: f32, %out: f8E4M3FNUZ):
%33 = arith.truncf %in : f32 to f8E4M3FNUZ
linalg.yield %33 : f8E4M3FNUZ
} -> tensor<4x64x64xf8E4M3FNUZ>
%26 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%25 : tensor<4x64x64xf8E4M3FNUZ>) outs(%5 : tensor<4x64x64xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%33 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %33 : f32
} -> tensor<4x64x64xf32>
%27 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4x64x32xf8E4M3FNUZ>) outs(%0 : tensor<4x64x32xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%33 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %33 : f32
} -> tensor<4x64x32xf32>
%28 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x64x32xf32>) -> tensor<4x64x32xf32>
%29 = linalg.batch_matmul ins(%26, %27 : tensor<4x64x64xf32>, tensor<4x64x32xf32>) outs(%28 : tensor<4x64x32xf32>) -> tensor<4x64x32xf32>
%30 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%29, %arg6 : tensor<4x64x32xf32>, tensor<f32>) outs(%0 : tensor<4x64x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%33 = arith.mulf %in, %in_1 : f32
linalg.yield %33 : f32
} -> tensor<4x64x32xf32>
%31 = tensor.empty() : tensor<4x64x32xf8E4M3FNUZ>
%32 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%30 : tensor<4x64x32xf32>) outs(%31 : tensor<4x64x32xf8E4M3FNUZ>) {
^bb0(%in: f32, %out: f8E4M3FNUZ):
%33 = arith.truncf %in : f32 to f8E4M3FNUZ
linalg.yield %33 : f8E4M3FNUZ
} -> tensor<4x64x32xf8E4M3FNUZ>
return %32 : tensor<4x64x32xf8E4M3FNUZ>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment