Created
January 27, 2021 09:43
-
-
Save ebraraktas/ab87170deb38eae979b37795015e44bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "tensordot_vs_matmul_tflite_bug.ipynb", | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3_6j__XULaSt" | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import os\n", | |
"\n", | |
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # To suppress TFLiteConverter warnings\n", | |
"\n", | |
"\n", | |
"class DummyMatmul(tf.Module):\n", | |
" def __init__(self, matrix):\n", | |
" super().__init__()\n", | |
" self.matrix = tf.constant(matrix)\n", | |
"\n", | |
" @tf.function\n", | |
" def __call__(self, signal_tensor):\n", | |
" result = tf.matmul(signal_tensor, self.matrix)\n", | |
" return -result\n", | |
"\n", | |
"\n", | |
"class DummyTensordot(tf.Module):\n", | |
" def __init__(self, matrix):\n", | |
" super().__init__()\n", | |
" self.matrix = tf.constant(matrix)\n", | |
"\n", | |
" @tf.function\n", | |
" def __call__(self, signal_tensor):\n", | |
" result = tf.tensordot(signal_tensor, self.matrix, 1)\n", | |
" return -result\n", | |
"\n", | |
"\n", | |
"def save_concrete_func(tf_module: tf.Module, input_spec: tf.TensorSpec, output_path):\n", | |
" concrete_func = tf_module.__call__.get_concrete_function(signal_tensor=input_spec)\n", | |
" # Convert the model\n", | |
" converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])\n", | |
" converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]\n", | |
" converter.experimental_enable_mlir_converter = True\n", | |
" tflite_model = converter.convert()\n", | |
" # Save the model.\n", | |
" with open(output_path, 'wb') as f:\n", | |
" f.write(tflite_model)\n", | |
"\n", | |
"\n", | |
"def run_interpreter(interpreter, input: np.ndarray):\n", | |
" interpreter_out_indices = [o['index'] for o in interpreter.get_output_details()]\n", | |
" interpreter.resize_tensor_input(0, input.shape)\n", | |
" interpreter.allocate_tensors()\n", | |
" interpreter.set_tensor(0, input)\n", | |
" interpreter.invoke()\n", | |
" interpreter_outs = [interpreter.get_tensor(interpreter_out_index) for interpreter_out_index in\n", | |
" interpreter_out_indices]\n", | |
" return interpreter_outs\n", | |
"\n", | |
"\n", | |
"def dummy_test(matrix_shape=(128, 128)):\n", | |
" np.random.seed(42)\n", | |
" matrix = np.random.randint(0, 10, matrix_shape).astype(np.float32)\n", | |
" np.random.seed(42)\n", | |
" test_input = np.random.randint(0, 10, [1, 6, matrix_shape[0]]).astype(np.float32)\n", | |
" print('Left matrix:')\n", | |
" print(test_input)\n", | |
" print('Right matrix:')\n", | |
" print(matrix)\n", | |
"\n", | |
" signal_tensor_spec = tf.TensorSpec(shape=[1, None, matrix_shape[0]], dtype=tf.float32)\n", | |
"\n", | |
" dummy_matmul = DummyMatmul(matrix)\n", | |
" dummy_tensordot = DummyTensordot(matrix)\n", | |
" save_concrete_func(dummy_matmul, signal_tensor_spec, 'dummy_matmul.tflite')\n", | |
" save_concrete_func(dummy_tensordot, signal_tensor_spec, 'dummy_tensordot.tflite')\n", | |
"\n", | |
" dummy_matmul_interpreter = tf.lite.Interpreter('dummy_matmul.tflite')\n", | |
" dummy_tensordot_interpreter = tf.lite.Interpreter('dummy_tensordot.tflite')\n", | |
"\n", | |
" print(\"Invoking tf.Modules:\")\n", | |
" for i in range(3):\n", | |
" matmul_out = [dummy_matmul(test_input)]\n", | |
" tensordot_out = [dummy_tensordot(test_input)]\n", | |
" # print(np.abs(matmul_out[0] - tensordot_out[0])) You can see the diffence with this line\n", | |
" print(f'Iteration {i} - Maximum Absolute Difference:',\n", | |
" [np.abs(mout - tout).max() for mout, tout in zip(matmul_out, tensordot_out)])\n", | |
" print('- * ' * 20)\n", | |
"\n", | |
"\n", | |
" print(\"Invoking interpreters:\")\n", | |
" for i in range(3):\n", | |
" matmul_out = run_interpreter(dummy_matmul_interpreter, test_input)\n", | |
" tensordot_out = run_interpreter(dummy_tensordot_interpreter, test_input)\n", | |
" # print(np.abs(matmul_out[0] - tensordot_out[0])) You can see the diffence with this line\n", | |
" print(f'Iteration {i} - Maximum Absolute Difference:',\n", | |
" [np.abs(mout - tout).max() for mout, tout in zip(matmul_out, tensordot_out)])\n", | |
" print('- * ' * 20)\n" | |
], | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "cGPRmBm2Lf-a", | |
"outputId": "5bd62646-83b5-4f5b-ccd7-2032ec4e47d1" | |
}, | |
"source": [ | |
"dummy_test(matrix_shape=(2, 3)) # This input produces different results\n", | |
"dummy_test(matrix_shape=(2, 2)) # This input produces same results" | |
], | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Left matrix:\n", | |
"[[[6. 3.]\n", | |
" [7. 4.]\n", | |
" [6. 9.]\n", | |
" [2. 6.]\n", | |
" [7. 4.]\n", | |
" [3. 7.]]]\n", | |
"Right matrix:\n", | |
"[[6. 3. 7.]\n", | |
" [4. 6. 9.]]\n", | |
"Invoking tf.Modules:\n", | |
"Iteration 0 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 1 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 2 - Maximum Absolute Difference: [0.0]\n", | |
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n", | |
"Invoking interpreters:\n", | |
"Iteration 0 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 1 - Maximum Absolute Difference: [1134.0]\n", | |
"Iteration 2 - Maximum Absolute Difference: [1134.0]\n", | |
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n", | |
"Left matrix:\n", | |
"[[[6. 3.]\n", | |
" [7. 4.]\n", | |
" [6. 9.]\n", | |
" [2. 6.]\n", | |
" [7. 4.]\n", | |
" [3. 7.]]]\n", | |
"Right matrix:\n", | |
"[[6. 3.]\n", | |
" [7. 4.]]\n", | |
"Invoking tf.Modules:\n", | |
"Iteration 0 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 1 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 2 - Maximum Absolute Difference: [0.0]\n", | |
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n", | |
"Invoking interpreters:\n", | |
"Iteration 0 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 1 - Maximum Absolute Difference: [0.0]\n", | |
"Iteration 2 - Maximum Absolute Difference: [0.0]\n", | |
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O97QU7NILhTG" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment