Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ShawnHymel/610d31df82b3ffdd7416dfec99badf74 to your computer and use it in GitHub Desktop.
Save ShawnHymel/610d31df82b3ffdd7416dfec99badf74 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from os import listdir\n",
"from os.path import isdir, join\n",
"import array\n",
"import struct\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
"import librosa\n",
"import sounddevice as sd\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers, models, optimizers, regularizers, backend\n",
"from tensorflow import lite"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TensorFlow v2.0.0\n",
"Keras v2.2.4-tf\n"
]
}
],
"source": [
"# Print versions\n",
"print('TensorFlow v' + tf.__version__)\n",
"print('Keras v' + tf.keras.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Settings (some of the many hyperparameters)\n",
"dataset_path = 'C:\\\\Users\\\\sgmustadio\\\\Documents\\\\Python\\\\datasets\\\\speech_commands_dataset'\n",
"target_word = 'stop'\n",
"model_filename = 'wake_word_stft_stop_model.h5'\n",
"tflite_model_filename = 'wake_word_stft_stop_model.tflite'\n",
"c_model_name = 'wake_word_stft_stop_model'\n",
"rec_length = 1.0 # Time (seconds) of expected recordings\n",
"perc_target = 0.1 # Percentage of samples that should be our target wake word\n",
"val_ratio = 0.2 # Percentage of samples that should be held for validation set\n",
"test_ratio = 0.2 # Percentage of samples that should be held for test set\n",
"sample_rate = 16000 # The Arduino Nano 33 BLE basically forces 16 kHz sampling on us\n",
"stft_n_fft = 512 # Number of FFT bins (also, number of samples in each slice)\n",
"stft_hop_length = 340 # Distance between each FFT slice (number of samples)\n",
"stft_window = 'hanning' # \"The window of choice if you don't have any better ideas\"\n",
"stft_n_windows = 46 # Number of slices (windows) to look for in each STFT\n",
"stft_min_bin = 1 # Lowest bin to use (inclusive; basically, filter out DC) \n",
"stft_max_bin = 65 # Highest bin (exclusive; basically, filter out >2kHz)\n",
"shift_n_bits = 3 # Number of bits to shift 16-bit STFT values to make 8-bit values (before clipping)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']\n"
]
}
],
"source": [
"# Create an all targets list (without background noise set)\n",
"all_targets = [name for name in listdir(dataset_path) if isdir(join(dataset_path, name))]\n",
"all_targets.remove('_background_noise_')\n",
"print(all_targets)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of samples: 105829\n"
]
}
],
"source": [
"# Create a list of file paths along with ground truth vector (y)\n",
"# Note that y[i] is 0 for \"not <target word>\" and 1 for \"<target word>\"\n",
"target_filenames = []\n",
"target_y = []\n",
"other_filenames = []\n",
"other_y = []\n",
"num_samples = 0\n",
"for index, target in enumerate(all_targets):\n",
" num_samples += len(listdir(join(dataset_path, target)))\n",
" samples_in_dir = listdir(join(dataset_path, target))\n",
" samples_in_dir = [join(dataset_path, target, sample) for sample in samples_in_dir]\n",
" if target == target_word:\n",
" target_filenames.append(samples_in_dir)\n",
" target_y.append(np.ones(len(samples_in_dir)))\n",
" else:\n",
" other_filenames.append(samples_in_dir)\n",
" other_y.append(np.zeros(len(samples_in_dir)))\n",
"print('Total number of samples:', num_samples)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of target samples: 3872\n",
"Number of other samples: 34848\n",
"Number of total samples: 38720\n"
]
}
],
"source": [
"# Calculate number of target and other samples based on desired percentage\n",
"num_target_samples = len(listdir(join(dataset_path, target_word)))\n",
"if (num_target_samples / num_samples) < perc_target:\n",
" num_total_samples = round(num_target_samples / perc_target)\n",
"else:\n",
" num_total_samples = num_samples\n",
" num_target_samples = round(perc_target * num_total_samples)\n",
"num_other_samples = num_total_samples - num_target_samples\n",
"print('Number of target samples:', num_target_samples)\n",
"print('Number of other samples:', num_other_samples)\n",
"print('Number of total samples:', num_total_samples)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Flatten target filename and y vectors. Zip them.\n",
"target_filenames = [item for sublist in target_filenames for item in sublist]\n",
"target_y = [item for sublist in target_y for item in sublist]\n",
"target_filenames_y = list(zip(target_filenames, target_y))\n",
"\n",
"# Flatten other filename and y vectors. Zip them.\n",
"other_filenames = [item for sublist in other_filenames for item in sublist]\n",
"other_y = [item for sublist in other_y for item in sublist]\n",
"other_filenames_y = list(zip(other_filenames, other_y))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Shuffle target filename/truth list and only keep the desired amount\n",
"random.shuffle(target_filenames_y)\n",
"target_filenames_y = target_filenames_y[:num_target_samples]\n",
"\n",
"# Shuffle other filename/truth list and only keep the desired amount\n",
"random.shuffle(other_filenames_y)\n",
"other_filenames_y = other_filenames_y[:num_other_samples]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Combine target and other lists, shuffle, and unzip\n",
"filenames_y = target_filenames_y + other_filenames_y\n",
"random.shuffle(filenames_y)\n",
"filenames, y = zip(*filenames_y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Calculate validation and test set sizes\n",
"val_set_size = int(len(filenames) * val_ratio)\n",
"test_set_size = int(len(filenames) * test_ratio)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training samples: 23232\n",
"Number of validation samples: 7744\n",
"Number of test samples: 7744\n"
]
}
],
"source": [
"# Break dataset apart into train, validation, and test sets\n",
"filenames_val = filenames[:val_set_size]\n",
"filenames_test = filenames[val_set_size:(val_set_size + test_set_size)]\n",
"filenames_train = filenames[(val_set_size + test_set_size):]\n",
"print('Number of training samples:', len(filenames_train))\n",
"print('Number of validation samples:', len(filenames_val))\n",
"print('Number of test samples:', len(filenames_test))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Break y apart into train, validation, and test sets\n",
"y_orig_val = y[:val_set_size]\n",
"y_orig_test = y[val_set_size:(val_set_size + test_set_size)]\n",
"y_orig_train = y[(val_set_size + test_set_size):]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Function: load file, play it, draw waveform, draw spectrogram\n",
"def analyze_clip(file_path):\n",
" \n",
" # Load file\n",
" waveform, fs = librosa.load(file_path, sr=sample_rate)\n",
" \n",
" # Test playing it\n",
" sd.play(waveform, fs)\n",
" \n",
" # Convert floating point wav data (-1.0 to 1.0) to 16-bit PCM\n",
" waveform = np.around(waveform * 32767)\n",
" \n",
" # Calculate STFT\n",
" stft = np.abs(librosa.stft(waveform,\n",
" n_fft=stft_n_fft,\n",
" hop_length=stft_hop_length,\n",
" win_length=stft_n_fft,\n",
" window=stft_window,\n",
" center=False))\n",
" \n",
" # Adjust for quantization and scaling in 16-bit fixed point FFT\n",
" stft = np.around(stft / stft_n_fft)\n",
" \n",
" # Reduce precision by converting to 8-bit unsigned values [0..255]\n",
" stft = np.around(stft / (2 ** shift_n_bits))\n",
" stft = np.clip(stft, a_min=0, a_max=255)\n",
" \n",
" # Only keep the frequency bins we care about (i.e. filter out unwanted frequencies)\n",
" stft = stft[stft_min_bin:stft_max_bin,:]\n",
" \n",
" # Average every 2 bins together to reduce size of STFT\n",
" stft_comp = np.zeros((int(stft.shape[0] / 2), stft.shape[1]))\n",
" print(stft_comp.shape)\n",
" for idx, slice in enumerate(stft.T):\n",
" stft_comp[:, idx] = np.mean(slice.reshape(-1, 2), axis=1)\n",
" \n",
" # Print information about clip\n",
" print(file_path)\n",
" print('Waveform shape:', waveform.shape)\n",
" print('Compressed STFT shape:', stft_comp.shape)\n",
" \n",
" # Draw time domain signal\n",
" plt.plot(waveform)\n",
" plt.show()\n",
" \n",
" # Draw spectrogram\n",
" plt.figure()\n",
" plt.imshow(stft_comp, cmap='inferno', origin='lower')\n",
" \n",
" return stft_comp"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(32, 46)\n",
"C:\\Users\\sgmustadio\\Documents\\Python\\datasets\\speech_commands_dataset\\off\\716757ce_nohash_0.wav\n",
"Waveform shape: (16000,)\n",
"Compressed STFT shape: (32, 46)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 432x288 with 0 Axes>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAD4CAYAAACt8i4nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAVFElEQVR4nO3dfYyc1XXH8d+Z9dpeYxJjXoyDSR0SqpIgMNRyLRFFBNLKbVFMWpEGtZHVRnEqJVKQqFrKP9BKlfJH8ya1irQJFEelSSwgAaGoKnKKCFVLYggBUoOKiHkJi23wC8Yvu56Z0z9mULzmOXfWd+bu7Mx8PxLamXvnznPnsnvm8TNnzjV3FwCgnFq/JwAAw45ACwCFEWgBoDACLQAURqAFgMIWzefBzCyR4pCK+dGwEhkTljFmoWRu5My907jUa1sorxtYMF5393NPbZzXQNsyVtlqtjQxplnZ6j6TGJMXPCyxJB6OaySOlev0g6bZeNzp1WvYGph4zX4iccQSrxsYZI0Xq1q5dAAAhRFoAaAwAi0AFEagBYDCCLQAUFjHrANrpQM8ImlJ+/H3uPttZrZS0vckrZW0W9In3f1Auam+06Kxs7PG1Rv7wz5XPXc6gd6nToWZEanMgtSRkpkFvZa7HsDgmssZ7bSka9z9cknrJG0ys42SbpG0w90vlrSjfR8AcIqOgdZb3mrfHW//55I2S9rWbt8m6foiMwSAATena7RmNmZmT0raK+khd39M0ip3n5Kk9s/zyk0TAAbXnAKtuzfcfZ2kNZI2mNmlcz2AmW01s51mtjN3kgAwyE4r68DdD0p6WNImSXvMbLUktX/uDcZMuvt6d1/f5VwBYCB1DLRmdq6ZrWjfnpD0MUnPSnpA0pb2w7ZIur/UJAFgkM2lqMxqSdvMbEytwLzd3R80s/+WtN3MPiPpJUk3dDORnBSjVJpWqsDKorGVYV8zUaim2Twa9KSKqxRIWbLq90f3eB7pgjO56V3VBYLS8lLQgEHWMdC6+1OSrqhof0PStSUmBQDDhG+GAUBhBFoAKIxACwCFEWgBoLA+bGVTrVZbFvZFGQnu04kxcfZAbraCBQVR8vMK4k/to2O1Dhh9cp/IfvBUMZd4XKpwT6NxqPpQ2StCwRkMJ85oAaAwAi0AFEagBYDCCLQAUBiBFgAKI9ACQGELJr0rVVRmfGxFZfuisaWJ54uLlxw/UVnRsT3ueNgXS6UlzZ/W9m4Bj/dCM1sc9jWbifUIitsoUdwGGEWc0QJAYQRaACiMQAsAhRFoAaAwAi0AFEagBYDCFkx6V2ovqRNBta16I7EPVpR6JGnZ4gvCvun6wbCv0TxS2Z6qIpZddSox/ygVzlKHSjzfWO3MsK/eiNcjrvqV2kuMKlwYPZzRAkBhBFoAKIxACwCFEWgBoDACLQAURqAFgMI6pneZ2YWSvi3pfLVysCbd/etmdrukz0ra137ore7+w9yJmC2JO4NKXE0/mhgTv4ccmX4hax4WpC3lb0aYEs8/mkcqhSv1usZqcV+zGVcEa/a80hmpXxhOc8mjrUu62d2fMLMzJT1uZg+1+77q7v9YbnoAMPg6Blp3n5I01b592Mx2SYoz/gEAs5zWNVozWyvpCkmPtZu+YGZPmdmdZnZWj+cGAENhzoHWzJZLulfSTe7+pqRvSHq/pHVqnfF+ORi31cx2mtnOHswXAAbOnAKtmY2rFWTvdvf7JMnd97h7w1t7xnxT0oaqse4+6e7r3X19ryYNAIOkY6A1M5N0h6Rd7v6Vk9pXn/SwT0h6pvfTA4DBN5esg6skfVrS02b2ZLvtVkk3mtk6tXJydkv6XJEZSloyfnZle83OD8c0EhW1js9MhX2pzRnj5KPclKW4Yllqs8p4UGJTRD8Wds0kjxXPMW8MKVyjI+fvInej04X9ezWXrINHVf3qs3NmAWCU8M0wACiMQAsAhRFoAaAwAi0AFEagBYDCFtDmjLHjJ16rbDfFmzPWaovDvqWLV4d9jWacFtZoVqd+NZuHwzG5LJHm4sGmiMkKaLmCymlpqffvRAoaFqj5rLi2sNO0cnFGCwCFEWgBoDACLQAURqAFgMIItABQ2ILJOqhZYipWnUFQbxwMhzQb8X5i9caB+FDBsSQlPoHPLYSROFTy09fTf39MFalJZTik9iFTTuEbDKDcTIDhzCDIwRktABRGoAWAwgi0AFAYgRYACiPQAkBhBFoAKGzBpHctGjsj7Dt38W9WttcT+4LNeJzedeD482Ffs/lW2Nd7eekvUTpWep+xRHGYRDGarL3LkuazQAmwMHBGCwCFEWgBoDACLQAURqAFgMIItABQGIEWAArrmN5lZhdK+rak89XKEZp096+b2UpJ35O0VtJuSZ9097gsVgdLxt4V9v3q6E8r28dqS8Mxi2oTYd97lv122He0Gb+EwzOvVrafqO8Px+TvkRWnY8VJUHnvm6kUrprFa9z06j3UkqlkwAiay19mXdLN7n6JpI2SPm9mH5R0i6Qd7n6xpB3t+wCAU3QMtO4+5e5PtG8flrRL0gWSNkva1n7YNknXl5okAAyy0/q3ppmtlXSFpMckrXL3KakVjCWd1+vJAcAwmPNXcM1suaR7Jd3k7m+azW1XATPbKmlr3vQAYPDN6YzWzMbVCrJ3u/t97eY9Zra63b9a0t6qse4+6e7r3X19LyYMAIOmY6C11qnrHZJ2uftXTup6QNKW9u0tku7v/fQAYPCZe7pikpl9WNKPJT2tX+ft3KrWddrtkt4r6SVJN7h7Ks9JZubSWGXfh5b9cTjuuemHK9ubzSi9SPJEWlVuOlO0UWGzGVcKS6c6pdY+dWnm9NO4zKrXXZLce52CRhUujKrG41X/eu94jdbdH1X8V39tt9MCgGHHN8MAoDACLQAURqAFgMIItABQGIEWAApbMJszXmjnhH1/edEnKtsXj8VpSW9MLw77JvftDvteOvJo2BenheWmcMUpVynR5owp+SlcqXHV80/NL5V2R1oYhhVntABQGIEWAAoj0AJAYQRaACiMQAsAhXUsKtPTgyWKyqyYuDQct6i2pLL9uqUbwzHvGo/nce3q18O+Zw+9O+zb+UZ1ksY9hybDMa56PJHswjFRlkPv3zdbFTIDHry2oPiOlC7oQ9YBBl91URnOaAGgMAItABRGoAWAwgi0AFAYgRYACiPQAkBhCya967JlfxKOe81+Wdm+7+hTiWOl0pLiIjDLl6yJhwVpVUemXw3HND3e1yxdjCal+v0xtS9YXrqYVLOJeFTzSHVHIr0rxX0m0ZtKhSMtDAsF6V0A0BcEWgAojEALAIURaAGgMAItABRGoAWAwjruGWZmd0q6TtJed7+03Xa7pM9K2td+2K3u/sNuJnLJ+Nlh381rqqt3vXTkynBMM5Hxc8++N8K+Z2ceDvsajaPBsUqkcKVSlqqfM5WpZ5kpV2EKlyRZ9a+PJfZCc59OHI0ULgynufz13SVpU0X7V919Xfu/roIsAAyzjoHW3R+RtH8e5gIAQ6mba7RfMLOnzOxOMzsrepCZbTWznWa2s4tjAcDAyg2035D0fknrJE1J+nL0QHefdPf1VV9LA4BRkBVo3X2PuzfcvSnpm5I29HZaADA8sgKtma0+6e4nJD3Tm+kAwPCZS3rXdyRdLekcM3tF0m2SrjazdWrl3OyW9Lm5H7I6hedn9ZfDET/dXb0J4NUTF4Vj3re8EfbdtfFA2Pfi/mvDvn95fkVl+4NH7grHpDcjzE39Ov3nS88jlqoIFj2nK5XClXpvJ4ULw6ljoHX3Gyua7ygwFwAYSnwzDAAKI9ACQGEEWgAojEALAIURaAGgsD5szlid6PCRib8Ixz3pj1a2H6vHJRg8sQFjsxlX2zpjyXvCvlqw4eORmT3hmHojVSYitfapSlY58t5T0+ldUQpdnFpntjjxfKkUNFK/sFCk/jbrbM4IAP1AoAWAwgi0AFAYgRYACiPQAkBhHWsd9F71p8f3/Okj4YizNh+sbD/x+Ew45thr8R5kt//gU2Hf3W9WZzhI0tF69V5jTY/nkZb69DLnPTC3SE2qGE2cQZCTGZGfWcB+YlgoTv/3jTNaACiMQAsAhRFoAaAwAi0AFEagBYDCCLQAUNiCKSrz4Yk/D8d99Jwlle1/8+nt4ZglV8YpV/ZH/xT2+Xfj7c+enPxIZfvvPvZqOObgsV1hX6r4SlpU6CWV3pV6T80dlzN/0rQWJv6/9EaDojIA0A8EWgAojEALAIURaAGgMAItABRGoAWAwjqmd5nZnZKuk7TX3S9tt62U9D1JayXtlvRJdz/Q8WCJ9K7fmfh0OO7njR9Vtk/PvBaOGRs7M+xrej3sO3fiQ2FfLUir2j/zy3DM9ImpsC8/5SpHbmWsXlcRI1VoNtKqhkt+etddkjad0naLpB3ufrGkHe37AIAKHQOtuz8i6dStXDdL2ta+vU3S9T2eFwAMjdzC36vcfUqS3H3KzM6LHmhmWyVtzTwOAAy84jssuPukpEnp7Wu0ADBacrMO9pjZaklq/9zbuykBwHDJDbQPSNrSvr1F0v29mQ4ADJ+Olw7M7DuSrpZ0jpm9Iuk2SV+StN3MPiPpJUk3dDuRq1Ysj/v08cr2F97KuxLxuD8X9u05/r9hX6N5LGh/K3G03KslOVWzUqlCUcUvySzuS2/OmDOPFFKd5o61GiQdA6273xh0XdvjuQDAUOKbYQBQGIEWAAoj0AJAYQRaACiMQAsAhfVhc8bqVKILl18TjvtA46LK9vUrF4djli+KK0itWBxX75o6Fj/ns4eq1+rBo/EmkfXGG2FfKuUqrwJWXgqXaTyehR9PzCOaY26lsBRSljAI2JwRAPqCQAsAhRFoAaAwAi0AFEagBYDCCLQAUFgf0ruq69icufQD4bi3pl+ubHc/0ZN5naxWWxb3WXXqlyc2e2w0DyWOlkrvSlXNisSpU5Y4lmcdqzXydOeRn6aV85xUuMJ8I70LAPqCQAsAhRFoAaAwAi0AFEagBYDCim83Pldrxi4N+8Ynrqxsf7N2MBxT13TYd6j+q7Dv2In9YV8kvWdYbhGV3IIz1VxxZkSvj1XmE/2c5ySzAAsDZ7QAUBiBFgAKI9ACQGEEWgAojEALAIURaAGgsK7Su8xst6TDalVAqVcVU3in6pSbg9objlil91a2/5avDceMW5xWVRu/JOw7VosLrBwK9s/6yfF4zzBP7LllidQvT6Ym5bw/ptK0covK5KSukXKF0dOLPNqPuvvrPXgeABhKXDoAgMK6DbQu6T/M7HEz21r1ADPbamY7zWxnl8cCgIHU7aWDq9z9VTM7T9JDZvasuz9y8gPcfVLSpPR24W8AGC1dndG6+6vtn3slfV/Shl5MCgCGSXagNbMzzOzMt29L+j1Jz/RqYgAwLLq5dLBK0vetlUa1SNK/ufu/5z7ZnmNPhX1Tzf8JelJpSbn7RcWVrMJ0LEstY271rhypFK7Ue2pOha7Uc+Y+X65ojblShYUhO9C6+wuSLu/hXABgKJHeBQCFEWgBoDACLQAURqAFgMIItABQ2Dxvzmiy4JCrJi4LRx1tHqhsn2nEmyI2Pd6MsNGMN25s+kzYF/I4nSmVYJSq0LVobEXY12gcCnpSmyym5hG/35qd/nO6l6gGllPNLDfNLDctLHf+GHac0QJAYQRaACiMQAsAhRFoAaAwAi0AFEagBYDC5jm9K8+q2vurOxJvE3XF6V0nLE7vaviJsK8WpDpF6WeSdODoL8K+sbF3h33jY2eEfc1m9YaP7vHrSqWSpTaJTFk6fn5l+7GZl7OeL/d9P5p/eoPLlFRKWyplLOd4850SllPpjLS1bnFGCwCFEWgBoDACLQAURqAFgMIItABQ2DxnHbg8yAaYOvJf4aipUtPpmdSnsvF72cqlQTaFpHoig6DePFbd3oiLueTuXGY2HvZN16uzLcyWhGNSmRG5ouwCs8WJQXH2QO7801kOOYV2cj/tz/l9TM0vdT7W6337hjPDgTNaACiMQAsAhRFoAaAwAi0AFEagBYDCCLQAUFhX6V1mtknS19WqwvEtd/9ST2b1ziNljCmx71Mk7/1qmZ0V9u05sSvsG6vF6UeRerjPmLQoUdym3jgc9sWpX6nCK/Fa1WrL4mHJfdmqCwHl7l3mXl20p5PU/moe/DqmCvrkFsXJec5oL79WZ+L32+NjWW0i7Gs2o/3+Uq85Veyn1/vU5e2j516depl9Rmuto/2zpN+X9EFJN5rZB3OfDwCGVTeXDjZIet7dX3D3GUnflbS5N9MCgOHRTaC9QNLJhUdfabfNYmZbzWynme3s4lgAMLC6uUZbdXHjHRdY3H1S0qQkmdngfocOADJ1c0b7iqQLT7q/RtKr3U0HAIZPN4H2p5IuNrP3Wat6x6ckPdCbaQHA8DCP8k7mMtjsDyR9Ta28izvd/R86PH6fpBfbd8+R9Hr2wYcP6zEb6zEb6zHbQl2P33D3c09t7CrQdsPMdrr7+r4cfAFiPWZjPWZjPWYbtPXgm2EAUBiBFgAK62egnezjsRci1mM21mM21mO2gVqPvl2jBYBRwaUDACiMQAsAhfUl0JrZJjN7zsyeN7Nb+jGHfjKzO81sr5k9c1LbSjN7yMz+r/0zrqE4ZMzsQjP7TzPbZWa/MLMvtttHck3MbKmZ/cTMft5ej79rt4/kekitaoFm9jMze7B9f6DWYt4DLeUVJUl3Sdp0Ststkna4+8WSdrTvj4q6pJvd/RJJGyV9vv07MaprMi3pGne/XNI6SZvMbKNGdz0k6YuSTi7SPFBr0Y8z2pEvr+juj0jaf0rzZknb2re3Sbp+XifVR+4+5e5PtG8fVusP6gKN6Jp4y9uVscfb/7lGdD3MbI2kP5T0rZOaB2ot+hFo51RecQStcvcpqRV4JJ3X5/n0hZmtlXSFpMc0wmvS/qfyk5L2SnrI3Ud5Pb4m6a81e+uOgVqLfgTaOZVXxOgxs+WS7pV0k7u/2e/59JO7N9x9nVpV8TaY2aX9nlM/mNl1kva6++P9nks3+hFoKa9YbY+ZrZak9s+9fZ7PvLLWBmT3Srrb3e9rN4/0mkiSux+U9LBa1/RHcT2ukvRxM9ut1mXGa8zsXzVga9GPQEt5xWoPSNrSvr1F0v19nMu8MjOTdIekXe7+lZO6RnJNzOxcM1vRvj0h6WOSntUIroe7/627r3H3tWrFih+5+59pwNaiL98MO93yisPGzL4j6Wq1Sr3tkXSbpB9I2i7pvZJeknSDu5/6gdlQMrMPS/qxpKf16+twt6p1nXbk1sTMLlPrA54xtU6Gtrv735vZ2RrB9XibmV0t6a/c/bpBWwu+ggsAhfHNMAAojEALAIURaAGgMAItABRGoAWAwgi0AFAYgRYACvt/isFB2h3ATUcAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 432x288 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"stft = analyze_clip(filenames_train[1])\n",
"plt.figure()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Function: Create STFT, keeping only ones of desired length\n",
"def extract_stft(in_files, in_y):\n",
" prob_cnt = 0\n",
" out_x = []\n",
" out_y = []\n",
" \n",
" for index, path in enumerate(in_files):\n",
" \n",
" # Check to make sure we're reading a .wav file\n",
" if not path.endswith('.wav'):\n",
" continue\n",
" \n",
" # Load .wav file\n",
" waveform, fs = librosa.load(path, sr=sample_rate)\n",
" \n",
" # Convert to something that approximates a 16-bit PCM waveform\n",
" waveform = np.around(waveform * 32767)\n",
" \n",
" # Calculate STFT\n",
" stft = np.abs(librosa.stft(waveform,\n",
" n_fft=stft_n_fft,\n",
" hop_length=stft_hop_length,\n",
" win_length=stft_n_fft,\n",
" window=stft_window,\n",
" center=False))\n",
" \n",
" # Adjust for quantization and scaling in 16-bit fixed point FFT\n",
" stft = np.around(stft / stft_n_fft)\n",
" \n",
" # Reduce precision by converting to 8-bit unsigned values [0..255]\n",
" stft = np.around(stft / (2 ** shift_n_bits))\n",
" stft = np.clip(stft, a_min=0, a_max=255)\n",
" \n",
" # Only keep the frequency bins we care about\n",
" stft = stft[stft_min_bin:stft_max_bin,:]\n",
" \n",
" # Average every 2 bins together to reduce size of STFT\n",
" stft_comp = np.zeros((int(stft.shape[0] / 2), stft.shape[1]))\n",
" for idx, slice in enumerate(stft.T):\n",
" stft_comp[:, idx] = np.mean(slice.reshape(-1, 2), axis=1)\n",
" \n",
" # Only keep STFTs with given length\n",
" if stft.shape[1] == stft_n_windows:\n",
" out_x.append(stft_comp)\n",
" out_y.append(in_y[index])\n",
" else:\n",
" #print('Dropped:', index, stft_comp.shape)\n",
" prob_cnt += 1\n",
" \n",
" return out_x, out_y, prob_cnt"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting features from training set...\n",
"Removed percentage: 0.09633264462809918\n",
"Extracting features from validation set...\n",
"Removed percentage: 0.10046487603305786\n",
"Extracting features from test set...\n",
"Removed percentage: 0.09517045454545454\n"
]
}
],
"source": [
"# Create training, validation, and test sets\n",
"print('Extracting features from training set...')\n",
"x_train, y_train, prob_cnt = extract_stft(filenames_train, \n",
" y_orig_train)\n",
"print('Removed percentage:', prob_cnt / len(y_orig_train))\n",
"print('Extracting features from validation set...')\n",
"x_val, y_val, prob_cnt = extract_stft(filenames_val, y_orig_val)\n",
"print('Removed percentage:', prob_cnt / len(y_orig_val))\n",
"print('Extracting features from test set...')\n",
"x_test, y_test, prob_cnt = extract_stft(filenames_test, y_orig_test)\n",
"print('Removed percentage:', prob_cnt / len(y_orig_test))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# Convert feature sets to numpy tensors\n",
"x_train = np.array(x_train)\n",
"x_val = np.array(x_val)\n",
"x_test = np.array(x_test)\n",
"\n",
"y_train = np.array(y_train)\n",
"y_val = np.array(y_val)\n",
"y_test = np.array(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training samples:\t (20994, 32, 46)\n",
"Validation samples:\t (6966, 32, 46)\n",
"Test samples:\t\t (7007, 32, 46)\n",
"Training truth set:\t (20994,)\n",
"Validation truth set:\t (6966,)\n",
"Test truth set:\t\t (7007,)\n"
]
}
],
"source": [
"# View tensor dimensions\n",
"print('Training samples:\\t', x_train.shape)\n",
"print('Validation samples:\\t', x_val.shape)\n",
"print('Test samples:\\t\\t', x_test.shape)\n",
"print('Training truth set:\\t', y_train.shape)\n",
"print('Validation truth set:\\t', y_val.shape)\n",
"print('Test truth set:\\t\\t', y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20994, 32, 46, 1)\n",
"(6966, 32, 46, 1)\n",
"(7007, 32, 46, 1)\n"
]
}
],
"source": [
"# CNN for TF expects (batch, height, width, channels)\n",
"# So we reshape the input tensors with a \"color\" channel of 1\n",
"x_train = x_train.reshape(x_train.shape[0], \n",
" x_train.shape[1], \n",
" x_train.shape[2], \n",
" 1)\n",
"x_val = x_val.reshape(x_val.shape[0], \n",
" x_val.shape[1], \n",
" x_val.shape[2], \n",
" 1)\n",
"x_test = x_test.reshape(x_test.shape[0], \n",
" x_test.shape[1], \n",
" x_test.shape[2], \n",
" 1)\n",
"print(x_train.shape)\n",
"print(x_val.shape)\n",
"print(x_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(32, 46, 1)\n"
]
}
],
"source": [
"# Input shape for CNN is size of STFT of 1 sample\n",
"sample_shape = x_test.shape[1:]\n",
"print(sample_shape)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"# Build model\n",
"# Based on: https://www.geeksforgeeks.org/python-image-classification-using-keras/\n",
"model = models.Sequential()\n",
"\n",
"model.add(layers.Conv2D(16, \n",
" (2, 2), \n",
" activation='relu',\n",
" input_shape=sample_shape))\n",
"model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(layers.Dropout(0.1))\n",
"\n",
"model.add(layers.Conv2D(16, \n",
" (2, 2), \n",
" activation='relu'))\n",
"model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(layers.Dropout(0.1))\n",
"\n",
"model.add(layers.Conv2D(16, \n",
" (2, 2), \n",
" activation='relu'))\n",
"model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(layers.Dropout(0.1))\n",
"\n",
"model.add(layers.Conv2D(16, \n",
" (2, 2), \n",
" activation='relu'))\n",
"model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(layers.Dropout(0.1))\n",
"\n",
"#model.add(layers.Conv2D(16, (2, 2), activation='relu'))\n",
"#model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
"\n",
"# Classifier\n",
"model.add(layers.Flatten())\n",
"model.add(layers.Dense(16, activation='relu', kernel_regularizer=regularizers.l1(0.01)))\n",
"model.add(layers.Dropout(0.5))\n",
"model.add(layers.Dense(1, activation='sigmoid'))"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_8 (Conv2D) (None, 31, 45, 16) 80 \n",
"_________________________________________________________________\n",
"max_pooling2d_8 (MaxPooling2 (None, 15, 22, 16) 0 \n",
"_________________________________________________________________\n",
"dropout_6 (Dropout) (None, 15, 22, 16) 0 \n",
"_________________________________________________________________\n",
"conv2d_9 (Conv2D) (None, 14, 21, 16) 1040 \n",
"_________________________________________________________________\n",
"max_pooling2d_9 (MaxPooling2 (None, 7, 10, 16) 0 \n",
"_________________________________________________________________\n",
"dropout_7 (Dropout) (None, 7, 10, 16) 0 \n",
"_________________________________________________________________\n",
"conv2d_10 (Conv2D) (None, 6, 9, 16) 1040 \n",
"_________________________________________________________________\n",
"max_pooling2d_10 (MaxPooling (None, 3, 4, 16) 0 \n",
"_________________________________________________________________\n",
"dropout_8 (Dropout) (None, 3, 4, 16) 0 \n",
"_________________________________________________________________\n",
"conv2d_11 (Conv2D) (None, 2, 3, 16) 1040 \n",
"_________________________________________________________________\n",
"max_pooling2d_11 (MaxPooling (None, 1, 1, 16) 0 \n",
"_________________________________________________________________\n",
"dropout_9 (Dropout) (None, 1, 1, 16) 0 \n",
"_________________________________________________________________\n",
"flatten_2 (Flatten) (None, 16) 0 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 16) 272 \n",
"_________________________________________________________________\n",
"dropout_10 (Dropout) (None, 16) 0 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 1) 17 \n",
"=================================================================\n",
"Total params: 3,489\n",
"Trainable params: 3,489\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# Display model\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"# Quantization aware training (Does not seem to be supported on this version)\n",
"#sess = tf.compat.v1.keras.backend.get_session()\n",
"#tf.contrib.quantize.create_training_graph(sess.graph)\n",
"#sess.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"# Add training parameters to model\n",
"rmsprop = optimizers.RMSprop(learning_rate=0.001)\n",
"model.compile(loss='binary_crossentropy', \n",
" optimizer=rmsprop, \n",
" metrics=['acc'])"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"# Train\n",
"history = model.fit(x_train, \n",
" y_train, \n",
" epochs=100, \n",
" batch_size=100, \n",
" validation_data=(x_val, y_val),\n",
" verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plot results\n",
"import matplotlib.pyplot as plt\n",
"\n",
"acc = history.history['acc']\n",
"val_acc = history.history['val_acc']\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"\n",
"epochs = range(1, len(acc) + 1)\n",
"\n",
"plt.plot(epochs, acc, 'bo', label='Training acc')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
"plt.title('Training and validation accuracy')\n",
"plt.legend()\n",
"\n",
"plt.figure()\n",
"\n",
"plt.plot(epochs, loss, 'bo', label='Training loss')\n",
"plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
"plt.title('Training and validation loss')\n",
"plt.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"# Save the TensorFlow model as a file\n",
"models.save_model(model, model_filename)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8344"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Convert the model to a TF Lite model file with uint8 quantization\n",
"converter = lite.TFLiteConverter.from_keras_model(model)\n",
"converter.optimizations = [lite.Optimize.DEFAULT]\n",
"#converter.post_training_quantize = True\n",
"converter.target_spec.supported_types = [tf.uint8]\n",
"tflite_model = converter.convert()\n",
"open(tflite_model_filename, 'wb').write(tflite_model)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"# Function: Convert some hex value into an array for C programming\n",
"def hex_to_c_array(hex_data, var_name):\n",
" \n",
" c_str = ''\n",
" \n",
" # Create header guard\n",
" c_str += '#ifndef ' + var_name.upper() + '_H\\n'\n",
" c_str += '#define ' + var_name.upper() + '_H\\n\\n'\n",
" \n",
" # Add array length at top of file\n",
" c_str += '\\nunsigned int ' + var_name + '_len = ' + str(len(hex_data)) + ';\\n'\n",
" \n",
" # Declare C variable\n",
" c_str += 'unsigned char ' + var_name + '[] = {'\n",
" hex_array = []\n",
" for i, val in enumerate(hex_data):\n",
"\n",
" # Construct string from hex\n",
" hex_str = format(val, '#04x')\n",
" \n",
" # Add formatting so each line stays within 80 characters\n",
" if (i + 1) < len(hex_data):\n",
" hex_str += ','\n",
" if (i + 1) % 12 == 0:\n",
" hex_str += '\\n '\n",
" hex_array.append(hex_str)\n",
" \n",
" # Add closing brace\n",
" c_str += '\\n ' + format(' '.join(hex_array)) + '\\n};\\n\\n'\n",
" \n",
" # Close out header guard\n",
" c_str += '#endif //' + var_name.upper() + '_H'\n",
" \n",
" return c_str"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"# Write TFLite model to a C source file\n",
"with open(c_model_name + '.h', 'w') as file:\n",
" file.write(hex_to_c_array(tflite_model, c_model_name))"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"#print(hex_to_c_array(tflite_model, c_model_name))"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.1499670784922049, 0.94976455]\n"
]
}
],
"source": [
"# Evaulate model with test set\n",
"eval = model.evaluate(x=x_test, y=y_test, verbose=0);\n",
"print(eval)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment