Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save cosmincatalin/f8732e53ccd379d2559e1c23537ab337 to your computer and use it in GitHub Desktop.
Save cosmincatalin/f8732e53ccd379d2559e1c23537ab337 to your computer and use it in GitHub Desktop.
Voice Recognition SageMaker Notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"%%capture\n",
"!pip install pydub > /dev/null"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"import random\n",
"import base64\n",
"import json\n",
"import tarfile\n",
"import wave\n",
"from contextlib import closing\n",
"from os import listdir, makedirs\n",
"from os.path import isfile, join\n",
"from pickle import dump\n",
"from sagemaker.mxnet import MXNet\n",
"from shutil import rmtree, copy2\n",
"from urllib.request import urlretrieve\n",
"from tempfile import gettempdir\n",
"\n",
"import boto3\n",
"import cv2\n",
"import matplotlib\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import mxnet as mx\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sagemaker\n",
"from pydub import AudioSegment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"makedirs(\"data/sentences\")\n",
"\n",
"urlretrieve(\"http://www.cs.cornell.edu/people/pabo/movie-review-data/rotten_imdb.tar.gz\",\n",
" \"data/sentences/sentences.tar.gz\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tar = tarfile.open(\"data/sentences/sentences.tar.gz\")\n",
"tar.extractall(\"data/sentences\")\n",
"tar.close()\n",
"\n",
"with open(\"data/sentences/plot.tok.gt9.5000\", \"r\", encoding = \"ISO-8859-1\") as first_file:\n",
" first_sentences = first_file.read().split(\"\\n\")[0:5000]\n",
"with open(\"data/sentences/quote.tok.gt9.5000\", \"r\", encoding = \"ISO-8859-1\") as second_file:\n",
" second_sentences = second_file.read().split(\"\\n\")[0:5000]\n",
"\n",
"with open(\"data/sentences/sentences.txt\", \"w\") as sentences_file:\n",
" for sentence in first_sentences + second_sentences:\n",
" sentences_file.write(\"{}\\n\".format(sentence))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(\"data/sentences/sentences.txt\", \"r\", encoding = \"ISO-8859-1\") as sentences_file:\n",
" sentences = sentences_file.read().split(\"\\n\")[:-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"voices = [\"Ivy\", \"Joanna\", \"Joey\", \"Justin\", \"Kendra\", \"Kimberly\", \"Matthew\", \"Salli\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mp3_files = sorted([f for f in listdir(\"data/mp3\") if isfile(join(\"data/mp3\", f))])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"makedirs(\"data/wav\")\n",
"\n",
"sample_start = random.randint(500, 1000)\n",
"sample_finish = sample_start + 2000\n",
"\n",
"for mp3 in mp3_files:\n",
" sound = AudioSegment.from_mp3(\"data/mp3/{}\".format(mp3))[sample_start:sample_finish]\n",
" sound.export(\"data/wav/{}wav\".format(mp3[:-3]), format=\"wav\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wav_files = sorted([f for f in listdir(\"data/wav/\") if isfile(join(\"data/wav/\", f))])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def graph_spectrogram(wav_file, out):\n",
" wav = wave.open(wav_file, \"r\")\n",
" frames = wav.readframes(-1)\n",
" sound_info = np.frombuffer(frames, \"int16\")\n",
" frame_rate = wav.getframerate()\n",
" wav.close()\n",
" fig = plt.figure()\n",
" fig.set_size_inches((1.4, 1.4))\n",
" ax = plt.Axes(fig, [0., 0., 1., 1.])\n",
" ax.set_axis_off()\n",
" fig.add_axes(ax)\n",
" plt.set_cmap(\"hot\")\n",
" plt.specgram(sound_info, Fs=frame_rate)\n",
" plt.savefig(out, format=\"png\")\n",
" plt.close(fig)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"makedirs(\"data/spectrograms\")\n",
"\n",
"for wav in wav_files:\n",
" graph_spectrogram(\"data/wav/{}\".format(wav), \"data/spectrograms/{}png\".format(wav[:-3]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"spectrograms = sorted([join(\"data/spectrograms/\", f) for f in listdir(\"data/spectrograms/\") if isfile(join(\"data/spectrograms/\", f))])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({\n",
" \"wav\": [join(\"data/wav/\", f) for f in wav_files],\n",
" \"mp3\": [join(\"data/mp3/\", f) for f in mp3_files],\n",
" \"spectrogram\": spectrograms\n",
"})\n",
"df[\"label\"] = df.spectrogram.str.extract(\"sample-\\\\d+-(\\\\w+)\\\\.png\", expand=False).apply(lambda x: voices.index(x))\n",
"df[\"voice\"] = df.spectrogram.str.extract('sample-\\\\d+-(\\\\w+)\\\\.png', expand=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train = df.groupby(\"voice\").apply(lambda x: x.sample(frac=.8)).reset_index(0, drop=True)\n",
"validation = df.loc[~df.index.isin(train.index), :].groupby(\"voice\").apply(lambda x: x.sample(frac=.5)).reset_index(0, drop=True)\n",
"test = df.loc[np.logical_not(np.logical_xor(~df.index.isin(train.index), ~df.index.isin(validation.index))), :]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def transform(row):\n",
" img = cv2.imread(row[\"spectrogram\"])\n",
" img = mx.nd.array(img)\n",
" img = img.astype(np.float32)\n",
" img = mx.nd.transpose(img, (2, 0, 1))\n",
" img = img / 255\n",
" label = np.float32(row[\"label\"])\n",
" return img, label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_nd = [transform(row) for _, row in train.iterrows()]\n",
"validation_nd = [transform(row) for _, row in validation.iterrows()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def save_to_disk(data, type):\n",
" makedirs(\"{}/pvdwgmas/data/pickles/{}\".format(gettempdir(), type))\n",
" with open(\"{}/pvdwgmas/data/pickles/{}/data.p\".format(gettempdir(), type), \"wb\") as out:\n",
" dump(data, out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_to_disk(train_nd, \"train\")\n",
"save_to_disk(validation_nd, \"validation\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sagemaker_session = sagemaker.Session()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inputs = sagemaker_session.upload_data(path=\"{}/pvdwgmas/data/pickles\".format(gettempdir()),\n",
" bucket=\"redacted\", key_prefix=\"cosmin/sagemaker/demo\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"makedirs(\"data/test\")\n",
"for _, row in test.iterrows():\n",
" makedirs(\"data/test/{}\".format(row[\"voice\"]), exist_ok=True)\n",
" copy2(row[\"mp3\"], \"data/test/{}\".format(row[\"voice\"]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator = MXNet(\"voice-recognition-sagemaker-script.py\", \n",
" role=sagemaker.get_execution_role(), \n",
" train_instance_count=1, \n",
" train_instance_type=\"ml.p2.xlarge\",\n",
" hyperparameters={\"epochs\": 5},\n",
" py_version=\"py3\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator.fit(inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictor = estimator.deploy(instance_type=\"ml.m4.xlarge\", initial_instance_count=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use this cell only if you have downloaded the `Kimberly recites some shameless self promotion ad.mp3` file from\n",
"[https://raw.githubusercontent.com/cosmincatalin/voice-recognition-with-mxnet-and-sagemaker/master/Kimberly%20recites%20some%20shameless%20self%20promotion%20ad.mp3](https://raw.githubusercontent.com/cosmincatalin/voice-recognition-with-mxnet-and-sagemaker/master/Kimberly%20recites%20some%20shameless%20self%20promotion%20ad.mp3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# with open(\"Kimberly recites some shameless self promotion ad.mp3\", \"rb\") as audio_file:\n",
"# payload = base64.b64encode(audio_file.read()).decode(\"utf-8\")\n",
"# response = sagemaker_runtime_client.invoke_endpoint(\n",
"# EndpointName=predictor.endpoint,\n",
"# Body=payload,\n",
"# ContentType=\"audio/mp3\",\n",
"# Accept=\"application/json\"\n",
"# )[\"Body\"].read()\n",
"# print(\"Kimberly predicted as {}\".format(json.loads(response, encoding=\"utf-8\")))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for directory in listdir(\"data/test\"):\n",
" batch = []\n",
" cnt = 0\n",
" total = 0\n",
" detected = 0\n",
" for file in listdir(\"data/test/{}\".format(directory)):\n",
" with open(\"data/test/{}/{}\".format(directory, file), \"rb\") as audio_file:\n",
" batch.append(base64.b64encode(audio_file.read()).decode(\"utf-8\"))\n",
" cnt += 1\n",
" if cnt == 5:\n",
" binary_json = json.dumps(batch).encode(\"utf-8\")\n",
" response = sagemaker_runtime_client.invoke_endpoint(\n",
" EndpointName=predictor.endpoint,\n",
" Body=binary_json,\n",
" ContentType=\"application/json\",\n",
" Accept=\"application/json\"\n",
" )[\"Body\"].read()\n",
" individual_predictions = json.loads(response, encoding=\"utf-8\")\n",
" for prediction in individual_predictions:\n",
" total += 1 \n",
" if prediction == directory:\n",
" detected += 1\n",
" cnt = 0\n",
" batch = []\n",
" print(\"\"\"Recordings with {}:\n",
" Total: {}\n",
" Detected: {}\n",
" Accuracy: {:0.2f}\n",
" \"\"\".format(directory, str(total), str(detected), detected/total))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_mxnet_p36",
"language": "python",
"name": "conda_mxnet_p36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment