Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Created May 27, 2019 18:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keunwoochoi/3c4ae16b4c5b98e75ce92776d45c161e to your computer and use it in GitHub Desktop.
Save keunwoochoi/3c4ae16b4c5b98e75ce92776d45c161e to your computer and use it in GitHub Desktop.
class NSynthDataset(data.Dataset):
"""Pytorch dataset for NSynth dataset
Args:
splitset_path: root dir containing examples.json and audio directory with
wav files.
audio_transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that
takes in the target and transforms it.
excluding_filters (callable, optional): List of functions, each of which takes
two args (key, value) and returns bool if this item should be
excluded or not.
categorical_field_list: list of string. Each string is a key e.g.,
'instrument_family' that will be used as a classification target.
Each field value will be encoding as an integer using sklearn
LabelEncoder.
audio_dtype: numpy datatype for audio
There are 107 pitches (9 to 119), but we'll be using 61 pitches [24-84] only,
following GANSynth paper.
"""
def __init__(
self,
splitset_path: str,
audio_transform=None,
target_transform=None,
excluding_filters=None,
categorical_field_list=None,
audio_dtype=config.DType.NUMPY,
):
"""Constructor"""
if categorical_field_list is None:
categorical_field_list = []
if excluding_filters is None:
excluding_filters = []
assert isinstance(categorical_field_list, list)
assert isinstance(excluding_filters, list)
self.splitset_path = splitset_path
self.excluding_filters = excluding_filters
self.wavfile_paths = glob.glob(os.path.join(splitset_path, 'audio/*.wav'))
with open(os.path.join(splitset_path, 'examples.json'), 'r') as f:
self.metadata = json.load(f)
self._refine_metadata()
self._filter_out()
self.categorical_field_list = categorical_field_list
self.audio_transform = audio_transform
self.target_transform = target_transform
self.int16max = np.array(np.iinfo(np.int16).max)
self.audio_dtype = audio_dtype
def _refine_metadata(self):
audio_avail_keys = set(
[f.split('/')[-1].rstrip('.wav') for f in self.wavfile_paths]
)
n_audio = len(audio_avail_keys)
n_metadata = len(set([k for k in self.metadata]))
if n_audio != n_metadata:
print(
'The number of audio file (%d) != number of metadata (%d)'
% (n_audio, n_metadata),
'. The audio-unavailable metadata are removed.',
)
keys_to_del = []
for key in self.metadata:
if key not in audio_avail_keys:
keys_to_del.append(key)
for key in keys_to_del:
del self.metadata[key]
def _filter_out(self):
"""
This method is called after init to remove some items that are not gonna be used
Args
filenames: list of filenames of nsynth dataset
json_data:
"""
keys_to_remove = set([])
for excluding_filter in self.excluding_filters:
for key_read, value in self.metadata.items():
if excluding_filter(key_read, value) is True:
keys_to_remove.add(key_read)
self.wavfile_paths = [
f
for f in self.wavfile_paths
if f.split('/')[-1].rstrip('.wav') not in keys_to_remove
]
self.metadata = {
key: value
for key, value in self.metadata.items()
if key not in keys_to_remove
}
def __len__(self):
return len(self.wavfile_paths)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns: tuple (audio_data, *categorical_targets)
audio_data (torch.Tensor or its tuple):
self.audio_transform(waveform)
- remember that `*categorical_targets` is _always_ a list.
e.g., if only pitch is returned,
- `*categorical_targets[0].shape == torch.Size([batch_size, n_pitches=61])`
"""
name = self.wavfile_paths[index]
_, waveform = scipy.io.wavfile.read(name) # (64000, )
waveform = waveform.astype(self.audio_dtype) / self.int16max
waveform = waveform.reshape(1, -1) # (1, 64000)
waveform = torch.from_numpy(waveform)
target = self.metadata[os.path.splitext(os.path.basename(name))[0]]
if self.audio_transform is not None:
audio_data = self.audio_transform(waveform)
else:
audio_data = waveform
if self.target_transform is not None:
target = self.target_transform(target)
return [audio_data, target]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment