Skip to content

Instantly share code, notes, and snippets.

@tucan9389
Created May 1, 2023 05:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tucan9389/f587f4abf77905dbc74df82140619272 to your computer and use it in GitHub Desktop.
Save tucan9389/f587f4abf77905dbc74df82140619272 to your computer and use it in GitHub Desktop.
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 440 : i32}, tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node4__optimizer.iter", tf_saved_model.exported_names = [], type = tensor<i64>, value = dense<0> : tensor<i64>} : () -> ()
"tf_saved_model.global_tensor"() {sym_name = "__sm_node6__optimizer.learning_rate", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<0.00999999977> : tensor<f32>} : () -> ()
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node17__model.conv1.kernel", tf_saved_model.exported_names = [], type = tensor<5x5x1x32xf32>, value = dense<""> : tensor<5x5x1x32xf32>} : () -> ()
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node26__model.conv2.kernel", tf_saved_model.exported_names = [], type = tensor<5x5x32x32xf32>, value = dense<""> : tensor<5x5x32x32xf32>} : () -> ()
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node39__model.dense1.kernel", tf_saved_model.exported_names = [], type = tensor<1568x1024xf32>, value = dense<""> : tensor<1568x1024xf32>} : () -> ()
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node40__model.dense1.bias", tf_saved_model.exported_names = [], type = tensor<1024xf32>, value = dense<0.000000e+00> : tensor<1024xf32>} : () -> ()
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node49__model.dense2.kernel", tf_saved_model.exported_names = [], type = tensor<1024x10xf32>, value = dense<""> : tensor<1024x10xf32>} : () -> ()
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node50__model.dense2.bias", tf_saved_model.exported_names = [], type = tensor<10xf32>, value = dense<0.000000e+00> : tensor<10xf32>} : () -> ()
func @__inference_predict_3320(%arg0: tensor<32x28x28x1xf32> {tf._user_specified_name = "inputs", tf_saved_model.index_path = [0]}, %arg1: tensor<32x1xf32> {tf._user_specified_name = "targets", tf_saved_model.index_path = [1]}, %arg2: tensor<!tf.resource<tensor<5x5x1x32xf32>>> {tf_saved_model.bound_input = @__sm_node17__model.conv1.kernel}, %arg3: tensor<!tf.resource<tensor<5x5x32x32xf32>>> {tf_saved_model.bound_input = @__sm_node26__model.conv2.kernel}, %arg4: tensor<!tf.resource<tensor<1568x1024xf32>>> {tf_saved_model.bound_input = @__sm_node39__model.dense1.kernel}, %arg5: tensor<!tf.resource<tensor<1024xf32>>> {tf_saved_model.bound_input = @__sm_node40__model.dense1.bias}, %arg6: tensor<!tf.resource<tensor<1024x10xf32>>> {tf_saved_model.bound_input = @__sm_node49__model.dense2.kernel}, %arg7: tensor<!tf.resource<tensor<10xf32>>> {tf_saved_model.bound_input = @__sm_node50__model.dense2.bias}, %arg8: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @__sm_node6__optimizer.learning_rate}, %arg9: tensor<!tf.resource<tensor<i64>>> {tf_saved_model.bound_input = @__sm_node4__optimizer.iter}) -> (tensor<f32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [#tf.shape<32x28x28x1>, #tf.shape<32x1>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful, tf_saved_model.exported_names = ["predict"]} {
%0 = mhlo.constant dense<3.125000e-02> : tensor<32x10xf32>
%1 = mhlo.constant dense<3.200000e+01> : tensor<f32>
%2 = mhlo.constant dense<1> : tensor<i64>
%3 = mhlo.constant dense<1.000000e+00> : tensor<f32>
%4 = mhlo.constant dense<0> : tensor<32xi64>
%5 = mhlo.constant dense<10> : tensor<32xi64>
%6 = mhlo.constant dense<0.000000e+00> : tensor<32xf32>
%7 = mhlo.constant dense<0x7FC00000> : tensor<32xf32>
%8 = mhlo.constant dense<0xFF800000> : tensor<f32>
%9 = mhlo.constant dense<0.000000e+00> : tensor<32x1024xf32>
%10 = mhlo.constant dense<0.000000e+00> : tensor<32x14x14x32xf32>
%11 = mhlo.constant dense<0.000000e+00> : tensor<32x28x28x32xf32>
%12 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%13 = mhlo.constant dense<false> : tensor<i1>
%14 = "tf.Cast"(%arg2) {Truncate = false} : (tensor<!tf.resource<tensor<5x5x1x32xf32>>>) -> tensor<!tf.resource>
%15 = "tf.Cast"(%arg3) {Truncate = false} : (tensor<!tf.resource<tensor<5x5x32x32xf32>>>) -> tensor<!tf.resource>
%16 = "tf.Cast"(%arg4) {Truncate = false} : (tensor<!tf.resource<tensor<1568x1024xf32>>>) -> tensor<!tf.resource>
%17 = "tf.Cast"(%arg5) {Truncate = false} : (tensor<!tf.resource<tensor<1024xf32>>>) -> tensor<!tf.resource>
%18 = "tf.Cast"(%arg6) {Truncate = false} : (tensor<!tf.resource<tensor<1024x10xf32>>>) -> tensor<!tf.resource>
%19 = "tf.Cast"(%arg7) {Truncate = false} : (tensor<!tf.resource<tensor<10xf32>>>) -> tensor<!tf.resource>
%20 = "tf.Cast"(%arg9) {Truncate = false} : (tensor<!tf.resource<tensor<i64>>>) -> tensor<!tf.resource>
%21 = "tf.ReadVariableOp"(%15) {device = ""} : (tensor<!tf.resource>) -> tensor<5x5x32x32xf32>
%22 = "tf.ReadVariableOp"(%14) {device = ""} : (tensor<!tf.resource>) -> tensor<5x5x1x32xf32>
%23 = "tf.ReadVariableOp"(%19) {device = ""} : (tensor<!tf.resource>) -> tensor<10xf32>
%24 = "tf.ReadVariableOp"(%18) {device = ""} : (tensor<!tf.resource>) -> tensor<1024x10xf32>
%25 = "tf.ReadVariableOp"(%17) {device = ""} : (tensor<!tf.resource>) -> tensor<1024xf32>
%26 = "tf.ReadVariableOp"(%16) {device = ""} : (tensor<!tf.resource>) -> tensor<1568x1024xf32>
%27 = "mhlo.convolution"(%arg0, %22) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x28x28x1xf32>, tensor<5x5x1x32xf32>) -> tensor<32x28x28x32xf32>
%28 = mhlo.maximum %27, %11 : tensor<32x28x28x32xf32>
%29 = "mhlo.reduce_window"(%28, %8) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.maximum %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<32x28x28x32xf32>, tensor<f32>) -> tensor<32x14x14x32xf32>
%30 = "mhlo.convolution"(%29, %21) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x14x14x32xf32>, tensor<5x5x32x32xf32>) -> tensor<32x14x14x32xf32>
%31 = mhlo.maximum %30, %10 : tensor<32x14x14x32xf32>
%32 = "mhlo.reduce_window"(%31, %8) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.maximum %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<32x14x14x32xf32>, tensor<f32>) -> tensor<32x7x7x32xf32>
%33 = "mhlo.reshape"(%32) : (tensor<32x7x7x32xf32>) -> tensor<32x1568xf32>
%34 = "mhlo.dot"(%33, %26) : (tensor<32x1568xf32>, tensor<1568x1024xf32>) -> tensor<32x1024xf32>
%35 = "mhlo.broadcast_in_dim"(%25) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1024xf32>) -> tensor<32x1024xf32>
%36 = mhlo.add %34, %35 : tensor<32x1024xf32>
%37 = mhlo.maximum %36, %9 : tensor<32x1024xf32>
%38 = "mhlo.dot"(%37, %24) : (tensor<32x1024xf32>, tensor<1024x10xf32>) -> tensor<32x10xf32>
%39 = "mhlo.broadcast_in_dim"(%23) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>) -> tensor<32x10xf32>
%40 = mhlo.add %38, %39 : tensor<32x10xf32>
%41 = "tf.ReadVariableOp"(%arg8) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%42 = "mhlo.convert"(%arg1) : (tensor<32x1xf32>) -> tensor<32x1xi64>
%43 = "mhlo.reshape"(%42) : (tensor<32x1xi64>) -> tensor<32xi64>
%44 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi64>
%45 = "mhlo.broadcast_in_dim"(%44) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xi64>) -> tensor<32x10xi64>
%46 = "mhlo.broadcast_in_dim"(%43) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xi64>) -> tensor<32x10xi64>
%47 = "mhlo.compare"(%46, %45) {comparison_direction = "EQ"} : (tensor<32x10xi64>, tensor<32x10xi64>) -> tensor<32x10xi1>
%48 = "mhlo.broadcast"(%3) {broadcast_sizes = dense<[32, 10]> : tensor<2xi64>} : (tensor<f32>) -> tensor<32x10xf32>
%49 = "mhlo.broadcast"(%12) {broadcast_sizes = dense<[32, 10]> : tensor<2xi64>} : (tensor<f32>) -> tensor<32x10xf32>
%50 = "mhlo.select"(%47, %48, %49) : (tensor<32x10xi1>, tensor<32x10xf32>, tensor<32x10xf32>) -> tensor<32x10xf32>
%51 = "mhlo.compare"(%4, %43) {comparison_direction = "LE"} : (tensor<32xi64>, tensor<32xi64>) -> tensor<32xi1>
%52 = "mhlo.compare"(%43, %5) {comparison_direction = "LT"} : (tensor<32xi64>, tensor<32xi64>) -> tensor<32xi1>
%53 = mhlo.and %51, %52 : tensor<32xi1>
%54 = "mhlo.select"(%53, %6, %7) : (tensor<32xi1>, tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%55 = "mhlo.reshape"(%54) : (tensor<32xf32>) -> tensor<32x1xf32>
%56 = "mhlo.broadcast_in_dim"(%55) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<32x1xf32>) -> tensor<32x10xf32>
%57 = mhlo.add %50, %56 : tensor<32x10xf32>
%58 = "mhlo.negate"(%57) : (tensor<32x10xf32>) -> tensor<32x10xf32>
%59 = "mhlo.reduce"(%40, %8) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.maximum %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<32x10xf32>, tensor<f32>) -> tensor<32xf32>
%60 = "mhlo.broadcast_in_dim"(%59) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>) -> tensor<32x10xf32>
%61 = mhlo.subtract %40, %60 : tensor<32x10xf32>
%62 = "mhlo.exponential"(%61) : (tensor<32x10xf32>) -> tensor<32x10xf32>
%63 = "mhlo.reduce"(%62, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<32x10xf32>, tensor<f32>) -> tensor<32xf32>
%64 = "mhlo.log"(%63) : (tensor<32xf32>) -> tensor<32xf32>
%65 = "mhlo.broadcast_in_dim"(%64) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>) -> tensor<32x10xf32>
%66 = mhlo.subtract %61, %65 : tensor<32x10xf32>
%67 = mhlo.multiply %58, %66 : tensor<32x10xf32>
%68 = "mhlo.reduce"(%67, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<32x10xf32>, tensor<f32>) -> tensor<32xf32>
%69 = "mhlo.reduce"(%40, %8) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.maximum %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<32x10xf32>, tensor<f32>) -> tensor<32xf32>
%70 = "mhlo.broadcast_in_dim"(%69) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>) -> tensor<32x10xf32>
%71 = mhlo.subtract %40, %70 : tensor<32x10xf32>
%72 = "mhlo.exponential"(%71) : (tensor<32x10xf32>) -> tensor<32x10xf32>
%73 = "mhlo.reduce"(%72, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<32x10xf32>, tensor<f32>) -> tensor<32xf32>
%74 = "mhlo.broadcast_in_dim"(%73) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>) -> tensor<32x10xf32>
%75 = mhlo.divide %72, %74 : tensor<32x10xf32>
%76 = mhlo.subtract %75, %57 : tensor<32x10xf32>
%77 = mhlo.multiply %76, %0 : tensor<32x10xf32>
%78 = "mhlo.reduce"(%77, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<32x10xf32>, tensor<f32>) -> tensor<10xf32>
%79 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10xf32>
%80 = mhlo.multiply %79, %78 : tensor<10xf32>
%81 = "tf.ReadVariableOp"(%19) : (tensor<!tf.resource>) -> tensor<*xf32>
%82 = "tf.Sub"(%81, %80) : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32>
"tf.AssignVariableOp"(%19, %82) : (tensor<!tf.resource>, tensor<*xf32>) -> ()
%83 = "mhlo.transpose"(%24) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1024x10xf32>) -> tensor<10x1024xf32>
%84 = "mhlo.dot"(%77, %83) : (tensor<32x10xf32>, tensor<10x1024xf32>) -> tensor<32x1024xf32>
%85 = "mhlo.compare"(%37, %9) {comparison_direction = "GT"} : (tensor<32x1024xf32>, tensor<32x1024xf32>) -> tensor<32x1024xi1>
%86 = "mhlo.select"(%85, %84, %9) : (tensor<32x1024xi1>, tensor<32x1024xf32>, tensor<32x1024xf32>) -> tensor<32x1024xf32>
%87 = "mhlo.reduce"(%86, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<32x1024xf32>, tensor<f32>) -> tensor<1024xf32>
%88 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1024xf32>
%89 = mhlo.multiply %88, %87 : tensor<1024xf32>
%90 = "tf.ReadVariableOp"(%17) : (tensor<!tf.resource>) -> tensor<*xf32>
%91 = "tf.Sub"(%90, %89) : (tensor<*xf32>, tensor<1024xf32>) -> tensor<*xf32>
"tf.AssignVariableOp"(%17, %91) : (tensor<!tf.resource>, tensor<*xf32>) -> ()
%92 = "mhlo.transpose"(%26) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1568x1024xf32>) -> tensor<1024x1568xf32>
%93 = "mhlo.dot"(%86, %92) : (tensor<32x1024xf32>, tensor<1024x1568xf32>) -> tensor<32x1568xf32>
%94 = "mhlo.reshape"(%93) : (tensor<32x1568xf32>) -> tensor<32x7x7x32xf32>
%95 = "mhlo.select_and_scatter"(%31, %94, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = "mhlo.compare"(%arg10, %arg11) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%130) : (tensor<i1>) -> ()
}, {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<32x14x14x32xf32>, tensor<32x7x7x32xf32>, tensor<f32>) -> tensor<32x14x14x32xf32>
%96 = "mhlo.compare"(%31, %10) {comparison_direction = "GT"} : (tensor<32x14x14x32xf32>, tensor<32x14x14x32xf32>) -> tensor<32x14x14x32xi1>
%97 = "mhlo.select"(%96, %95, %10) : (tensor<32x14x14x32xi1>, tensor<32x14x14x32xf32>, tensor<32x14x14x32xf32>) -> tensor<32x14x14x32xf32>
%98 = "mhlo.convolution"(%29, %97) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 3 : i64, input_feature_dimension = 0 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 0 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, output_batch_dimension = 2 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x14x14x32xf32>, tensor<32x14x14x32xf32>) -> tensor<5x5x32x32xf32>
%99 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x5x32x32xf32>
%100 = mhlo.multiply %99, %98 : tensor<5x5x32x32xf32>
%101 = "tf.ReadVariableOp"(%15) : (tensor<!tf.resource>) -> tensor<*xf32>
%102 = "tf.Sub"(%101, %100) : (tensor<*xf32>, tensor<5x5x32x32xf32>) -> tensor<*xf32>
"tf.AssignVariableOp"(%15, %102) : (tensor<!tf.resource>, tensor<*xf32>) -> ()
%103 = "mhlo.reverse"(%21) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32>
%104 = "mhlo.convolution"(%97, %103) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 3 : i64, kernel_output_feature_dimension = 2 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x14x14x32xf32>, tensor<5x5x32x32xf32>) -> tensor<32x14x14x32xf32>
%105 = "mhlo.select_and_scatter"(%28, %104, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = "mhlo.compare"(%arg10, %arg11) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%130) : (tensor<i1>) -> ()
}, {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<32x28x28x32xf32>, tensor<32x14x14x32xf32>, tensor<f32>) -> tensor<32x28x28x32xf32>
%106 = "mhlo.compare"(%28, %11) {comparison_direction = "GT"} : (tensor<32x28x28x32xf32>, tensor<32x28x28x32xf32>) -> tensor<32x28x28x32xi1>
%107 = "mhlo.select"(%106, %105, %11) : (tensor<32x28x28x32xi1>, tensor<32x28x28x32xf32>, tensor<32x28x28x32xf32>) -> tensor<32x28x28x32xf32>
%108 = "mhlo.convolution"(%arg0, %107) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 3 : i64, input_feature_dimension = 0 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 0 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, output_batch_dimension = 2 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x28x28x1xf32>, tensor<32x28x28x32xf32>) -> tensor<5x5x1x32xf32>
%109 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x5x1x32xf32>
%110 = mhlo.multiply %109, %108 : tensor<5x5x1x32xf32>
%111 = "tf.ReadVariableOp"(%14) : (tensor<!tf.resource>) -> tensor<*xf32>
%112 = "tf.Sub"(%111, %110) : (tensor<*xf32>, tensor<5x5x1x32xf32>) -> tensor<*xf32>
"tf.AssignVariableOp"(%14, %112) : (tensor<!tf.resource>, tensor<*xf32>) -> ()
%113 = "mhlo.transpose"(%33) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<32x1568xf32>) -> tensor<1568x32xf32>
%114 = "mhlo.dot"(%113, %86) : (tensor<1568x32xf32>, tensor<32x1024xf32>) -> tensor<1568x1024xf32>
%115 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1568x1024xf32>
%116 = mhlo.multiply %115, %114 : tensor<1568x1024xf32>
%117 = "tf.ReadVariableOp"(%16) : (tensor<!tf.resource>) -> tensor<*xf32>
%118 = "tf.Sub"(%117, %116) : (tensor<*xf32>, tensor<1568x1024xf32>) -> tensor<*xf32>
"tf.AssignVariableOp"(%16, %118) : (tensor<!tf.resource>, tensor<*xf32>) -> ()
%119 = "mhlo.transpose"(%37) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<32x1024xf32>) -> tensor<1024x32xf32>
%120 = "mhlo.dot"(%119, %77) : (tensor<1024x32xf32>, tensor<32x10xf32>) -> tensor<1024x10xf32>
%121 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1024x10xf32>
%122 = mhlo.multiply %121, %120 : tensor<1024x10xf32>
%123 = "tf.ReadVariableOp"(%18) : (tensor<!tf.resource>) -> tensor<*xf32>
%124 = "tf.Sub"(%123, %122) : (tensor<*xf32>, tensor<1024x10xf32>) -> tensor<*xf32>
"tf.AssignVariableOp"(%18, %124) : (tensor<!tf.resource>, tensor<*xf32>) -> ()
%125 = "tf.ReadVariableOp"(%20) : (tensor<!tf.resource>) -> tensor<*xi64>
%126 = "tf.AddV2"(%125, %2) : (tensor<*xi64>, tensor<i64>) -> tensor<*xi64>
"tf.AssignVariableOp"(%20, %126) : (tensor<!tf.resource>, tensor<*xi64>) -> ()
%127 = "mhlo.reduce"(%68, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>, tensor<f32>) -> tensor<f32>
%128 = mhlo.divide %127, %1 : tensor<f32>
%129 = "mhlo.select"(%13, %12, %128) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %129 : tensor<f32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment