Skip to content

Instantly share code, notes, and snippets.

@bofenghuang
Last active November 8, 2022 23:54
Show Gist options
  • Save bofenghuang/87245983a3ad9a0567ec364bc0aea18c to your computer and use it in GitHub Desktop.
Save bofenghuang/87245983a3ad9a0567ec364bc0aea18c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
Based on https://colab.research.google.com/drive/1mypqbHDrusZaIbqPoiEGY-WIbnpMHa2I?usp=sharing#scrollTo=bY0AwDiSxdVE
- Random Sampling
- After sorting (minimum padding but no randomness)
- Org Dynamic Batching (the one in speechbrain current version)
- **Mdf Dynamic Batching w/ fitted lognorm** (bucket boundaries set up with lognormal distribution fitted on dataset)
- **Mdf Dynamic Batching w/ fitted beta** (bucket boundaries set up with beta distribution fitted on dataset, mentioned in tuto)
- **Mdf Dynamic Batching w/ Kmeans** (examples aggregated into buckets by Kmeans)
"""
import logging
import time
from typing import List
import numpy as np
import scipy.stats
import speechbrain
import torch
from scipy.stats import lognorm
from speechbrain.dataio.batch import PaddedBatch
from speechbrain.dataio.dataset import DynamicItemDataset
from speechbrain.dataio.sampler import DynamicBatchSampler as DynamicBatchSamplerOrg
from torch.utils.data import DataLoader, Sampler
logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO)
class DynamicBatchSamplerMdf(Sampler):
def __init__(
self,
dataset,
max_batch_length: int,
num_buckets: int = None,
length_func=lambda x: x["duration"],
shuffle: bool = True,
batch_ordering: str = "random",
max_batch_ex: int = None,
bucket_boundaries: List[int] = [],
lengths_list: List[int] = None,
seed: int = 42,
epoch: int = 0,
drop_last: bool = False,
verbose: bool = False,
boundary_mode: str = "fixed",
boundary_rv_dist: str = "lognorm",
):
self._dataset = dataset
self._ex_lengths = {}
ex_ids = self._dataset.data_ids
self.verbose = verbose
# We do not put a default on num_buckets to encourage users to play with this parameter
if num_buckets is None and len(bucket_boundaries) == 0:
raise RuntimeError(
"Please specify either num_buckets or bucket boundaries." "Check the docs, and/or the tutorial !"
)
if lengths_list is not None:
# take length of examples from this argument and bypass length_key
for indx in range(len(lengths_list)):
self._ex_lengths[str(indx)] = lengths_list[indx]
else:
# use length func
if not isinstance(dataset, DynamicItemDataset):
raise NotImplementedError("Dataset should be a Speechbrain DynamicItemDataset when using length function")
for indx in range(len(self._dataset)):
self._ex_lengths[str(indx)] = length_func(self._dataset.data[ex_ids[indx]])
if len(bucket_boundaries) > 0:
if not all([x >= 0 for x in bucket_boundaries]):
raise ValueError("All elements in bucket boundaries should be non-negative (>= 0).")
if not len(set(bucket_boundaries)) == len(bucket_boundaries):
raise ValueError("Bucket_boundaries should not contain duplicates.")
np.testing.assert_array_equal(
np.array(bucket_boundaries),
np.array(sorted(bucket_boundaries)),
err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!",
)
self._bucket_boundaries = np.array(sorted(bucket_boundaries))
else:
# use num_buckets
# self._bucket_boundaries = np.array(
# self._get_boundaries_through_warping(
# max_batch_length=max_batch_length,
# num_quantiles=num_buckets,
# )
# )
self._bucket_boundaries = np.array(
self._get_boundaries(
max_batch_length=max_batch_length,
num_quantiles=num_buckets,
mode=boundary_mode,
rv_dist=boundary_rv_dist,
)
)
self._max_batch_length = max_batch_length
self._shuffle_ex = shuffle
self._batch_ordering = batch_ordering
self._seed = seed
self._drop_last = drop_last
if max_batch_ex is None:
max_batch_ex = np.inf
self._max_batch_ex = max_batch_ex
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
self._bucket_lens = [
max(1, int(max_batch_length / self._bucket_boundaries[i])) for i in range(len(self._bucket_boundaries))
] + [1]
self._epoch = epoch
self._generate_batches()
def get_durations(self, batch):
return [self._ex_lengths[str(idx)] for idx in batch]
def _get_boundaries(self, max_batch_length: int, num_quantiles: int, mode: str, rv_dist: str):
if mode == "fixed":
return self._get_boundaries_through_warping(
max_batch_length=max_batch_length,
num_quantiles=num_quantiles,
)
elif mode == "fitted":
return self._get_boundaries_through_fitted_warping(num_buckets=num_quantiles, rv_dist=rv_dist)
else:
raise NotImplementedError
def _get_boundaries_through_fitted_warping(
self,
num_buckets: int,
rv_dist: str,
):
latent_boundaries = np.linspace(
1 / num_buckets,
1,
num_buckets,
)
lengths = list(self._ex_lengths.values())
try:
rv = getattr(scipy.stats, rv_dist)
except AttributeError:
logger.warning(f"{rv_dist} doesn't exist, use beta instead")
rv = scipy.stats.beta
# ? should add floc=0 and fscale=1
# RuntimeWarning: invalid value encountered in sqrt sk = 2*(b-a)*np.sqrt(a + b + 1) / (a + b + 2) / np.sqrt(a*b)
rv_params = rv.fit(lengths)
# last upper boundary is always inf
bucket_boundaries = rv.ppf(latent_boundaries, *rv_params)
# replace inf by max length
# bucket_boundaries[-1] = max(max(lengths), max(bucket_boundaries))
return bucket_boundaries
def _get_boundaries_through_warping(
self,
max_batch_length: int,
num_quantiles: int,
) -> List[int]:
# NOTE: the following lines do not cover that there is only one example in the dataset
# warp frames (duration) distribution of train data
logger.info("Batch quantisation in latent space")
# linspace set-up
num_boundaries = num_quantiles + 1
# create latent linearly equal spaced buckets
latent_boundaries = np.linspace(
1 / num_boundaries,
num_quantiles / num_boundaries,
num_quantiles,
)
# get quantiles using lognormal distribution
quantiles = lognorm.ppf(latent_boundaries, 1)
# scale up to to max_batch_length
bucket_boundaries = quantiles * max_batch_length / quantiles[-1]
# compute resulting bucket length multipliers
length_multipliers = [bucket_boundaries[x + 1] / bucket_boundaries[x] for x in range(num_quantiles - 1)]
# logging
logger.info(
"Latent bucket boundary - buckets: {} - length multipliers: {}".format(
list(map("{:.2f}".format, bucket_boundaries)),
list(map("{:.2f}".format, length_multipliers)),
)
)
return list(sorted(bucket_boundaries))
def _permute_batches(self):
if self._batch_ordering == "random":
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
sampler = torch.randperm(len(self._batches), generator=g).tolist() # type: ignore
tmp = []
for idx in sampler:
tmp.append(self._batches[idx])
self._batches = tmp
elif self._batch_ordering == "ascending":
self._batches = sorted(
self._batches,
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
)
elif self._batch_ordering == "descending":
self._batches = sorted(
self._batches,
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
reverse=True,
)
else:
raise NotImplementedError
def _generate_batches(self):
logger.info("DynamicBatchSampler: Generating dynamic batches")
if self._shuffle_ex:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore
else:
# take examples as they are: e.g. they have been sorted
sampler = range(len(self._dataset)) # type: ignore
self._batches = []
bucket_batches = [[] for i in self._bucket_lens]
stats_tracker = [{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} for i in self._bucket_lens]
for idx in sampler:
# length of pre-sampled audio
item_len = self._ex_lengths[str(idx)]
# bucket to fill up most padding
bucket_id = np.searchsorted(self._bucket_boundaries, item_len)
# fill audio's duration into that bucket
bucket_batches[bucket_id].append(idx)
stats_tracker[bucket_id]["min"] = min(stats_tracker[bucket_id]["min"], item_len)
stats_tracker[bucket_id]["max"] = max(stats_tracker[bucket_id]["max"], item_len)
stats_tracker[bucket_id]["tot"] += item_len
stats_tracker[bucket_id]["n_ex"] += 1
# track #samples - why not duration/#frames; rounded up?
# keep track of durations, if necessary
if (
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id]
or len(bucket_batches[bucket_id]) >= self._max_batch_ex
):
self._batches.append(bucket_batches[bucket_id])
bucket_batches[bucket_id] = []
# keep track of durations
# Dump remaining batches
if not self._drop_last:
for batch in bucket_batches:
if batch:
self._batches.append(batch)
self._permute_batches() # possibly reorder batches
if self._epoch == 0: # only log at first epoch
# frames per batch & their padding remaining
boundaries = [0] + self._bucket_boundaries.tolist()
for bucket_indx in range(len(self._bucket_boundaries)):
try:
num_batches = stats_tracker[bucket_indx]["tot"] // (self._max_batch_length)
pad_factor = (stats_tracker[bucket_indx]["max"] - stats_tracker[bucket_indx]["min"]) / (
stats_tracker[bucket_indx]["tot"] / stats_tracker[bucket_indx]["n_ex"]
)
except ZeroDivisionError:
num_batches = 0
pad_factor = 0
logger.info(
(
"DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and "
+ "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}."
).format(
bucket_indx,
boundaries[bucket_indx],
boundaries[bucket_indx + 1],
self._bucket_lens[bucket_indx],
stats_tracker[bucket_indx]["n_ex"],
num_batches,
pad_factor * 100,
)
)
if self.verbose:
batch_stats = {
"tot_frames": [],
"tot_pad_frames": [],
"pad_%": [],
}
for batch in self._batches:
tot_frames = sum([self._ex_lengths[str(idx)] for idx in batch])
batch_stats["tot_frames"].append(tot_frames)
max_frames = max([self._ex_lengths[str(idx)] for idx in batch])
tot_pad = sum([max_frames - self._ex_lengths[str(idx)] for idx in batch])
batch_stats["tot_pad_frames"].append(tot_pad)
batch_stats["pad_%"].append(tot_pad / tot_frames * 100)
padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total."
padding_details = "DynamicBatchSampler: " + padding_details
for i in range(len(self._batches)):
logger.info(
padding_details.format(
i,
batch_stats["tot_frames"][i],
len(self._batches[i]),
batch_stats["tot_pad_frames"][i],
batch_stats["pad_%"][i],
)
)
def __iter__(self):
for batch in self._batches:
yield batch
if self._shuffle_ex: # re-generate examples if ex_ordering == "random"
self._generate_batches()
if self._batch_ordering == "random":
# we randomly permute the batches only --> faster
self._permute_batches()
else:
pass
def set_epoch(self, epoch):
"""
You can also just access self.epoch, but we maintain this interface
to mirror torch.utils.data.distributed.DistributedSampler
"""
self._epoch = epoch
self._generate_batches()
def __len__(self):
return len(self._batches)
class DynamicBatchSamplerKM(Sampler):
def __init__(
self,
dataset,
max_batch_length: int,
num_buckets: int = None,
length_func=lambda x: x["duration"],
shuffle: bool = True,
batch_ordering: str = "random",
max_batch_ex: int = None,
bucket_boundaries: List[int] = [],
lengths_list: List[int] = None,
seed: int = 42,
epoch: int = 0,
drop_last: bool = False,
verbose: bool = False,
):
self._dataset = dataset
self._ex_lengths = {}
ex_ids = self._dataset.data_ids
self.verbose = verbose
# We do not put a default on num_buckets to encourage users to play with this parameter
if num_buckets is None and len(bucket_boundaries) == 0:
raise RuntimeError(
"Please specify either num_buckets or bucket boundaries." "Check the docs, and/or the tutorial !"
)
if lengths_list is not None:
# take length of examples from this argument and bypass length_key
for indx in range(len(lengths_list)):
self._ex_lengths[str(indx)] = lengths_list[indx]
else:
# use length func
if not isinstance(dataset, DynamicItemDataset):
raise NotImplementedError("Dataset should be a Speechbrain DynamicItemDataset when using length function")
for indx in range(len(self._dataset)):
self._ex_lengths[str(indx)] = length_func(self._dataset.data[ex_ids[indx]])
# if len(bucket_boundaries) > 0:
# if not all([x >= 0 for x in bucket_boundaries]):
# raise ValueError(
# "All elements in bucket boundaries should be non-negative (>= 0)."
# )
# if not len(set(bucket_boundaries)) == len(bucket_boundaries):
# raise ValueError(
# "Bucket_boundaries should not contain duplicates."
# )
# np.testing.assert_array_equal(
# np.array(bucket_boundaries),
# np.array(sorted(bucket_boundaries)),
# err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!",
# )
# self._bucket_boundaries = np.array(sorted(bucket_boundaries))
# else:
# # use num_buckets
# self._bucket_boundaries = np.array(
# self._get_boundaries_through_warping(
# max_batch_length=max_batch_length,
# num_quantiles=num_buckets,
# )
# )
self._max_batch_length = max_batch_length
self._shuffle_ex = shuffle
self._batch_ordering = batch_ordering
self._seed = seed
self._drop_last = drop_last
if max_batch_ex is None:
max_batch_ex = np.inf
self._max_batch_ex = max_batch_ex
self._ex_bucket_ids = self._get_bucket_ids(num_buckets=num_buckets)
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length?
# self._bucket_lens = [
# max(1, int(max_batch_length / self._bucket_boundaries[i])) for i in range(len(self._bucket_boundaries))
# ] + [1]
self._bucket_lens = self._get_bucket_lens(num_buckets, max_batch_length)
self._epoch = epoch
self._generate_batches()
def get_durations(self, batch):
return [self._ex_lengths[str(idx)] for idx in batch]
def _get_bucket_ids(self, num_buckets: int):
try:
from sklearn.cluster import KMeans
except ImportError:
print("pip3 install -U scikit-learn")
lengths = np.asarray(list(self._ex_lengths.values()))
lengths = lengths.reshape(-1, 1)
km = KMeans(n_clusters=num_buckets, random_state=self._seed).fit(lengths)
# sort cluster by centroid
idx = np.argsort(km.cluster_centers_.reshape((-1,)))
lut = np.zeros_like(idx)
lut[idx] = np.arange(num_buckets)
return lut[km.labels_]
def _get_bucket_lens(self, num_buckets: int, max_batch_length: int):
max_lens_by_bucket = []
for bucket_id in range(num_buckets):
# ? why not use np array
lengths = np.asarray(list(self._ex_lengths.values()))
max_lens_by_bucket.append(lengths[np.where(self._ex_bucket_ids == bucket_id)].max())
# ? one less bucket than other methods
bucket_lens = [max(1, int(max_batch_length / max_len)) for max_len in max_lens_by_bucket]
return bucket_lens
def _get_boundaries_through_warping(
self,
max_batch_length: int,
num_quantiles: int,
) -> List[int]:
# NOTE: the following lines do not cover that there is only one example in the dataset
# warp frames (duration) distribution of train data
logger.info("Batch quantisation in latent space")
# linspace set-up
num_boundaries = num_quantiles + 1
# create latent linearly equal spaced buckets
latent_boundaries = np.linspace(
1 / num_boundaries,
num_quantiles / num_boundaries,
num_quantiles,
)
# get quantiles using lognormal distribution
quantiles = lognorm.ppf(latent_boundaries, 1)
# scale up to to max_batch_length
bucket_boundaries = quantiles * max_batch_length / quantiles[-1]
# compute resulting bucket length multipliers
length_multipliers = [bucket_boundaries[x + 1] / bucket_boundaries[x] for x in range(num_quantiles - 1)]
# logging
logger.info(
"Latent bucket boundary - buckets: {} - length multipliers: {}".format(
list(map("{:.2f}".format, bucket_boundaries)),
list(map("{:.2f}".format, length_multipliers)),
)
)
return list(sorted(bucket_boundaries))
def _permute_batches(self):
if self._batch_ordering == "random":
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
sampler = torch.randperm(len(self._batches), generator=g).tolist() # type: ignore
tmp = []
for idx in sampler:
tmp.append(self._batches[idx])
self._batches = tmp
elif self._batch_ordering == "ascending":
self._batches = sorted(
self._batches,
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
)
elif self._batch_ordering == "descending":
self._batches = sorted(
self._batches,
key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]),
reverse=True,
)
else:
raise NotImplementedError
def _generate_batches(self):
logger.info("DynamicBatchSampler: Generating dynamic batches")
if self._shuffle_ex:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore
else:
# take examples as they are: e.g. they have been sorted
sampler = range(len(self._dataset)) # type: ignore
self._batches = []
bucket_batches = [[] for i in self._bucket_lens]
stats_tracker = [{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} for i in self._bucket_lens]
for idx in sampler:
# length of pre-sampled audio
item_len = self._ex_lengths[str(idx)]
# bucket to fill up most padding
# bucket_id = np.searchsorted(self._bucket_boundaries, item_len)
bucket_id = self._ex_bucket_ids[idx]
# fill audio's duration into that bucket
bucket_batches[bucket_id].append(idx)
stats_tracker[bucket_id]["min"] = min(stats_tracker[bucket_id]["min"], item_len)
stats_tracker[bucket_id]["max"] = max(stats_tracker[bucket_id]["max"], item_len)
stats_tracker[bucket_id]["tot"] += item_len
stats_tracker[bucket_id]["n_ex"] += 1
# track #samples - why not duration/#frames; rounded up?
# keep track of durations, if necessary
if (
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id]
or len(bucket_batches[bucket_id]) >= self._max_batch_ex
):
self._batches.append(bucket_batches[bucket_id])
bucket_batches[bucket_id] = []
# keep track of durations
# Dump remaining batches
if not self._drop_last:
for batch in bucket_batches:
if batch:
self._batches.append(batch)
self._permute_batches() # possibly reorder batches
# if self._epoch == 0: # only log at first epoch
# # frames per batch & their padding remaining
# boundaries = [0] + self._bucket_boundaries.tolist()
# for bucket_indx in range(len(self._bucket_boundaries)):
# try:
# num_batches = stats_tracker[bucket_indx]["tot"] // (self._max_batch_length)
# pad_factor = (stats_tracker[bucket_indx]["max"] - stats_tracker[bucket_indx]["min"]) / (
# stats_tracker[bucket_indx]["tot"] / stats_tracker[bucket_indx]["n_ex"]
# )
# except ZeroDivisionError:
# num_batches = 0
# pad_factor = 0
# logger.info(
# (
# "DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and "
# + "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}."
# ).format(
# bucket_indx,
# boundaries[bucket_indx],
# boundaries[bucket_indx + 1],
# self._bucket_lens[bucket_indx],
# stats_tracker[bucket_indx]["n_ex"],
# num_batches,
# pad_factor * 100,
# )
# )
# if self.verbose:
# batch_stats = {
# "tot_frames": [],
# "tot_pad_frames": [],
# "pad_%": [],
# }
# for batch in self._batches:
# tot_frames = sum([self._ex_lengths[str(idx)] for idx in batch])
# batch_stats["tot_frames"].append(tot_frames)
# max_frames = max([self._ex_lengths[str(idx)] for idx in batch])
# tot_pad = sum([max_frames - self._ex_lengths[str(idx)] for idx in batch])
# batch_stats["tot_pad_frames"].append(tot_pad)
# batch_stats["pad_%"].append(tot_pad / tot_frames * 100)
# padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total."
# padding_details = "DynamicBatchSampler: " + padding_details
# for i in range(len(self._batches)):
# logger.info(
# padding_details.format(
# i,
# batch_stats["tot_frames"][i],
# len(self._batches[i]),
# batch_stats["tot_pad_frames"][i],
# batch_stats["pad_%"][i],
# )
# )
def __iter__(self):
for batch in self._batches:
yield batch
if self._shuffle_ex: # re-generate examples if ex_ordering == "random"
self._generate_batches()
if self._batch_ordering == "random":
# we randomly permute the batches only --> faster
self._permute_batches()
else:
pass
def set_epoch(self, epoch):
"""
You can also just access self.epoch, but we maintain this interface
to mirror torch.utils.data.distributed.DistributedSampler
"""
self._epoch = epoch
self._generate_batches()
def __len__(self):
return len(self._batches)
def count_samples(dataloader):
true_samples = 0
padded_samples = 0
t1 = time.time()
for batch in dataloader:
audio, lens = batch.signal
true_samples += torch.sum(audio.shape[-1] * lens).item()
padded_samples += torch.sum(audio.shape[-1] * (1 - lens)).item()
# print(audio.shape)
elapsed = time.time() - t1
tot_samples = true_samples + padded_samples
return true_samples / tot_samples, padded_samples / tot_samples, elapsed
def main():
# load the minilibrispeech dataset
# as mentioned in tuto:
train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_json("data.json")
# we define a pipeline to read audio
@speechbrain.utils.data_pipeline.takes("file_path")
@speechbrain.utils.data_pipeline.provides("signal")
def audio_pipeline(file_path):
sig = speechbrain.dataio.dataio.read_audio(file_path)
return sig
# setting the pipeline
train_data.add_dynamic_item(audio_pipeline)
train_data.set_output_keys(["signal", "file_path"])
batch_size = 32
max_batch_len = 17 * 32
# random
dataloader = DataLoader(train_data, collate_fn=PaddedBatch, batch_size=batch_size)
percent_true, percent_padded, elapsed = count_samples(dataloader)
print("Random Sampling: % True samples {}, % of padding {}, Total time {}".format(percent_true, percent_padded, elapsed))
# sort
sorted_data = train_data.filtered_sorted(sort_key="length")
dataloader = DataLoader(sorted_data, collate_fn=PaddedBatch, batch_size=batch_size)
percent_true, percent_padded, elapsed = count_samples(dataloader)
print("After sorting: % True samples {}, % of padding {}, Total time {}".format(percent_true, percent_padded, elapsed))
# num_buckets = 20
num_buckets_list = [5, 10, 15, 20, 60]
for num_buckets in num_buckets_list:
print(f"\nnum_buckets: {num_buckets}")
t1 = time.time()
dynamic_batcher = DynamicBatchSamplerOrg(
train_data,
max_batch_length=max_batch_len,
num_buckets=num_buckets,
length_func=lambda x: x["length"] / 16000,
shuffle=False,
batch_ordering="descending",
)
elapsed_sampler = time.time() - t1
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch)
percent_true, percent_padded, elapsed = count_samples(dataloader)
print(
"Org Dynamic Batching: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format(
percent_true, percent_padded, elapsed, elapsed_sampler
)
)
# t1 = time.time()
# dynamic_batcher = DynamicBatchSamplerMdf(
# train_data,
# max_batch_length=max_batch_len,
# num_buckets=num_buckets,
# length_func=lambda x: x["length"] / 16000,
# shuffle=False,
# batch_ordering="descending",
# boundary_mode="fixed",
# )
# elapsed_sampler = time.time() - t1
# dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch)
# percent_true, percent_padded, elapsed = count_samples(dataloader)
# print(
# "Mdf Dynamic Batching w/ fixed lognorm: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format(
# percent_true, percent_padded, elapsed, elapsed_sampler
# )
# )
t1 = time.time()
dynamic_batcher = DynamicBatchSamplerMdf(
train_data,
max_batch_length=max_batch_len,
num_buckets=num_buckets,
length_func=lambda x: x["length"] / 16000,
shuffle=False,
batch_ordering="descending",
boundary_mode="fitted",
boundary_rv_dist="lognorm",
)
elapsed_sampler = time.time() - t1
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch)
percent_true, percent_padded, elapsed = count_samples(dataloader)
print(
"Mdf Dynamic Batching w/ fitted lognorm: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format(
percent_true, percent_padded, elapsed, elapsed_sampler
)
)
t1 = time.time()
dynamic_batcher = DynamicBatchSamplerMdf(
train_data,
max_batch_length=max_batch_len,
num_buckets=num_buckets,
length_func=lambda x: x["length"] / 16000,
shuffle=False,
batch_ordering="descending",
boundary_mode="fitted",
boundary_rv_dist="beta",
)
elapsed_sampler = time.time() - t1
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch)
percent_true, percent_padded, elapsed = count_samples(dataloader)
print(
"Mdf Dynamic Batching w/ fitted beta: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format(
percent_true, percent_padded, elapsed, elapsed_sampler
)
)
t1 = time.time()
dynamic_batcher = DynamicBatchSamplerKM(
train_data,
max_batch_length=max_batch_len,
num_buckets=num_buckets,
length_func=lambda x: x["length"] / 16000,
shuffle=False,
batch_ordering="descending",
)
elapsed_sampler = time.time() - t1
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch)
percent_true, percent_padded, elapsed = count_samples(dataloader)
print(
"Mdf Dynamic Batching w/ Kmeans: % True samples {}, % of padding {}, Total time {}, Sampler initialization time {}".format(
percent_true, percent_padded, elapsed, elapsed_sampler
)
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment