Skip to content

Instantly share code, notes, and snippets.

@dhpollack
Created September 15, 2017 07:53
Show Gist options
  • Save dhpollack/f33ff65fa0264d1a332f36853e6f07a4 to your computer and use it in GitHub Desktop.
Save dhpollack/f33ff65fa0264d1a332f36853e6f07a4 to your computer and use it in GitHub Desktop.
A minimal example of a variable length dataset for pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"import torch.utils.data as data\n",
"import torchaudio\n",
"import librosa\n",
"import numpy as np\n",
"import random\n",
"import os\n",
"import glob"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
" class VariableLengthDataset(data.Dataset):\n",
" def __init__(self, manifest, snippet_length=24000, get_sequentially=False, ret_np=False, use_librosa=False):\n",
" self.manifest = manifest\n",
" self.snippet_length = snippet_length\n",
" self.get_sequentially = get_sequentially\n",
" self.use_librosa = use_librosa\n",
" self.ret_np = ret_np\n",
" self.acc = 0\n",
" self.snippet_counter = 0\n",
" self.audio_idx = 0\n",
" self.st = 0\n",
" self.data = {}\n",
" def __getitem__(self, index):\n",
" # load audio data from file or cache\n",
" if self.snippet_counter == 0:\n",
" self.audio_idx = index - self.acc\n",
" apath = self.manifest[self.audio_idx]\n",
" if apath not in self.data:\n",
" if self.use_librosa:\n",
" sig, sr = librosa.core.load(apath, sr=None)\n",
" sig = torch.from_numpy(sig).unsqueeze(1).float()\n",
" else:\n",
" sig, sr = torchaudio.load(apath, normalization=True)\n",
" self.data[apath] = (sig, sr)\n",
" else:\n",
" sig, sr = self.data[apath]\n",
"\n",
" # increase iterations based on length of audio\n",
" num_snippets = int(sig.size(0) // self.snippet_length)\n",
" self.acc += max(num_snippets-1,0)\n",
" else:\n",
" apath = self.manifest[self.audio_idx]\n",
" sig, sr = self.data[apath]\n",
" num_snippets = int(sig.size(0) // self.snippet_length)\n",
"\n",
" # create snippet\n",
" if self.get_sequentially:\n",
" self.st += self.snippet_length\n",
" else:\n",
" self.st = random.randrange(int(sig.size(0)-self.snippet_length))\n",
" ret_sig = sig[self.st:(self.st+self.snippet_length)]\n",
" if self.ret_np:\n",
" ret_sig = ret_sig.numpy()\n",
"\n",
" # update counter for current audio file\n",
" self.snippet_counter += 1\n",
"\n",
" # label creation\n",
" spkr = os.path.dirname(apath).rsplit(\"/\", 1)[-1]\n",
"\n",
" # check for reset\n",
" if self.snippet_counter >= num_snippets:\n",
" self.snippet_counter = 0\n",
" self.st = 0\n",
"\n",
" return ret_sig, spkr\n",
"\n",
" def __len__(self):\n",
" return len(self.data) + self.acc\n",
"\n",
" def reset_acc(self):\n",
" self.acc = 0\n",
"\n",
" datadir = \"pcsnpny-20150204-mkj\"\n",
" audio_manifest = [a for a in glob.glob(datadir+\"/**/*.wav\", recursive=True)]\n",
" def run_dataset():\n",
" for epoch in range(1):\n",
" print(epoch)\n",
" all_data = [(x, label) for x, label in ds]\n",
" ds.reset_acc()\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"CPU times: user 19.5 ms, sys: 3.57 ms, total: 23 ms\n",
"Wall time: 16.2 ms\n",
"0\n",
"CPU times: user 72.1 ms, sys: 22 µs, total: 72.1 ms\n",
"Wall time: 36 ms\n"
]
}
],
"source": [
"ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=False)\n",
"%time run_dataset()\n",
"ds = VariableLengthDataset(audio_manifest, 12000, get_sequentially=True, ret_np=False, use_librosa=True)\n",
"%time run_dataset()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment