Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Created January 27, 2020 19:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stellaraccident/0ee28805fb6da9de1dea831d99e6a811 to your computer and use it in GitHub Desktop.
Save stellaraccident/0ee28805fb6da9de1dea831d99e6a811 to your computer and use it in GitHub Desktop.
examples of hlo granularity
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 293 : i32}} {
flow.variable @h1_bias mutable dense<[1.51671076, -1.03060472, 0.786281049, -0.111620337, 1.81119263, -0.489863962, -1.35557854, 1.12750614, -2.68010569, -1.31835032, 1.32360709, -0.169878066, -2.02759194, -1.08895075, 0.00321231596, -0.31182602]> : tensor<16xf32>
flow.variable @h1_weights mutable dense<[[-1.32382154, -0.549432516, 0.367527097, 0.727583051, -0.200104922, 0.803734958, 0.12167716, 1.32091141, -0.532794356, -0.784785628, -0.228998855, 0.517136097, -0.699431359, -0.73973155, 0.743836284, -0.946993887], [0.179826155, -0.49125874, -1.25974214, 1.15823603, -0.431264639, 0.251312494, -0.443345934, -2.01710749, 1.14093256, 0.719460964, -0.530746937, -2.79470325, -1.53676498, -1.249620e+00, 0.0132142669, -0.21497789], [1.0148958, -0.07205116, 0.406493574, -1.13559055, 0.363096684, 0.349495828, -1.05846632, 0.435198575, -0.0271317028, -0.122605741, -0.127589807, -0.243348598, -2.24777889, 0.653263628, -0.37310189, -1.71895993], [-0.646140158, 1.6863637, -0.860919892, 0.5264135, 3.2894721, -0.0983939543, 1.45377111, 2.8501966, 0.783811926, 0.588116825, -0.171106666, -0.72100544, 0.440455109, 0.108651742, 0.929431438, -1.61530697], [0.559388041, 0.364464134, -0.764375269, 2.27214599, 0.340933681, -0.875317394, -0.356756479, 8.754420e-02, -0.817409038, 0.898144125, 0.486458927, 0.0284955949, 1.56958246, -0.435689867, 1.38073099, -0.889448106], [0.258163512, -1.07554412, -1.72254181, -0.566617608, -1.15463114, 0.275421321, -0.18840237, 0.163536638, -0.485031962, -0.376853019, 0.398864329, 9.120600e-02, 1.0214175, -8.278740e-01, -0.523520172, -1.74436307], [7.966170e-03, 0.68441093, 1.65160108, -0.606498241, -1.40209818, 1.25619709, 0.00928257592, -0.597413242, -0.58628118, -1.14331686, 1.43713844, 0.560160637, 0.52714777, -0.581045747, -1.24046147, -1.6064477], [0.28464216, 0.622724533, -1.513320e+00, -0.283951521, -0.572764337, 0.265616119, -0.577383697, -0.458614886, -1.3270402, -0.0247978028, -0.962619423, -0.834290087, 0.462049156, -0.515147567, 9.894610e-01, -1.02187645], [2.21359825, -0.493265629, -0.627484142, -0.0896971598, 1.6319505, 1.03338182, -0.302158237, 1.42723691, -1.94194436, 0.306814611, 0.978998363, 0.214622647, 0.453215778, 0.709018946, -0.641262829, 0.806017994], [0.367737174, -1.975040e+00, -0.993579208, -0.123102225, 0.237505108, -1.45136499, -0.427790314, -1.59556103, -0.456848472, -1.52381814, -0.381176412, -0.148052752, 1.13024783, -1.02772939, 1.61302423, -0.890073538], [-1.10121667, 0.270123214, -0.364642352, -1.1423806, 1.13514245, 0.729299366, 0.0271606985, -0.109672539, -0.937480509, 2.99719644, -1.41048658, -0.639986634, -1.65872514, 1.13170183, 7.939230e-01, -0.634354054], [1.05086374, 0.438791245, 0.852006316, -4.385830e-02, 0.834755301, -0.0683089867, 1.01507759, -0.232999325, -0.800174952, 0.984396934, -1.76124716, 0.217537552, -0.445734024, 0.0509375632, 0.29531914, -0.702082396], [0.459670752, 0.283051282, -0.574137866, -0.605766594, 0.815231382, 0.225581273, 0.144795522, -0.638931453, 1.16724384, -0.0104969749, 0.23039332, -0.815333843, -0.127822071, 1.84493363, -1.51086462, 0.341582865], [0.301393211, 0.453265458, -1.93748128, 1.16636097, -1.62139785, 2.70223427, -0.104579866, -0.590782762, -0.893690407, -1.80861509, 0.058574371, -0.360783517, -0.402437121, 0.237567663, -1.01468921, -0.817559063], [-0.128787071, 0.857713997, -0.745054126, 0.210000157, -0.0683295056, -0.788205385, -0.854343831, -2.15483403, 1.42367768, 0.248510212, 0.56085211, -0.684432626, -1.11980283, 1.28147912, 0.435670376, 0.240455762], [-3.14416265, -0.876971542, -1.15776503, 1.27019072, -0.164921239, 2.13830519, 0.102101341, 1.38634789, 2.18710423, 0.700603485, 0.443497092, -0.434374183, -0.642774641, -0.280788094, -0.780260562, -0.49473241]]> : tensor<16x16xf32>
flow.variable @h2_bias mutable dense<[-0.967021465, -0.661053777, 0.463357359, -0.242546469, -0.403963268, -0.99858874, -0.016528355, 0.281360239, -1.22633886, 0.934534668, -1.64829016, -0.389897108, -1.32179987, -0.622112929, -2.52271295, -1.21434605]> : tensor<16xf32>
flow.variable @h2_weights mutable dense<[[-2.13222814, 1.88626814, 0.670573651, 0.482707053, 0.448073238, -1.5314455, 0.654727458, 2.41940427, 1.08034825, 2.10553885, 1.48635793, 1.12497866, 1.36419261, 0.13896206, -0.467206061, 1.91503072], [-0.0662296936, 1.01473761, -0.615601241, 0.978598356, 0.826115847, 8.179330e-02, 1.0840745, 1.3813957, -0.387237668, -0.103054583, -0.398758054, -1.19285882, 0.107874155, -0.303926617, 0.252907723, 0.652497112], [0.927725554, -1.12642908, -2.09232068, -0.665811419, 1.34143126, -0.45511964, -1.2244817, -1.96985805, 0.592386425, 1.31674743, 0.587773263, -0.26006645, -0.87894994, 0.486421824, -0.682872772, -0.851798593], [0.106736414, 2.05440736, 0.0519925728, 0.238654211, -1.21667862, -0.382746428, -0.0899780616, -1.11597872, -1.2264291, 0.889204264, 1.66972291, 0.0722560584, 0.880449116, 0.472382516, -0.783615052, 1.15327919], [-0.129865691, 0.36419332, -2.25143456, 0.00790903252, 0.478828371, 0.0289189145, 0.660441696, 0.906809687, -0.281814605, 0.9209885, -0.742295861, 1.21671712, 0.916797995, 1.7961055, -2.35085821, 0.18215315], [-0.0301297065, 0.221202955, 0.275484145, 0.914768159, -0.601035297, -0.173001096, 0.674126744, -1.84882367, -0.15277335, -0.373465866, -0.0120063508, 0.219989568, 0.844797611, 1.43350768, 0.736365139, -0.438352674], [1.9144367, 1.26891935, -0.263682514, 0.34708482, 9.492310e-02, -0.339362174, -0.671611368, -1.57363975, 1.35520077, 0.310628414, -0.770133793, -0.42299515, 0.474828213, 0.67529422, -1.2379235, 1.01158977], [-0.391927302, 0.743282973, -0.982167363, -0.119729631, -1.33537877, -1.164930e+00, -0.359153658, -0.675622582, -0.446195781, 1.01512158, -0.562348902, -0.286203295, -0.835502743, -0.855140149, 0.0838188156, 0.819264888], [-1.86946452, 0.419203371, -1.24075806, -0.752735734, 0.214274839, -1.33938313, -0.304385483, 1.68522859, 0.690583705, 0.00531598181, 2.02068734, 0.535425603, -0.979015767, -0.0799265429, 1.18853271, 1.81206512], [0.0338045545, 0.481575817, 1.11879599, 0.531895757, -0.193092853, -0.589684844, 0.229198456, -1.85691774, -0.500680208, 0.641366601, -0.565744281, 0.20245333, 0.256212652, -0.801903665, 0.511094451, 0.342053592], [-0.307935685, 1.45732141, 0.952566206, -0.122174278, 0.539054692, 0.217579946, 0.652594149, -0.731528937, -0.365515947, 0.0831764116, -0.994086205, 0.0167690143, -1.11587358, -0.355338097, 0.474682778, -1.98097193], [-1.40860844, -1.1918596, 0.0346903838, -1.0781666, 0.294729888, 0.437439322, 0.305828601, 0.716489077, -0.222594798, -0.839936912, -1.98863161, 1.40359128, -0.0976508855, 0.153071091, -0.877907515, -2.29928875], [-0.693776965, 0.354467422, 1.02242219, 0.542613924, -0.695461452, 0.918443977, -0.84615308, -0.688789427, 0.0306969099, -0.645426571, 0.364052057, -1.38542044, 0.571529567, -1.13256037, 1.96275496, 0.406928778], [-1.94173968, -0.178654611, -1.19487393, -0.322406977, 0.380649149, 0.32583335, 0.373116672, -0.123307616, -0.610563278, -0.701816499, -0.638574302, 0.412249923, -0.187399045, 1.52940667, 0.865426659, 0.19339177], [3.922240e-02, 0.649915338, -7.613330e-02, 0.593228161, 0.516314864, 0.445739955, -0.893786549, 0.445071459, 0.164744481, -0.513481081, -2.88612103, -0.914083957, -1.40616357, 5.011920e-01, -0.947912454, -0.304788023], [-0.730117083, 0.205610394, -0.616440355, -7.180330e-01, 0.183839872, 0.507514238, 1.03348029, 1.91910505, -0.711628556, 1.11151874, -0.333912671, -0.0121291243, 0.0764675885, -1.44195652, 0.427113056, 0.194617867]]> : tensor<16x16xf32>
flow.variable @out_bias mutable dense<[0.437576413, -0.394484192, -0.45458594, 0.405400574, -0.198196054, 1.09229314, 1.62638128, 0.311785311, 0.527803242, 0.647038996]> : tensor<10xf32>
flow.variable @out_weights mutable dense<[[-0.216886863, 2.63177204, -0.468145043, 1.21687174, -0.0127581339, -0.823347926, 0.105535813, 5.166300e-01, 1.09568524, -0.888645768], [-1.42299247, -2.52306557, 1.62807822, -1.31459236, -0.0839586407, -0.466504842, 1.20906091, -0.175997272, -0.523951769, -0.780084908], [-0.801886975, -0.847262561, 0.590660214, -0.154939383, -0.213584095, 0.0020446002, 0.735446572, 8.115490e-01, -1.45244837, 0.753187537], [-0.360525191, -0.171334893, -0.431902558, 0.27747193, 1.29720938, -0.818946182, 0.128241032, 0.939774275, -0.477430284, 1.2073245], [0.16224964, -1.72704029, 1.95572603, 0.349767357, 0.814667344, -1.07649386, -1.86366332, 0.547934175, 0.739157796, -0.199198663], [0.0158156492, 1.53092647, -0.14321664, 0.349571645, 0.329990983, 0.0675868616, -0.300224036, -0.515419245, 0.562449574, -0.341419965], [0.252103835, -0.979151844, -0.901627719, -0.490994513, 0.249424621, 0.643954455, -0.338068098, -0.996651351, 0.747954428, -1.11576712], [1.34119046, 0.415901124, -1.44766223, 0.895626664, -2.07433796, -1.63365293, -1.68218136, 0.0226093177, -0.680986047, -0.376427203], [-0.639261424, 1.70271373, -0.554883242, 0.588306487, 0.0282016899, 0.200512335, -0.622929275, -0.098626405, -1.096452, -1.30428338], [0.0871729106, 0.189050868, -0.745543181, 1.50433517, 0.396510452, 1.61612141, -1.3233881, -2.17477059, 1.45560098, -1.76575744], [-0.716714799, -4.730760e-01, 1.93656707, -0.46339193, -1.1567589, -1.10739279, -1.28075111, -0.187897861, 0.33074674, -0.415498972], [2.16509533, -2.81075215, -0.970568299, 1.09478748, 0.221125141, -0.673275828, -0.561795175, 1.38099718, 0.122936457, -1.44553244], [0.157974631, 0.657304168, -1.58168328, -0.606899142, -1.27687705, 0.53633827, 1.31574571, -0.656811953, -0.655197441, -0.749383151], [3.3339951, -0.44636789, 0.463400811, 1.82535017, 1.69797051, -1.53696215, 0.110338919, -0.0924209058, -1.12320101, -1.35927022], [-0.520103872, -0.793788075, -0.113174565, -0.733935058, 0.12740241, -0.257638067, -0.147433698, 0.35111174, -1.99087203, 0.369820863], [0.416022748, 0.377101332, 0.753733754, 0.929480195, 0.813521147, -0.0243832674, 0.287252605, -0.337466508, 0.951832711, 0.3508012]]> : tensor<16x10xf32>
func @predict(%arg0: tensor<?x16xf32>) -> tensor<?x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 16 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.signature.is_stateful} {
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%1 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = flow.variable.load @h2_bias : tensor<16xf32>
%4 = flow.variable.load @out_bias : tensor<10xf32>
%5 = flow.variable.load @h1_bias : tensor<16xf32>
%6 = flow.variable.load @h2_weights : tensor<16x16xf32>
%7 = flow.variable.load @out_weights : tensor<16x10xf32>
%8 = flow.variable.load @h1_weights : tensor<16x16xf32>
%9 = "xla_hlo.dot"(%arg0, %8) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%10 = "xla_hlo.add"(%9, %5) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%11 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 16]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x16xf32>
%12 = xla_hlo.mul %10, %11 : tensor<?x16xf32>
%13 = "xla_hlo.tanh"(%12) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%14 = xla_hlo.mul %13, %11 : tensor<?x16xf32>
%15 = xla_hlo.add %14, %11 : tensor<?x16xf32>
%16 = "xla_hlo.dot"(%15, %6) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%17 = "xla_hlo.add"(%16, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%18 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 16]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x16xf32>
%19 = xla_hlo.mul %17, %18 : tensor<?x16xf32>
%20 = "xla_hlo.tanh"(%19) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%21 = xla_hlo.mul %20, %18 : tensor<?x16xf32>
%22 = xla_hlo.add %21, %18 : tensor<?x16xf32>
%23 = "xla_hlo.dot"(%22, %7) : (tensor<?x16xf32>, tensor<16x10xf32>) -> tensor<?x10xf32>
%24 = "xla_hlo.add"(%23, %4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
%25 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 10]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x10xf32>
%26 = xla_hlo.mul %24, %25 : tensor<?x10xf32>
%27 = "xla_hlo.tanh"(%26) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%28 = xla_hlo.mul %27, %25 : tensor<?x10xf32>
%29 = xla_hlo.add %28, %25 : tensor<?x10xf32>
%30 = "xla_hlo.reduce"(%29, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%35 = xla_hlo.max %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%35) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%31 = "xla_hlo.sub"(%29, %30) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<?xf32>) -> tensor<?x10xf32>
%32 = "xla_hlo.exp"(%31) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%33 = "xla_hlo.reduce"(%32, %2) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%35 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%35) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%34 = "xla_hlo.div"(%32, %33) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<?xf32>) -> tensor<?x10xf32>
return %34 : tensor<?x10xf32>
}
func @predict_tanh_no_softmax(%arg0: tensor<?x16xf32>) -> tensor<?x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 16 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.signature.is_stateful} {
%0 = flow.variable.load @h2_bias : tensor<16xf32>
%1 = flow.variable.load @out_bias : tensor<10xf32>
%2 = flow.variable.load @h1_bias : tensor<16xf32>
%3 = flow.variable.load @h2_weights : tensor<16x16xf32>
%4 = flow.variable.load @out_weights : tensor<16x10xf32>
%5 = flow.variable.load @h1_weights : tensor<16x16xf32>
%6 = "xla_hlo.dot"(%arg0, %5) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%7 = "xla_hlo.add"(%6, %2) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%8 = "xla_hlo.tanh"(%7) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%9 = "xla_hlo.dot"(%8, %3) : (tensor<?x16xf32>, tensor<16x16xf32>) -> tensor<?x16xf32>
%10 = "xla_hlo.add"(%9, %0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x16xf32>, tensor<16xf32>) -> tensor<?x16xf32>
%11 = "xla_hlo.tanh"(%10) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%12 = "xla_hlo.dot"(%11, %4) : (tensor<?x16xf32>, tensor<16x10xf32>) -> tensor<?x10xf32>
%13 = "xla_hlo.add"(%12, %1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
%14 = "xla_hlo.tanh"(%13) : (tensor<?x10xf32>) -> tensor<?x10xf32>
return %14 : tensor<?x10xf32>
}
}
// op-carried
%24 = "xla_hlo.add"(%23, %4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
// explicit
%8 = "xla_hlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<16xf32>) -> tensor<?x16xf32>
%9 = xla_hlo.add %7, %8 : tensor<?x16xf32>
%36 = "xla_hlo.reduce"(%35, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%45 = xla_hlo.max %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%45) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%37 = "xla_hlo.broadcast_in_dim"(%35) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>) -> tensor<?x10xf32>
%38 = "xla_hlo.broadcast_in_dim"(%36) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>) -> tensor<?x10xf32>
%39 = xla_hlo.sub %37, %38 : tensor<?x10xf32>
%40 = "xla_hlo.exp"(%39) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%41 = "xla_hlo.reduce"(%40, %2) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%45 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%45) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
// TANH
%10 = "xla_hlo.tanh"(%9) : (tensor<?x16xf32>) -> tensor<?x16xf32>
// SIGMOID
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%13 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[-1, 16]> : tensor<2xi64>} : (tensor<f32>) -> tensor<?x16xf32>
%14 = xla_hlo.mul %12, %13 : tensor<?x16xf32>
%15 = "xla_hlo.tanh"(%14) : (tensor<?x16xf32>) -> tensor<?x16xf32>
%16 = xla_hlo.mul %15, %13 : tensor<?x16xf32>
%17 = xla_hlo.add %16, %13 : tensor<?x16xf32>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment