Skip to content

Instantly share code, notes, and snippets.

@ebraraktas
Created January 27, 2021 09:43
Show Gist options
  • Save ebraraktas/ab87170deb38eae979b37795015e44bc to your computer and use it in GitHub Desktop.
Save ebraraktas/ab87170deb38eae979b37795015e44bc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "tensordot_vs_matmul_tflite_bug.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "3_6j__XULaSt"
},
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import os\n",
"\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # To suppress TFLiteConverter warnings\n",
"\n",
"\n",
"class DummyMatmul(tf.Module):\n",
" def __init__(self, matrix):\n",
" super().__init__()\n",
" self.matrix = tf.constant(matrix)\n",
"\n",
" @tf.function\n",
" def __call__(self, signal_tensor):\n",
" result = tf.matmul(signal_tensor, self.matrix)\n",
" return -result\n",
"\n",
"\n",
"class DummyTensordot(tf.Module):\n",
" def __init__(self, matrix):\n",
" super().__init__()\n",
" self.matrix = tf.constant(matrix)\n",
"\n",
" @tf.function\n",
" def __call__(self, signal_tensor):\n",
" result = tf.tensordot(signal_tensor, self.matrix, 1)\n",
" return -result\n",
"\n",
"\n",
"def save_concrete_func(tf_module: tf.Module, input_spec: tf.TensorSpec, output_path):\n",
" concrete_func = tf_module.__call__.get_concrete_function(signal_tensor=input_spec)\n",
" # Convert the model\n",
" converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])\n",
" converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]\n",
" converter.experimental_enable_mlir_converter = True\n",
" tflite_model = converter.convert()\n",
" # Save the model.\n",
" with open(output_path, 'wb') as f:\n",
" f.write(tflite_model)\n",
"\n",
"\n",
"def run_interpreter(interpreter, input: np.ndarray):\n",
" interpreter_out_indices = [o['index'] for o in interpreter.get_output_details()]\n",
" interpreter.resize_tensor_input(0, input.shape)\n",
" interpreter.allocate_tensors()\n",
" interpreter.set_tensor(0, input)\n",
" interpreter.invoke()\n",
" interpreter_outs = [interpreter.get_tensor(interpreter_out_index) for interpreter_out_index in\n",
" interpreter_out_indices]\n",
" return interpreter_outs\n",
"\n",
"\n",
"def dummy_test(matrix_shape=(128, 128)):\n",
" np.random.seed(42)\n",
" matrix = np.random.randint(0, 10, matrix_shape).astype(np.float32)\n",
" np.random.seed(42)\n",
" test_input = np.random.randint(0, 10, [1, 6, matrix_shape[0]]).astype(np.float32)\n",
" print('Left matrix:')\n",
" print(test_input)\n",
" print('Right matrix:')\n",
" print(matrix)\n",
"\n",
" signal_tensor_spec = tf.TensorSpec(shape=[1, None, matrix_shape[0]], dtype=tf.float32)\n",
"\n",
" dummy_matmul = DummyMatmul(matrix)\n",
" dummy_tensordot = DummyTensordot(matrix)\n",
" save_concrete_func(dummy_matmul, signal_tensor_spec, 'dummy_matmul.tflite')\n",
" save_concrete_func(dummy_tensordot, signal_tensor_spec, 'dummy_tensordot.tflite')\n",
"\n",
" dummy_matmul_interpreter = tf.lite.Interpreter('dummy_matmul.tflite')\n",
" dummy_tensordot_interpreter = tf.lite.Interpreter('dummy_tensordot.tflite')\n",
"\n",
" print(\"Invoking tf.Modules:\")\n",
" for i in range(3):\n",
" matmul_out = [dummy_matmul(test_input)]\n",
" tensordot_out = [dummy_tensordot(test_input)]\n",
" # print(np.abs(matmul_out[0] - tensordot_out[0])) You can see the diffence with this line\n",
" print(f'Iteration {i} - Maximum Absolute Difference:',\n",
" [np.abs(mout - tout).max() for mout, tout in zip(matmul_out, tensordot_out)])\n",
" print('- * ' * 20)\n",
"\n",
"\n",
" print(\"Invoking interpreters:\")\n",
" for i in range(3):\n",
" matmul_out = run_interpreter(dummy_matmul_interpreter, test_input)\n",
" tensordot_out = run_interpreter(dummy_tensordot_interpreter, test_input)\n",
" # print(np.abs(matmul_out[0] - tensordot_out[0])) You can see the diffence with this line\n",
" print(f'Iteration {i} - Maximum Absolute Difference:',\n",
" [np.abs(mout - tout).max() for mout, tout in zip(matmul_out, tensordot_out)])\n",
" print('- * ' * 20)\n"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cGPRmBm2Lf-a",
"outputId": "5bd62646-83b5-4f5b-ccd7-2032ec4e47d1"
},
"source": [
"dummy_test(matrix_shape=(2, 3)) # This input produces different results\n",
"dummy_test(matrix_shape=(2, 2)) # This input produces same results"
],
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
"Left matrix:\n",
"[[[6. 3.]\n",
" [7. 4.]\n",
" [6. 9.]\n",
" [2. 6.]\n",
" [7. 4.]\n",
" [3. 7.]]]\n",
"Right matrix:\n",
"[[6. 3. 7.]\n",
" [4. 6. 9.]]\n",
"Invoking tf.Modules:\n",
"Iteration 0 - Maximum Absolute Difference: [0.0]\n",
"Iteration 1 - Maximum Absolute Difference: [0.0]\n",
"Iteration 2 - Maximum Absolute Difference: [0.0]\n",
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n",
"Invoking interpreters:\n",
"Iteration 0 - Maximum Absolute Difference: [0.0]\n",
"Iteration 1 - Maximum Absolute Difference: [1134.0]\n",
"Iteration 2 - Maximum Absolute Difference: [1134.0]\n",
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n",
"Left matrix:\n",
"[[[6. 3.]\n",
" [7. 4.]\n",
" [6. 9.]\n",
" [2. 6.]\n",
" [7. 4.]\n",
" [3. 7.]]]\n",
"Right matrix:\n",
"[[6. 3.]\n",
" [7. 4.]]\n",
"Invoking tf.Modules:\n",
"Iteration 0 - Maximum Absolute Difference: [0.0]\n",
"Iteration 1 - Maximum Absolute Difference: [0.0]\n",
"Iteration 2 - Maximum Absolute Difference: [0.0]\n",
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n",
"Invoking interpreters:\n",
"Iteration 0 - Maximum Absolute Difference: [0.0]\n",
"Iteration 1 - Maximum Absolute Difference: [0.0]\n",
"Iteration 2 - Maximum Absolute Difference: [0.0]\n",
"- * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "O97QU7NILhTG"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment