Skip to content

Instantly share code, notes, and snippets.

@tcwalther
Last active January 8, 2021 14:43
Show Gist options
  • Save tcwalther/f1f2a31a2f2fba3e8f2fa3ea99164002 to your computer and use it in GitHub Desktop.
Save tcwalther/f1f2a31a2f2fba3e8f2fa3ea99164002 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import scipy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RFFT Tests"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def test_fft_length_matches_input_size():\n",
" input = np.array([1, 2, 3, 4, 3, 8, 6, 3, 5, 2, 7, 6, 9, 5, 8, 3]).reshape((4, 4))\n",
" \n",
" result = tf.signal.rfft2d(input, (4,4)).numpy()\n",
" \n",
" expected_result = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n",
" -15, -2+13j, -5, -10-5j, 3-6j, -6-11j])\n",
" np.testing.assert_array_equal(result.reshape(-1), expected_result)\n",
" \n",
"test_fft_length_matches_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def test_fft_length_smaller_than_input_size():\n",
" input = np.array([1, 2, 3, 4, 0, 3, 8, 6, 3, 0, 5, 2, 7, 6, 0, 9, 5, 8, 3, 0]).reshape((4, 5))\n",
" \n",
" np_result = np.fft.rfft2(input, (4, 4))\n",
" tf_result = tf.signal.rfft2d(input, fft_length=(4, 4)).numpy()\n",
" \n",
" expected_result = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n",
" -15, -2+13j, -5, -10-5j, 3-6j, -6-11j])\n",
" \n",
" np.testing.assert_array_almost_equal(np_result.reshape(-1), expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result.reshape(-1), expected_result)\n",
"\n",
"test_fft_length_smaller_than_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def test_fft_length_greater_than_input_size():\n",
" input = np.array([[1, 2, 3, 4],\n",
" [3, 8, 6, 3],\n",
" [5, 2, 7, 6]])\n",
" \n",
" np_result = np.fft.rfft2(input, (4, 8)).reshape(-1)\n",
" tf_result = tf.signal.rfft2d(input, fft_length=(4, 8)).numpy().reshape(-1)\n",
" \n",
" expected_result = np.array([\n",
" 50, 8.29289341-33.6776695j, -7+1j, 9.70710659-1.67766953j, 0,\n",
" -10-20j, -16.3639603-1.12132037j, -5+1j, -7.19238806-2.05025244j, -6+2j,\n",
" 10, -4.7781744-6.12132025j, -1+11j, 10.7781744+1.87867963j, 4,\n",
" -10+20j, 11.1923885+11.9497471j, 5-5j, -3.63603902-3.12132025j, -6-2j])\n",
"\n",
" np.testing.assert_array_almost_equal(np_result, expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result, expected_result)\n",
" \n",
"test_fft_length_greater_than_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def test_input_dims_greater_than_2():\n",
" input = np.array([1, 2, 3, 4, 3, 8, 6, 3, 5, 2, 7, 6, 7, 3, 23, 5]).reshape((2, 2, 4))\n",
" \n",
" np_result = np.fft.rfft2(input, (2, 4)).reshape(-1)\n",
" tf_result = tf.signal.rfft2d(input, fft_length=(2, 4)).numpy().reshape(-1)\n",
"\n",
" expected_result = np.array([\n",
" 30, -5-3j, -4,\n",
" -10, 1+7j, 0,\n",
" 58, -18+6j, 26,\n",
" -18, 14+2j, -18\n",
" ])\n",
"\n",
" np.testing.assert_array_almost_equal(np_result, expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result, expected_result)\n",
"\n",
"test_input_dims_greater_than_2()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# IRFFT Tests"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def test_fft_length_matches_input_size():\n",
" input = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n",
" -15, -2+13j, -5, -10-5j, 3-6j, -6-11j]).reshape(4, 3)\n",
" \n",
" np_result = np.fft.irfft2(input, (4, 4)).reshape(-1)\n",
" tf_result = tf.signal.irfft2d(input, fft_length=(4, 4)).numpy().reshape(-1)\n",
" \n",
" expected_result = np.array([1, 2, 3, 4, 3, 8, 6, 3, 5, 2, 7, 6, 9, 5, 8, 3])\n",
" \n",
" np.testing.assert_array_almost_equal(np_result, expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result, expected_result)\n",
" \n",
"test_fft_length_matches_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(4, 3)\n"
]
}
],
"source": [
"def test_fft_length_smaller_than_input_size():\n",
" input = np.array([1, 2, 3, 4, 0, 3, 8, 6, 3, 0, 5, 2, 7, 6, 0, 9, 5, 8, 3, 0]).reshape((4, 5))\n",
" \n",
" np_result = np.fft.rfft2(input, (4, 4))\n",
" tf_result = tf.signal.rfft2d(input, fft_length=(4, 4)).numpy()\n",
" print(np_result.shape)\n",
" \n",
" expected_result = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n",
" -15, -2+13j, -5, -10-5j, 3-6j, -6-11j])\n",
" \n",
" np.testing.assert_array_almost_equal(np_result.reshape(-1), expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result.reshape(-1), expected_result)\n",
"\n",
"test_fft_length_smaller_than_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def test_fft_length_smaller_than_input_size():\n",
" input = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j,\n",
" -15, -2+13j, -5, -10-5j, 3-6j, -6-11j]).reshape(4, 3)\n",
" \n",
" np_result = np.fft.irfft2(input, (2, 2))\n",
" tf_result = tf.signal.irfft2d(input, fft_length=(2, 2)).numpy()\n",
" \n",
" expected_result = np.array([14, 18.5,\n",
" 20.5, 22]).reshape(2, 2)\n",
" \n",
" np.testing.assert_array_almost_equal(np_result, expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result, expected_result)\n",
" \n",
"test_fft_length_smaller_than_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def test_fft_length_greater_than_input_size():\n",
" input = np.array([75, -6-1j, 9, -10+5j, -3+2j, -6+11j, -15, -2+13j, -5, -10-5j, 3-6j, -6-11j]).reshape((4, 3))\n",
" \n",
" np_result = np.fft.irfft2(input, (4, 8))\n",
" tf_result = tf.signal.irfft2d(input, fft_length=(4, 8)).numpy()\n",
" \n",
" expected_result = np.array([[0.25, 0.54289322, 1.25, 1.25, 1.25, 1.95710678, 2.25, 1.25],\n",
" [1.25, 2.85355339, 4.25, 3.91421356, 2.75, 2.14644661, 1.75, 1.08578644],\n",
" [3., 1.43933983, 0.5, 2.14644661, 4., 3.56066017, 2.5, 2.85355339],\n",
" [5.625, 3.65533009, 1.375, 3.3017767, 5.125, 2.59466991, 0.375, 2.9482233]])\n",
" \n",
" np.testing.assert_array_almost_equal(np_result, expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result, expected_result)\n",
" \n",
"test_fft_length_greater_than_input_size()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def test_input_dims_greater_than_2():\n",
" input = np.array([30, -5-3j, -4,\n",
" -10, 1+7j, 0,\n",
" 58, -18+6j, 26,\n",
" -18, 14+2j, -18]).reshape(2, 2, 3)\n",
" \n",
" np_result = np.fft.irfft2(input, (2, 4))\n",
" tf_result = tf.signal.irfft2d(input, fft_length=(2, 4)).numpy()\n",
" \n",
" expected_result = np.array([1., 2., 3., 4., 3., 8., 6., 3.,\n",
" 5., 2., 7., 6., 7., 3., 23., 5.]).reshape(2, 2, 4)\n",
"\n",
" np.testing.assert_array_almost_equal(np_result, expected_result)\n",
" np.testing.assert_array_almost_equal(tf_result, expected_result)\n",
" \n",
"test_input_dims_greater_than_2()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:tflite]",
"language": "python",
"name": "conda-env-tflite-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment