-
-
Save keunwoochoi/3c4ae16b4c5b98e75ce92776d45c161e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NSynthDataset(data.Dataset): | |
"""Pytorch dataset for NSynth dataset | |
Args: | |
splitset_path: root dir containing examples.json and audio directory with | |
wav files. | |
audio_transform (callable, optional): A function/transform that takes in | |
a sample and returns a transformed version. | |
target_transform (callable, optional): A function/transform that | |
takes in the target and transforms it. | |
excluding_filters (callable, optional): List of functions, each of which takes | |
two args (key, value) and returns bool if this item should be | |
excluded or not. | |
categorical_field_list: list of string. Each string is a key e.g., | |
'instrument_family' that will be used as a classification target. | |
Each field value will be encoding as an integer using sklearn | |
LabelEncoder. | |
audio_dtype: numpy datatype for audio | |
There are 107 pitches (9 to 119), but we'll be using 61 pitches [24-84] only, | |
following GANSynth paper. | |
""" | |
def __init__( | |
self, | |
splitset_path: str, | |
audio_transform=None, | |
target_transform=None, | |
excluding_filters=None, | |
categorical_field_list=None, | |
audio_dtype=config.DType.NUMPY, | |
): | |
"""Constructor""" | |
if categorical_field_list is None: | |
categorical_field_list = [] | |
if excluding_filters is None: | |
excluding_filters = [] | |
assert isinstance(categorical_field_list, list) | |
assert isinstance(excluding_filters, list) | |
self.splitset_path = splitset_path | |
self.excluding_filters = excluding_filters | |
self.wavfile_paths = glob.glob(os.path.join(splitset_path, 'audio/*.wav')) | |
with open(os.path.join(splitset_path, 'examples.json'), 'r') as f: | |
self.metadata = json.load(f) | |
self._refine_metadata() | |
self._filter_out() | |
self.categorical_field_list = categorical_field_list | |
self.audio_transform = audio_transform | |
self.target_transform = target_transform | |
self.int16max = np.array(np.iinfo(np.int16).max) | |
self.audio_dtype = audio_dtype | |
def _refine_metadata(self): | |
audio_avail_keys = set( | |
[f.split('/')[-1].rstrip('.wav') for f in self.wavfile_paths] | |
) | |
n_audio = len(audio_avail_keys) | |
n_metadata = len(set([k for k in self.metadata])) | |
if n_audio != n_metadata: | |
print( | |
'The number of audio file (%d) != number of metadata (%d)' | |
% (n_audio, n_metadata), | |
'. The audio-unavailable metadata are removed.', | |
) | |
keys_to_del = [] | |
for key in self.metadata: | |
if key not in audio_avail_keys: | |
keys_to_del.append(key) | |
for key in keys_to_del: | |
del self.metadata[key] | |
def _filter_out(self): | |
""" | |
This method is called after init to remove some items that are not gonna be used | |
Args | |
filenames: list of filenames of nsynth dataset | |
json_data: | |
""" | |
keys_to_remove = set([]) | |
for excluding_filter in self.excluding_filters: | |
for key_read, value in self.metadata.items(): | |
if excluding_filter(key_read, value) is True: | |
keys_to_remove.add(key_read) | |
self.wavfile_paths = [ | |
f | |
for f in self.wavfile_paths | |
if f.split('/')[-1].rstrip('.wav') not in keys_to_remove | |
] | |
self.metadata = { | |
key: value | |
for key, value in self.metadata.items() | |
if key not in keys_to_remove | |
} | |
def __len__(self): | |
return len(self.wavfile_paths) | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: tuple (audio_data, *categorical_targets) | |
audio_data (torch.Tensor or its tuple): | |
self.audio_transform(waveform) | |
- remember that `*categorical_targets` is _always_ a list. | |
e.g., if only pitch is returned, | |
- `*categorical_targets[0].shape == torch.Size([batch_size, n_pitches=61])` | |
""" | |
name = self.wavfile_paths[index] | |
_, waveform = scipy.io.wavfile.read(name) # (64000, ) | |
waveform = waveform.astype(self.audio_dtype) / self.int16max | |
waveform = waveform.reshape(1, -1) # (1, 64000) | |
waveform = torch.from_numpy(waveform) | |
target = self.metadata[os.path.splitext(os.path.basename(name))[0]] | |
if self.audio_transform is not None: | |
audio_data = self.audio_transform(waveform) | |
else: | |
audio_data = waveform | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return [audio_data, target] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment