Last active
November 8, 2022 23:54
-
-
Save bofenghuang/87245983a3ad9a0567ec364bc0aea18c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Based on https://colab.research.google.com/drive/1mypqbHDrusZaIbqPoiEGY-WIbnpMHa2I?usp=sharing#scrollTo=bY0AwDiSxdVE | |
- Random Sampling | |
- After sorting (minimum padding but no randomness) | |
- Org Dynamic Batching (the one in speechbrain current version) | |
- **Mdf Dynamic Batching w/ fitted lognorm** (bucket boundaries set up with lognormal distribution fitted on dataset) | |
- **Mdf Dynamic Batching w/ fitted beta** (bucket boundaries set up with beta distribution fitted on dataset, mentioned in tuto) | |
- **Mdf Dynamic Batching w/ Kmeans** (examples aggregated into buckets by Kmeans) | |
""" | |
import logging | |
import time | |
from typing import List | |
import numpy as np | |
import scipy.stats | |
import speechbrain | |
import torch | |
from scipy.stats import lognorm | |
from speechbrain.dataio.batch import PaddedBatch | |
from speechbrain.dataio.dataset import DynamicItemDataset | |
from speechbrain.dataio.sampler import DynamicBatchSampler as DynamicBatchSamplerOrg | |
from torch.utils.data import DataLoader, Sampler | |
logger = logging.getLogger(__name__) | |
# logging.basicConfig(level=logging.INFO) | |
class DynamicBatchSamplerMdf(Sampler): | |
def __init__( | |
self, | |
dataset, | |
max_batch_length: int, | |
num_buckets: int = None, | |
length_func=lambda x: x["duration"], | |
shuffle: bool = True, | |
batch_ordering: str = "random", | |
max_batch_ex: int = None, | |
bucket_boundaries: List[int] = [], | |
lengths_list: List[int] = None, | |
seed: int = 42, | |
epoch: int = 0, | |
drop_last: bool = False, | |
verbose: bool = False, | |
boundary_mode: str = "fixed", | |
boundary_rv_dist: str = "lognorm", | |
): | |
self._dataset = dataset | |
self._ex_lengths = {} | |
ex_ids = self._dataset.data_ids | |
self.verbose = verbose | |
# We do not put a default on num_buckets to encourage users to play with this parameter | |
if num_buckets is None and len(bucket_boundaries) == 0: | |
raise RuntimeError( | |
"Please specify either num_buckets or bucket boundaries." "Check the docs, and/or the tutorial !" | |
) | |
if lengths_list is not None: | |
# take length of examples from this argument and bypass length_key | |
for indx in range(len(lengths_list)): | |
self._ex_lengths[str(indx)] = lengths_list[indx] | |
else: | |
# use length func | |
if not isinstance(dataset, DynamicItemDataset): | |
raise NotImplementedError("Dataset should be a Speechbrain DynamicItemDataset when using length function") | |
for indx in range(len(self._dataset)): | |
self._ex_lengths[str(indx)] = length_func(self._dataset.data[ex_ids[indx]]) | |
if len(bucket_boundaries) > 0: | |
if not all([x >= 0 for x in bucket_boundaries]): | |
raise ValueError("All elements in bucket boundaries should be non-negative (>= 0).") | |
if not len(set(bucket_boundaries)) == len(bucket_boundaries): | |
raise ValueError("Bucket_boundaries should not contain duplicates.") | |
np.testing.assert_array_equal( | |
np.array(bucket_boundaries), | |
np.array(sorted(bucket_boundaries)), | |
err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!", | |
) | |
self._bucket_boundaries = np.array(sorted(bucket_boundaries)) | |
else: | |
# use num_buckets | |
# self._bucket_boundaries = np.array( | |
# self._get_boundaries_through_warping( | |
# max_batch_length=max_batch_length, | |
# num_quantiles=num_buckets, | |
# ) | |
# ) | |
self._bucket_boundaries = np.array( | |
self._get_boundaries( | |
max_batch_length=max_batch_length, | |
num_quantiles=num_buckets, | |
mode=boundary_mode, | |
rv_dist=boundary_rv_dist, | |
) | |
) | |
self._max_batch_length = max_batch_length | |
self._shuffle_ex = shuffle | |
self._batch_ordering = batch_ordering | |
self._seed = seed | |
self._drop_last = drop_last | |
if max_batch_ex is None: | |
max_batch_ex = np.inf | |
self._max_batch_ex = max_batch_ex | |
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length? | |
self._bucket_lens = [ | |
max(1, int(max_batch_length / self._bucket_boundaries[i])) for i in range(len(self._bucket_boundaries)) | |
] + [1] | |
self._epoch = epoch | |
self._generate_batches() | |
def get_durations(self, batch): | |
return [self._ex_lengths[str(idx)] for idx in batch] | |
def _get_boundaries(self, max_batch_length: int, num_quantiles: int, mode: str, rv_dist: str): | |
if mode == "fixed": | |
return self._get_boundaries_through_warping( | |
max_batch_length=max_batch_length, | |
num_quantiles=num_quantiles, | |
) | |
elif mode == "fitted": | |
return self._get_boundaries_through_fitted_warping(num_buckets=num_quantiles, rv_dist=rv_dist) | |
else: | |
raise NotImplementedError | |
def _get_boundaries_through_fitted_warping( | |
self, | |
num_buckets: int, | |
rv_dist: str, | |
): | |
latent_boundaries = np.linspace( | |
1 / num_buckets, | |
1, | |
num_buckets, | |
) | |
lengths = list(self._ex_lengths.values()) | |
try: | |
rv = getattr(scipy.stats, rv_dist) | |
except AttributeError: | |
logger.warning(f"{rv_dist} doesn't exist, use beta instead") | |
rv = scipy.stats.beta | |
# ? should add floc=0 and fscale=1 | |
# RuntimeWarning: invalid value encountered in sqrt sk = 2*(b-a)*np.sqrt(a + b + 1) / (a + b + 2) / np.sqrt(a*b) | |
rv_params = rv.fit(lengths) | |
# last upper boundary is always inf | |
bucket_boundaries = rv.ppf(latent_boundaries, *rv_params) | |
# replace inf by max length | |
# bucket_boundaries[-1] = max(max(lengths), max(bucket_boundaries)) | |
return bucket_boundaries | |
def _get_boundaries_through_warping( | |
self, | |
max_batch_length: int, | |
num_quantiles: int, | |
) -> List[int]: | |
# NOTE: the following lines do not cover that there is only one example in the dataset | |
# warp frames (duration) distribution of train data | |
logger.info("Batch quantisation in latent space") | |
# linspace set-up | |
num_boundaries = num_quantiles + 1 | |
# create latent linearly equal spaced buckets | |
latent_boundaries = np.linspace( | |
1 / num_boundaries, | |
num_quantiles / num_boundaries, | |
num_quantiles, | |
) | |
# get quantiles using lognormal distribution | |
quantiles = lognorm.ppf(latent_boundaries, 1) | |
# scale up to to max_batch_length | |
bucket_boundaries = quantiles * max_batch_length / quantiles[-1] | |
# compute resulting bucket length multipliers | |
length_multipliers = [bucket_boundaries[x + 1] / bucket_boundaries[x] for x in range(num_quantiles - 1)] | |
# logging | |
logger.info( | |
"Latent bucket boundary - buckets: {} - length multipliers: {}".format( | |
list(map("{:.2f}".format, bucket_boundaries)), | |
list(map("{:.2f}".format, length_multipliers)), | |
) | |
) | |
return list(sorted(bucket_boundaries)) | |
def _permute_batches(self): | |
if self._batch_ordering == "random": | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self._seed + self._epoch) | |
sampler = torch.randperm(len(self._batches), generator=g).tolist() # type: ignore | |
tmp = [] | |
for idx in sampler: | |
tmp.append(self._batches[idx]) | |
self._batches = tmp | |
elif self._batch_ordering == "ascending": | |
self._batches = sorted( | |
self._batches, | |
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), | |
) | |
elif self._batch_ordering == "descending": | |
self._batches = sorted( | |
self._batches, | |
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), | |
reverse=True, | |
) | |
else: | |
raise NotImplementedError | |
def _generate_batches(self): | |
logger.info("DynamicBatchSampler: Generating dynamic batches") | |
if self._shuffle_ex: | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self._seed + self._epoch) | |
sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore | |
else: | |
# take examples as they are: e.g. they have been sorted | |
sampler = range(len(self._dataset)) # type: ignore | |
self._batches = [] | |
bucket_batches = [[] for i in self._bucket_lens] | |
stats_tracker = [{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} for i in self._bucket_lens] | |
for idx in sampler: | |
# length of pre-sampled audio | |
item_len = self._ex_lengths[str(idx)] | |
# bucket to fill up most padding | |
bucket_id = np.searchsorted(self._bucket_boundaries, item_len) | |
# fill audio's duration into that bucket | |
bucket_batches[bucket_id].append(idx) | |
stats_tracker[bucket_id]["min"] = min(stats_tracker[bucket_id]["min"], item_len) | |
stats_tracker[bucket_id]["max"] = max(stats_tracker[bucket_id]["max"], item_len) | |
stats_tracker[bucket_id]["tot"] += item_len | |
stats_tracker[bucket_id]["n_ex"] += 1 | |
# track #samples - why not duration/#frames; rounded up? | |
# keep track of durations, if necessary | |
if ( | |
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id] | |
or len(bucket_batches[bucket_id]) >= self._max_batch_ex | |
): | |
self._batches.append(bucket_batches[bucket_id]) | |
bucket_batches[bucket_id] = [] | |
# keep track of durations | |
# Dump remaining batches | |
if not self._drop_last: | |
for batch in bucket_batches: | |
if batch: | |
self._batches.append(batch) | |
self._permute_batches() # possibly reorder batches | |
if self._epoch == 0: # only log at first epoch | |
# frames per batch & their padding remaining | |
boundaries = [0] + self._bucket_boundaries.tolist() | |
for bucket_indx in range(len(self._bucket_boundaries)): | |
try: | |
num_batches = stats_tracker[bucket_indx]["tot"] // (self._max_batch_length) | |
pad_factor = (stats_tracker[bucket_indx]["max"] - stats_tracker[bucket_indx]["min"]) / ( | |
stats_tracker[bucket_indx]["tot"] / stats_tracker[bucket_indx]["n_ex"] | |
) | |
except ZeroDivisionError: | |
num_batches = 0 | |
pad_factor = 0 | |
logger.info( | |
( | |
"DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and " | |
+ "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}." | |
).format( | |
bucket_indx, | |
boundaries[bucket_indx], | |
boundaries[bucket_indx + 1], | |
self._bucket_lens[bucket_indx], | |
stats_tracker[bucket_indx]["n_ex"], | |
num_batches, | |
pad_factor * 100, | |
) | |
) | |
if self.verbose: | |
batch_stats = { | |
"tot_frames": [], | |
"tot_pad_frames": [], | |
"pad_%": [], | |
} | |
for batch in self._batches: | |
tot_frames = sum([self._ex_lengths[str(idx)] for idx in batch]) | |
batch_stats["tot_frames"].append(tot_frames) | |
max_frames = max([self._ex_lengths[str(idx)] for idx in batch]) | |
tot_pad = sum([max_frames - self._ex_lengths[str(idx)] for idx in batch]) | |
batch_stats["tot_pad_frames"].append(tot_pad) | |
batch_stats["pad_%"].append(tot_pad / tot_frames * 100) | |
padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total." | |
padding_details = "DynamicBatchSampler: " + padding_details | |
for i in range(len(self._batches)): | |
logger.info( | |
padding_details.format( | |
i, | |
batch_stats["tot_frames"][i], | |
len(self._batches[i]), | |
batch_stats["tot_pad_frames"][i], | |
batch_stats["pad_%"][i], | |
) | |
) | |
def __iter__(self): | |
for batch in self._batches: | |
yield batch | |
if self._shuffle_ex: # re-generate examples if ex_ordering == "random" | |
self._generate_batches() | |
if self._batch_ordering == "random": | |
# we randomly permute the batches only --> faster | |
self._permute_batches() | |
else: | |
pass | |
def set_epoch(self, epoch): | |
""" | |
You can also just access self.epoch, but we maintain this interface | |
to mirror torch.utils.data.distributed.DistributedSampler | |
""" | |
self._epoch = epoch | |
self._generate_batches() | |
def __len__(self): | |
return len(self._batches) | |
class DynamicBatchSamplerKM(Sampler): | |
def __init__( | |
self, | |
dataset, | |
max_batch_length: int, | |
num_buckets: int = None, | |
length_func=lambda x: x["duration"], | |
shuffle: bool = True, | |
batch_ordering: str = "random", | |
max_batch_ex: int = None, | |
bucket_boundaries: List[int] = [], | |
lengths_list: List[int] = None, | |
seed: int = 42, | |
epoch: int = 0, | |
drop_last: bool = False, | |
verbose: bool = False, | |
): | |
self._dataset = dataset | |
self._ex_lengths = {} | |
ex_ids = self._dataset.data_ids | |
self.verbose = verbose | |
# We do not put a default on num_buckets to encourage users to play with this parameter | |
if num_buckets is None and len(bucket_boundaries) == 0: | |
raise RuntimeError( | |
"Please specify either num_buckets or bucket boundaries." "Check the docs, and/or the tutorial !" | |
) | |
if lengths_list is not None: | |
# take length of examples from this argument and bypass length_key | |
for indx in range(len(lengths_list)): | |
self._ex_lengths[str(indx)] = lengths_list[indx] | |
else: | |
# use length func | |
if not isinstance(dataset, DynamicItemDataset): | |
raise NotImplementedError("Dataset should be a Speechbrain DynamicItemDataset when using length function") | |
for indx in range(len(self._dataset)): | |
self._ex_lengths[str(indx)] = length_func(self._dataset.data[ex_ids[indx]]) | |
# if len(bucket_boundaries) > 0: | |
# if not all([x >= 0 for x in bucket_boundaries]): | |
# raise ValueError( | |
# "All elements in bucket boundaries should be non-negative (>= 0)." | |
# ) | |
# if not len(set(bucket_boundaries)) == len(bucket_boundaries): | |
# raise ValueError( | |
# "Bucket_boundaries should not contain duplicates." | |
# ) | |
# np.testing.assert_array_equal( | |
# np.array(bucket_boundaries), | |
# np.array(sorted(bucket_boundaries)), | |
# err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!", | |
# ) | |
# self._bucket_boundaries = np.array(sorted(bucket_boundaries)) | |
# else: | |
# # use num_buckets | |
# self._bucket_boundaries = np.array( | |
# self._get_boundaries_through_warping( | |
# max_batch_length=max_batch_length, | |
# num_quantiles=num_buckets, | |
# ) | |
# ) | |
self._max_batch_length = max_batch_length | |
self._shuffle_ex = shuffle | |
self._batch_ordering = batch_ordering | |
self._seed = seed | |
self._drop_last = drop_last | |
if max_batch_ex is None: | |
max_batch_ex = np.inf | |
self._max_batch_ex = max_batch_ex | |
self._ex_bucket_ids = self._get_bucket_ids(num_buckets=num_buckets) | |
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length? | |
# self._bucket_lens = [ | |
# max(1, int(max_batch_length / self._bucket_boundaries[i])) for i in range(len(self._bucket_boundaries)) | |
# ] + [1] | |
self._bucket_lens = self._get_bucket_lens(num_buckets, max_batch_length) | |
self._epoch = epoch | |
self._generate_batches() | |
def get_durations(self, batch): | |
return [self._ex_lengths[str(idx)] for idx in batch] | |
def _get_bucket_ids(self, num_buckets: int): | |
try: | |
from sklearn.cluster import KMeans | |
except ImportError: | |
print("pip3 install -U scikit-learn") | |
lengths = np.asarray(list(self._ex_lengths.values())) | |
lengths = lengths.reshape(-1, 1) | |
km = KMeans(n_clusters=num_buckets, random_state=self._seed).fit(lengths) | |
# sort cluster by centroid | |
idx = np.argsort(km.cluster_centers_.reshape((-1,))) | |
lut = np.zeros_like(idx) | |
lut[idx] = np.arange(num_buckets) | |
return lut[km.labels_] | |
def _get_bucket_lens(self, num_buckets: int, max_batch_length: int): | |
max_lens_by_bucket = [] | |
for bucket_id in range(num_buckets): | |
# ? why not use np array | |
lengths = np.asarray(list(self._ex_lengths.values())) | |
max_lens_by_bucket.append(lengths[np.where(self._ex_bucket_ids == bucket_id)].max()) | |
# ? one less bucket than other methods | |
bucket_lens = [max(1, int(max_batch_length / max_len)) for max_len in max_lens_by_bucket] | |
return bucket_lens | |
def _get_boundaries_through_warping( | |
self, | |
max_batch_length: int, | |
num_quantiles: int, | |
) -> List[int]: | |
# NOTE: the following lines do not cover that there is only one example in the dataset | |
# warp frames (duration) distribution of train data | |
logger.info("Batch quantisation in latent space") | |
# linspace set-up | |
num_boundaries = num_quantiles + 1 | |
# create latent linearly equal spaced buckets | |
latent_boundaries = np.linspace( | |
1 / num_boundaries, | |
num_quantiles / num_boundaries, | |
num_quantiles, | |
) | |
# get quantiles using lognormal distribution | |
quantiles = lognorm.ppf(latent_boundaries, 1) | |
# scale up to to max_batch_length | |
bucket_boundaries = quantiles * max_batch_length / quantiles[-1] | |
# compute resulting bucket length multipliers | |
length_multipliers = [bucket_boundaries[x + 1] / bucket_boundaries[x] for x in range(num_quantiles - 1)] | |
# logging | |
logger.info( | |
"Latent bucket boundary - buckets: {} - length multipliers: {}".format( | |
list(map("{:.2f}".format, bucket_boundaries)), | |
list(map("{:.2f}".format, length_multipliers)), | |
) | |
) | |
return list(sorted(bucket_boundaries)) | |
def _permute_batches(self): | |
if self._batch_ordering == "random": | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self._seed + self._epoch) | |
sampler = torch.randperm(len(self._batches), generator=g).tolist() # type: ignore | |
tmp = [] | |
for idx in sampler: | |
tmp.append(self._batches[idx]) | |
self._batches = tmp | |
elif self._batch_ordering == "ascending": | |
self._batches = sorted( | |
self._batches, | |
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), | |
) | |
elif self._batch_ordering == "descending": | |
self._batches = sorted( | |
self._batches, | |
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), | |
reverse=True, | |
) | |
else: | |
raise NotImplementedError | |
def _generate_batches(self): | |
logger.info("DynamicBatchSampler: Generating dynamic batches") | |
if self._shuffle_ex: | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self._seed + self._epoch) | |
sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore | |
else: | |
# take examples as they are: e.g. they have been sorted | |
sampler = range(len(self._dataset)) # type: ignore | |
self._batches = [] | |
bucket_batches = [[] for i in self._bucket_lens] | |
stats_tracker = [{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} for i in self._bucket_lens] | |
for idx in sampler: | |
# length of pre-sampled audio | |
item_len = self._ex_lengths[str(idx)] | |
# bucket to fill up most padding | |
# bucket_id = np.searchsorted(self._bucket_boundaries, item_len) | |
bucket_id = self._ex_bucket_ids[idx] | |
# fill audio's duration into that bucket | |
bucket_batches[bucket_id].append(idx) | |
stats_tracker[bucket_id]["min"] = min(stats_tracker[bucket_id]["min"], item_len) | |
stats_tracker[bucket_id]["max"] = max(stats_tracker[bucket_id]["max"], item_len) | |
stats_tracker[bucket_id]["tot"] += item_len | |
stats_tracker[bucket_id]["n_ex"] += 1 | |
# track #samples - why not duration/#frames; rounded up? | |
# keep track of durations, if necessary | |
if ( | |
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id] | |
or len(bucket_batches[bucket_id]) >= self._max_batch_ex | |
): | |
self._batches.append(bucket_batches[bucket_id]) | |
bucket_batches[bucket_id] = [] | |
# keep track of durations | |
# Dump remaining batches | |
if not self._drop_last: | |
for batch in bucket_batches: | |
if batch: | |
self._batches.append(batch) | |
self._permute_batches() # possibly reorder batches | |
# if self._epoch == 0: # only log at first epoch | |
# # frames per batch & their padding remaining | |
# boundaries = [0] + self._bucket_boundaries.tolist() | |
# for bucket_indx in range(len(self._bucket_boundaries)): | |
# try: | |
# num_batches = stats_tracker[bucket_indx]["tot"] // (self._max_batch_length) | |
# pad_factor = (stats_tracker[bucket_indx]["max"] - stats_tracker[bucket_indx]["min"]) / ( | |
# stats_tracker[bucket_indx]["tot"] / stats_tracker[bucket_indx]["n_ex"] | |
# ) | |
# except ZeroDivisionError: | |
# num_batches = 0 | |
# pad_factor = 0 | |
# logger.info( | |
# ( | |
# "DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and " | |
# + "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}." | |
# ).format( | |
# bucket_indx, | |
# boundaries[bucket_indx], | |
# boundaries[bucket_indx + 1], | |
# self._bucket_lens[bucket_indx], | |
# stats_tracker[bucket_indx]["n_ex"], | |
# num_batches, | |
# pad_factor * 100, | |
# ) | |
# ) | |
# if self.verbose: | |
# batch_stats = { | |
# "tot_frames": [], | |
# "tot_pad_frames": [], | |
# "pad_%": [], | |
# } | |
# for batch in self._batches: | |
# tot_frames = sum([self._ex_lengths[str(idx)] for idx in batch]) | |
# batch_stats["tot_frames"].append(tot_frames) | |
# max_frames = max([self._ex_lengths[str(idx)] for idx in batch]) | |
# tot_pad = sum([max_frames - self._ex_lengths[str(idx)] for idx in batch]) | |
# batch_stats["tot_pad_frames"].append(tot_pad) | |
# batch_stats["pad_%"].append(tot_pad / tot_frames * 100) | |
# padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total." | |
# padding_details = "DynamicBatchSampler: " + padding_details | |
# for i in range(len(self._batches)): | |
# logger.info( | |
# padding_details.format( | |
# i, | |
# batch_stats["tot_frames"][i], | |
# len(self._batches[i]), | |
# batch_stats["tot_pad_frames"][i], | |
# batch_stats["pad_%"][i], | |
# ) | |
# ) | |
def __iter__(self): | |
for batch in self._batches: | |
yield batch | |
if self._shuffle_ex: # re-generate examples if ex_ordering == "random" | |
self._generate_batches() | |
if self._batch_ordering == "random": | |
# we randomly permute the batches only --> faster | |
self._permute_batches() | |
else: | |
pass | |
def set_epoch(self, epoch): | |
""" | |
You can also just access self.epoch, but we maintain this interface | |
to mirror torch.utils.data.distributed.DistributedSampler | |
""" | |
self._epoch = epoch | |
self._generate_batches() | |
def __len__(self): | |
return len(self._batches) | |
def count_samples(dataloader): | |
true_samples = 0 | |
padded_samples = 0 | |
t1 = time.time() | |
for batch in dataloader: | |
audio, lens = batch.signal | |
true_samples += torch.sum(audio.shape[-1] * lens).item() | |
padded_samples += torch.sum(audio.shape[-1] * (1 - lens)).item() | |
# print(audio.shape) | |
elapsed = time.time() - t1 | |
tot_samples = true_samples + padded_samples | |
return true_samples / tot_samples, padded_samples / tot_samples, elapsed | |
def main(): | |
# load the minilibrispeech dataset | |
# as mentioned in tuto: | |
train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_json("data.json") | |
# we define a pipeline to read audio | |
@speechbrain.utils.data_pipeline.takes("file_path") | |
@speechbrain.utils.data_pipeline.provides("signal") | |
def audio_pipeline(file_path): | |
sig = speechbrain.dataio.dataio.read_audio(file_path) | |
return sig | |
# setting the pipeline | |
train_data.add_dynamic_item(audio_pipeline) | |
train_data.set_output_keys(["signal", "file_path"]) | |
batch_size = 32 | |
max_batch_len = 17 * 32 | |
# random | |
dataloader = DataLoader(train_data, collate_fn=PaddedBatch, batch_size=batch_size) | |
percent_true, percent_padded, elapsed = count_samples(dataloader) | |
print("Random Sampling: % True samples {}, % of padding {}, Total time {}".format(percent_true, percent_padded, elapsed)) | |
# sort | |
sorted_data = train_data.filtered_sorted(sort_key="length") | |
dataloader = DataLoader(sorted_data, collate_fn=PaddedBatch, batch_size=batch_size) | |
percent_true, percent_padded, elapsed = count_samples(dataloader) | |
print("After sorting: % True samples {}, % of padding {}, Total time {}".format(percent_true, percent_padded, elapsed)) | |
# num_buckets = 20 | |
num_buckets_list = [5, 10, 15, 20, 60] | |
for num_buckets in num_buckets_list: | |
print(f"\nnum_buckets: {num_buckets}") | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSamplerOrg( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["length"] / 16000, | |
shuffle=False, | |
batch_ordering="descending", | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch) | |
percent_true, percent_padded, elapsed = count_samples(dataloader) | |
print( | |
"Org Dynamic Batching: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format( | |
percent_true, percent_padded, elapsed, elapsed_sampler | |
) | |
) | |
# t1 = time.time() | |
# dynamic_batcher = DynamicBatchSamplerMdf( | |
# train_data, | |
# max_batch_length=max_batch_len, | |
# num_buckets=num_buckets, | |
# length_func=lambda x: x["length"] / 16000, | |
# shuffle=False, | |
# batch_ordering="descending", | |
# boundary_mode="fixed", | |
# ) | |
# elapsed_sampler = time.time() - t1 | |
# dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch) | |
# percent_true, percent_padded, elapsed = count_samples(dataloader) | |
# print( | |
# "Mdf Dynamic Batching w/ fixed lognorm: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format( | |
# percent_true, percent_padded, elapsed, elapsed_sampler | |
# ) | |
# ) | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSamplerMdf( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["length"] / 16000, | |
shuffle=False, | |
batch_ordering="descending", | |
boundary_mode="fitted", | |
boundary_rv_dist="lognorm", | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch) | |
percent_true, percent_padded, elapsed = count_samples(dataloader) | |
print( | |
"Mdf Dynamic Batching w/ fitted lognorm: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format( | |
percent_true, percent_padded, elapsed, elapsed_sampler | |
) | |
) | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSamplerMdf( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["length"] / 16000, | |
shuffle=False, | |
batch_ordering="descending", | |
boundary_mode="fitted", | |
boundary_rv_dist="beta", | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch) | |
percent_true, percent_padded, elapsed = count_samples(dataloader) | |
print( | |
"Mdf Dynamic Batching w/ fitted beta: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format( | |
percent_true, percent_padded, elapsed, elapsed_sampler | |
) | |
) | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSamplerKM( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["length"] / 16000, | |
shuffle=False, | |
batch_ordering="descending", | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch) | |
percent_true, percent_padded, elapsed = count_samples(dataloader) | |
print( | |
"Mdf Dynamic Batching w/ Kmeans: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format( | |
percent_true, percent_padded, elapsed, elapsed_sampler | |
) | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment