Created
June 3, 2022 09:23
-
-
Save bofenghuang/c335a038699b162fb28cee635c4c3e66 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import logging | |
import time | |
from typing import List, Optional | |
import numpy as np | |
import scipy.stats | |
import speechbrain | |
import torch | |
from scipy.stats import lognorm | |
from speechbrain.dataio.batch import PaddedBatch | |
from speechbrain.dataio.dataset import DynamicItemDataset | |
from speechbrain.dataio.sampler import DynamicBatchSampler as DynamicBatchSamplerOrg | |
from torch.utils.data import DataLoader, Sampler | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class DynamicBatchSampler(Sampler): | |
def __init__( | |
self, | |
dataset, | |
max_batch_length: int, | |
num_buckets: int = None, | |
length_func=lambda x: x["duration"], | |
shuffle: bool = True, | |
batch_ordering: str = "random", | |
max_batch_ex: int = None, | |
bucket_boundaries: List[int] = [], | |
lengths_list: List[int] = None, | |
seed: int = 42, | |
epoch: int = 0, | |
drop_last: bool = False, | |
mode: str = "lognorm", | |
fit_dataset: bool = False, | |
verbose: bool = False, | |
): | |
self._dataset = dataset | |
self._ex_lengths = {} | |
ex_ids = self._dataset.data_ids | |
self.verbose = verbose | |
# We do not put a default on num_buckets to encourage users to play with this parameter | |
if num_buckets is None and len(bucket_boundaries) == 0: | |
raise RuntimeError( | |
"Please specify either num_buckets or bucket boundaries." "Check the docs, and/or the tutorial !" | |
) | |
if lengths_list is not None: | |
# take length of examples from this argument and bypass length_key | |
# for indx in range(len(lengths_list)): | |
# self._ex_lengths[str(indx)] = lengths_list[indx] | |
# todo : rm other | |
self._ex_lengths = np.array(lengths_list) | |
else: | |
# use length func | |
if not isinstance(dataset, DynamicItemDataset): | |
raise NotImplementedError("Dataset should be a Speechbrain DynamicItemDataset when using length function") | |
# for indx in range(len(self._dataset)): | |
# self._ex_lengths[str(indx)] = length_func(self._dataset.data[ex_ids[indx]]) | |
self._ex_lengths = [length_func(self._dataset.data[ex_ids[indx]]) for indx in range(len(self._dataset))] | |
self._ex_lengths = np.array(self._ex_lengths) | |
self._max_batch_length = max_batch_length | |
self._shuffle_ex = shuffle | |
self._batch_ordering = batch_ordering | |
self._seed = seed | |
self._drop_last = drop_last | |
if max_batch_ex is None: | |
max_batch_ex = np.inf | |
self._max_batch_ex = max_batch_ex | |
if len(bucket_boundaries) > 0: | |
if not all([x >= 0 for x in bucket_boundaries]): | |
raise ValueError("All elements in bucket boundaries should be non-negative (>= 0).") | |
if not len(set(bucket_boundaries)) == len(bucket_boundaries): | |
raise ValueError("Bucket_boundaries should not contain duplicates.") | |
np.testing.assert_array_equal( | |
np.array(bucket_boundaries), | |
np.array(sorted(bucket_boundaries)), | |
err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!", | |
) | |
self._bucket_boundaries = np.array(sorted(bucket_boundaries)) | |
self._ex_bucket_ids = self._get_bucket_ids() | |
else: | |
if mode == "kmeans": | |
self._ex_bucket_ids = self._get_bucket_ids_by_kmeans(num_buckets=num_buckets) | |
# get boudaries for debugging | |
self._bucket_boundaries = np.array(self._get_boundaries_for_clusters(num_buckets=num_buckets)) | |
else: | |
try: | |
rv = getattr(scipy.stats, mode) | |
except AttributeError: | |
msg = f"Cannot import {mode} distribution from Scipy. Please use another random variable distribution like lognorm" | |
raise ImportError(msg) | |
self._bucket_boundaries = np.array( | |
self._get_boundaries_through_warping( | |
rv, max_batch_length=max_batch_length, num_quantiles=num_buckets, fit_dataset=fit_dataset | |
) | |
) | |
self._ex_bucket_ids = self._get_bucket_ids() | |
self._bucket_lens = self._get_bucket_lens(num_buckets, max_batch_length) | |
self._epoch = epoch | |
self._generate_batches() | |
logger.info("\n\n") | |
def get_durations(self, batch): | |
# return [self._ex_lengths[str(idx)] for idx in batch] | |
return self._ex_lengths[batch] | |
def _get_bucket_ids(self): | |
return np.searchsorted(self._bucket_boundaries, self._ex_lengths) | |
def _get_bucket_ids_by_kmeans(self, num_buckets: int): | |
try: | |
from sklearn.cluster import KMeans | |
except ImportError: | |
msg = "Please install sklearn to use kmeans\n" | |
msg += "e.g. run: pip3 install -U scikit-learn" | |
raise ImportError(msg) | |
lengths = self._ex_lengths.reshape(-1, 1) | |
km = KMeans(n_clusters=num_buckets, random_state=self._seed).fit(lengths) | |
# sort cluster by centroid | |
sorted_indices = np.argsort(km.cluster_centers_.reshape((-1,))) | |
sorted_clusters = np.zeros_like(sorted_indices) | |
sorted_clusters[sorted_indices] = np.arange(num_buckets) | |
return sorted_clusters[km.labels_] | |
def _get_boundaries_for_clusters(self, num_buckets: int, side: str = "left"): | |
cluster_boundaries = [] | |
for bucket_id in range(num_buckets): | |
len_by_cluster = self._ex_lengths[np.where(self._ex_bucket_ids == bucket_id)] | |
cluster_boundaries.append([len_by_cluster.min(), len_by_cluster.max()]) | |
upper_boundaries = [] | |
for indx in range(num_buckets - 1): | |
upper_boundaries.append(cluster_boundaries[indx][1] if side == "left" else cluster_boundaries[indx + 1][0]) | |
upper_boundaries.append(cluster_boundaries[-1][1]) | |
return upper_boundaries | |
def _get_bucket_lens(self, num_buckets: int, max_batch_length: int): | |
# todo: -inf | |
max_lens_by_bucket = [ | |
self._ex_lengths[np.where(self._ex_bucket_ids == bucket_id)].max(initial=-1) for bucket_id in range(num_buckets) | |
] | |
# Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length? | |
# self._bucket_lens = [ | |
# max(1, int(max_batch_length / self._bucket_boundaries[i])) for i in range(len(self._bucket_boundaries)) | |
# ] + [1] | |
# ? one less bucket than the other implementation | |
bucket_lens = [max(1, int(max_batch_length / max_len)) for max_len in max_lens_by_bucket] | |
return bucket_lens | |
def _get_boundaries_through_warping( | |
self, | |
rv: scipy.stats.rv_continuous, | |
max_batch_length: int, | |
num_quantiles: int, | |
fit_dataset: bool = False, | |
) -> List[int]: | |
# NOTE: the following lines do not cover that there is only one example in the dataset | |
# warp frames (duration) distribution of train data | |
logger.info("Batch quantisation in latent space") | |
if fit_dataset: | |
# ? better use num_buckets | |
latent_boundaries = np.linspace( | |
1 / num_quantiles, | |
1, | |
num_quantiles, | |
) | |
# ? add floc=0 and fscale=1 | |
# RuntimeWarning: invalid value encountered in sqrt sk = 2*(b-a)*np.sqrt(a + b + 1) / (a + b + 2) / np.sqrt(a*b) | |
rv_params = rv.fit(self._ex_lengths) | |
# last upper boundary is always inf | |
bucket_boundaries = rv.ppf(latent_boundaries, *rv_params) | |
# replace inf by max length | |
# bucket_boundaries[-1] = max(max(lengths), max(bucket_boundaries)) | |
# todo: add log | |
else: | |
# linspace set-up | |
num_boundaries = num_quantiles + 1 | |
# create latent linearly equal spaced buckets | |
latent_boundaries = np.linspace( | |
1 / num_boundaries, | |
num_quantiles / num_boundaries, | |
num_quantiles, | |
) | |
# get quantiles using lognormal distribution | |
# quantiles = lognorm.ppf(latent_boundaries, 1) | |
quantiles = rv.ppf(latent_boundaries, 1) | |
# scale up to to max_batch_length | |
bucket_boundaries = quantiles * max_batch_length / quantiles[-1] | |
# todo: add log | |
# compute resulting bucket length multipliers | |
length_multipliers = [bucket_boundaries[x + 1] / bucket_boundaries[x] for x in range(num_quantiles - 1)] | |
# logging | |
# todo: log format | |
logger.info( | |
"Latent bucket boundary - buckets: {} - length multipliers: {}".format( | |
list(map("{:.2f}".format, bucket_boundaries)), | |
list(map("{:.2f}".format, length_multipliers)), | |
) | |
) | |
return list(sorted(bucket_boundaries)) | |
def _permute_batches(self): | |
if self._batch_ordering == "random": | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self._seed + self._epoch) | |
sampler = torch.randperm(len(self._batches), generator=g).tolist() # type: ignore | |
tmp = [] | |
for idx in sampler: | |
tmp.append(self._batches[idx]) | |
self._batches = tmp | |
elif self._batch_ordering == "ascending": | |
self._batches = sorted( | |
self._batches, | |
# key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), | |
key=lambda x: self._ex_lengths[x].max(), | |
) | |
elif self._batch_ordering == "descending": | |
self._batches = sorted( | |
self._batches, | |
# key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), | |
key=lambda x: self._ex_lengths[x].max(), | |
reverse=True, | |
) | |
else: | |
raise NotImplementedError | |
def _generate_batches(self): | |
logger.info("DynamicBatchSampler: Generating dynamic batches") | |
if self._shuffle_ex: | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self._seed + self._epoch) | |
sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore | |
else: | |
# take examples as they are: e.g. they have been sorted | |
sampler = range(len(self._dataset)) # type: ignore | |
self._batches = [] | |
bucket_batches = [[] for i in self._bucket_lens] | |
stats_tracker = [ | |
{"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0, "item_lengths": [], "item_lengths_by_batch": []} | |
for i in self._bucket_lens | |
] | |
for idx in sampler: | |
# length of pre-sampled audio | |
# item_len = self._ex_lengths[str(idx)] | |
item_len = self._ex_lengths[idx] | |
bucket_id = self._ex_bucket_ids[idx] | |
# fill audio's duration into that bucket | |
bucket_batches[bucket_id].append(idx) | |
# stats_tracker[bucket_id]["min"] = min(stats_tracker[bucket_id]["min"], item_len) | |
# stats_tracker[bucket_id]["max"] = max(stats_tracker[bucket_id]["max"], item_len) | |
# stats_tracker[bucket_id]["tot"] += item_len | |
# stats_tracker[bucket_id]["n_ex"] += 1 | |
stats_tracker[bucket_id]["item_lengths"].append(item_len) | |
# track #samples - why not duration/#frames; rounded up? | |
# keep track of durations, if necessary | |
if ( | |
len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id] | |
or len(bucket_batches[bucket_id]) >= self._max_batch_ex | |
): | |
self._batches.append(bucket_batches[bucket_id]) | |
bucket_batches[bucket_id] = [] | |
# keep track of durations | |
stats_tracker[bucket_id]["item_lengths_by_batch"].append(stats_tracker[bucket_id]["item_lengths"]) | |
stats_tracker[bucket_id]["item_lengths"] = [] | |
# Dump remaining batches | |
if not self._drop_last: | |
for bucket_id, batch in enumerate(bucket_batches): | |
if batch: | |
self._batches.append(batch) | |
stats_tracker[bucket_id]["item_lengths_by_batch"].append(stats_tracker[bucket_id].pop("item_lengths")) | |
# todo: save mem | |
stats_tracker[bucket_id]["item_lengths"] = [] | |
self._permute_batches() # possibly reorder batches | |
if self._epoch == 0: # only log at first epoch | |
# frames per batch & their padding remaining | |
boundaries = [0] + self._bucket_boundaries.tolist() | |
n_true_samples = 0 | |
n_all_samples = 0 | |
n_tot_batches = 0 | |
for bucket_indx in range(len(self._bucket_boundaries)): | |
# shape: n_batchs * n_examples_per_batch | |
# item_lengths_by_batch = np.array(stats_tracker[bucket_indx]["item_lengths_by_batch"]) | |
# num_batches = item_lengths_by_batch.shape[0] | |
# max_len_by_batch = item_lengths_by_batch.max(axis=1) | |
item_lengths_by_batch = stats_tracker[bucket_indx]["item_lengths_by_batch"] | |
n_batches = len(item_lengths_by_batch) | |
n_tot_batches += n_batches | |
n_items_by_batch = [len(item_len) for item_len in item_lengths_by_batch] | |
n_items = sum(n_items_by_batch) | |
max_lengths_by_batch = [max(item_len) for item_len in item_lengths_by_batch] | |
n_true_samples_by_bucket = sum(y for x in item_lengths_by_batch for y in x) | |
n_all_samples_by_bucket = sum(n * m for n, m in zip(n_items_by_batch, max_lengths_by_batch)) | |
n_true_samples += n_true_samples_by_bucket | |
n_all_samples += n_all_samples_by_bucket | |
try: | |
pct_padding = 1 - n_true_samples_by_bucket / n_all_samples_by_bucket | |
except ZeroDivisionError: | |
pct_padding = 0 | |
# try: | |
# num_batches = stats_tracker[bucket_indx]["tot"] // (self._max_batch_length) | |
# pad_factor = (stats_tracker[bucket_indx]["max"] - stats_tracker[bucket_indx]["min"]) / ( | |
# stats_tracker[bucket_indx]["tot"] / stats_tracker[bucket_indx]["n_ex"] | |
# ) | |
# except ZeroDivisionError: | |
# num_batches = 0 | |
# pad_factor = 0 | |
logger.info( | |
( | |
"DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and " | |
+ "Batch Size {}, Num Examples {}, Num Batches {}, % of padding {:.2f}%." | |
).format( | |
bucket_indx, | |
boundaries[bucket_indx], | |
boundaries[bucket_indx + 1], | |
self._bucket_lens[bucket_indx], | |
# stats_tracker[bucket_indx]["n_ex"], | |
n_items, | |
n_batches, | |
# pad_factor * 100, | |
pct_padding * 100, | |
) | |
) | |
pct_true = n_true_samples / n_all_samples * 100 | |
logger.info( | |
"DynamicBatchSampler: % true samples {:.2f}%, % of padding {:.2f}%, num of batches {}".format( | |
pct_true, 100 - pct_true, n_tot_batches | |
) | |
) | |
# if self.verbose: | |
# batch_stats = { | |
# "tot_frames": [], | |
# "tot_pad_frames": [], | |
# "pad_%": [], | |
# } | |
# for batch in self._batches: | |
# tot_frames = sum([self._ex_lengths[str(idx)] for idx in batch]) | |
# batch_stats["tot_frames"].append(tot_frames) | |
# max_frames = max([self._ex_lengths[str(idx)] for idx in batch]) | |
# tot_pad = sum([max_frames - self._ex_lengths[str(idx)] for idx in batch]) | |
# batch_stats["tot_pad_frames"].append(tot_pad) | |
# batch_stats["pad_%"].append(tot_pad / tot_frames * 100) | |
# padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total." | |
# padding_details = "DynamicBatchSampler: " + padding_details | |
# for i in range(len(self._batches)): | |
# logger.info( | |
# padding_details.format( | |
# i, | |
# batch_stats["tot_frames"][i], | |
# len(self._batches[i]), | |
# batch_stats["tot_pad_frames"][i], | |
# batch_stats["pad_%"][i], | |
# ) | |
# ) | |
def __iter__(self): | |
for batch in self._batches: | |
yield batch | |
if self._shuffle_ex: # re-generate examples if ex_ordering == "random" | |
self._generate_batches() | |
if self._batch_ordering == "random": | |
# we randomly permute the batches only --> faster | |
self._permute_batches() | |
else: | |
pass | |
def set_epoch(self, epoch): | |
""" | |
You can also just access self.epoch, but we maintain this interface | |
to mirror torch.utils.data.distributed.DistributedSampler | |
""" | |
self._epoch = epoch | |
self._generate_batches() | |
def __len__(self): | |
return len(self._batches) | |
def count_samples(dataloader): | |
true_samples = 0 | |
padded_samples = 0 | |
n_batches = len(dataloader) | |
t1 = time.time() | |
for batch in dataloader: | |
audio, lens = batch.signal | |
true_samples += torch.sum(audio.shape[-1] * lens).item() | |
padded_samples += torch.sum(audio.shape[-1] * (1 - lens)).item() | |
# print(audio.shape) | |
elapsed = time.time() - t1 | |
tot_samples = true_samples + padded_samples | |
ratio_true = true_samples / tot_samples | |
return ratio_true, 1 - ratio_true, n_batches, elapsed | |
def main(): | |
# load the minilibrispeech dataset | |
# as mentioned in tuto: | |
# train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_json("data.json") | |
# train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_csv("/home/bhuang/asr/speechbrain/dataio/data/data.csv") | |
# train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_csv("/home/bhuang/asr/speechbrain/recipes/LibriSpeech/ASR/CTC/results/train_wav2vec2_char/1986/train.csv") | |
# train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_csv("/home/bhuang/asr/speechbrain/recipes/CommonVoice/ASR/CTC/results/wav2vec2_ctc_it/1234/save/train.csv") | |
train_data = speechbrain.dataio.dataset.DynamicItemDataset.from_csv("/home/bhuang/asr/speechbrain/recipes/CommonVoice/ASR/CTC/results/wav2vec2_ctc_fr/1234/save/train.csv") | |
# we define a pipeline to read audio | |
# @speechbrain.utils.data_pipeline.takes("file_path") | |
@speechbrain.utils.data_pipeline.takes("wav") | |
@speechbrain.utils.data_pipeline.provides("signal") | |
def audio_pipeline(file_path): | |
sig = speechbrain.dataio.dataio.read_audio(file_path) | |
return sig | |
# setting the pipeline | |
train_data.add_dynamic_item(audio_pipeline) | |
# train_data.set_output_keys(["signal", "file_path"]) | |
# train_data.set_output_keys(["signal", "wav"]) | |
train_data.set_output_keys(["signal", "wav", "duration"]) | |
num_workers = 64 | |
batch_size = 32 | |
# random | |
dataloader = DataLoader(train_data, collate_fn=PaddedBatch, batch_size=batch_size, num_workers=num_workers) | |
ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
print( | |
"Random Sampling: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s".format( | |
ratio_true, ratio_padding, n_batches, elapsed | |
) | |
) | |
# quit() | |
# sort | |
# sorted_data = train_data.filtered_sorted(sort_key="length") | |
sorted_data = train_data.filtered_sorted(sort_key="duration", reverse=True) | |
dataloader = DataLoader(sorted_data, collate_fn=PaddedBatch, batch_size=batch_size, num_workers=num_workers) | |
ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
print( | |
"After sorting: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s".format( | |
ratio_true, ratio_padding, n_batches, elapsed | |
) | |
) | |
max_dur = sorted_data[0]["duration"] | |
max_batch_len = max_dur * 32 | |
print(f"\nbatch_size: {batch_size}, max_dur: {max_dur}, max_batch_len is set to {max_batch_len}") | |
# num_buckets = 20 | |
num_buckets_list = [5, 10, 20, 60] | |
# num_buckets_list = [5, 10, 20] | |
# num_buckets_list = [10] | |
for num_buckets in num_buckets_list: | |
print(f"\nnum_buckets: {num_buckets}") | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSamplerOrg( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["duration"], | |
shuffle=False, | |
batch_ordering="descending", | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch, num_workers=num_workers) | |
ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
print( | |
"Org Dynamic Batching: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s, sampler initialization time {:.4f}s".format( | |
ratio_true, ratio_padding, n_batches, elapsed, elapsed_sampler | |
) | |
) | |
# todo: here we use max_len instead of bucket boudary to get n_examples_per_batch | |
# t1 = time.time() | |
# dynamic_batcher = DynamicBatchSampler( | |
# train_data, | |
# max_batch_length=max_batch_len, | |
# num_buckets=num_buckets, | |
# length_func=lambda x: x["duration"], | |
# shuffle=False, | |
# batch_ordering="descending", | |
# mode="lognorm", | |
# fit_dataset=False, | |
# ) | |
# elapsed_sampler = time.time() - t1 | |
# dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch, num_workers=num_workers) | |
# ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
# print( | |
# "Mdf Dynamic Batching w/ fixed lognorm: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s, sampler initialization time {:.4f}s".format( | |
# ratio_true, ratio_padding, n_batches, elapsed, elapsed_sampler | |
# ) | |
# ) | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSampler( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["duration"], | |
shuffle=False, | |
batch_ordering="descending", | |
mode="lognorm", | |
fit_dataset=True, | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch, num_workers=num_workers) | |
ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
print( | |
"Mdf Dynamic Batching w/ fit lognorm: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s, sampler initialization time {:.4f}s".format( | |
ratio_true, ratio_padding, n_batches, elapsed, elapsed_sampler | |
) | |
) | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSampler( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["duration"], | |
shuffle=False, | |
batch_ordering="descending", | |
mode="beta", | |
fit_dataset=True, | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch, num_workers=num_workers) | |
ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
print( | |
"Mdf Dynamic Batching w/ fit beta: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s, sampler initialization time {:.4f}s".format( | |
ratio_true, ratio_padding, n_batches, elapsed, elapsed_sampler | |
) | |
) | |
t1 = time.time() | |
dynamic_batcher = DynamicBatchSampler( | |
train_data, | |
max_batch_length=max_batch_len, | |
num_buckets=num_buckets, | |
length_func=lambda x: x["duration"], | |
shuffle=False, | |
batch_ordering="descending", | |
mode="kmeans", | |
) | |
elapsed_sampler = time.time() - t1 | |
dataloader = DataLoader(train_data, batch_sampler=dynamic_batcher, collate_fn=PaddedBatch, num_workers=num_workers) | |
ratio_true, ratio_padding, n_batches, elapsed = count_samples(dataloader) | |
print( | |
"Mdf Dynamic Batching w/ kmeans: ratio of true samples {:.4f}, ratio of padding {:.4f}, num of batches {}, total time {:.4f}s, sampler initialization time {:.4f}s".format( | |
ratio_true, ratio_padding, n_batches, elapsed, elapsed_sampler | |
) | |
) | |
if __name__ == "__main__": | |
main() | |
# todo: | |
# scale to maw_len | |
# throw away example longer than max, searchsorted | |
# ? A higher number means we are going to have, on average, an higher batch size so you must apply the same "tricks" as when batch size is increased for standard fixed batch size training. E.g. increase learning rate. | |
# randomness | |
# less padding | |
# use our hardware at the fullest with variant bs -> less #epoch and iterations |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment