Skip to content

Instantly share code, notes, and snippets.

@sunprinceS
Created November 12, 2019 13:41
Show Gist options
  • Save sunprinceS/0907e0a75684b58a3dd672fde7676c3a to your computer and use it in GitHub Desktop.
Save sunprinceS/0907e0a75684b58a3dd672fde7676c3a to your computer and use it in GitHub Desktop.
class DataContainer:
def __init__(self, data_dirs, batch_size, dev_batch_size, is_memmap,
is_bucket, num_workers=0, min_ilen=None, max_ilen=None,
half_batch_ilen=None, bucket_reverse=False, shuffle=True,
read_file=False, drop_last=False, pin_memory=True):
self.data_dirs = data_dirs
self.num_datasets = len(self.data_dirs)
self.batch_size = batch_size
self.is_memmap = is_memmap
self.is_bucket = is_bucket
self.num_workers = num_workers
self.min_ilen = min_ilen
self.max_ilen = max_ilen
self.half_batch_ilen = half_batch_ilen
self.bucket_reverse=bucket_reverse
self.shuffle = shuffle
self.read_file = read_file
self.reload_cnt = 0
self.loader_iters = list()
self.dev_loaders = list()
for data_dir in self.data_dirs:
self.loader_iters.append(
iter(get_loader(
data_dir.joinpath('train'),
batch_size = self.batch_size,
is_memmap = self.is_memmap,
is_bucket = self.is_bucket,
num_workers = self.num_workers,
min_ilen = self.min_ilen,
max_ilen = self.max_ilen,
half_batch_ilen = self.half_batch_ilen,
bucket_reverse = self.bucket_reverse,
shuffle = self.shuffle,
read_file = self.read_file
)))
self.dev_loaders.append(
get_loader(
data_dir.joinpath('dev'),
batch_size = dev_batch_size,
is_memmap = self.is_memmap,
is_bucket = False,
num_workers = self.num_workers,
shuffle =False,
))
def get_item(self, lang_idx=None, num=1):
ret_ls = []
if lang_idx is None: # for MultiASR
lang_ids = np.random.randint(self.num_datasets, size=num)
else:
lang_ids = np.repeat(lang_idx,num)
for lang_id in lang_ids:
try:
ret = next(self.loader_iters[lang_id])
ret_ls.append((lang_id,ret))
except StopIteration:
self.loader_iters[lang_id] = iter(get_loader(
self.data_dirs[lang_id].joinpath('train'),
batch_size = self.batch_size,
is_memmap = self.is_memmap,
is_bucket = self.is_bucket,
num_workers = self.num_workers,
min_ilen = self.min_ilen,
max_ilen = self.max_ilen,
half_batch_ilen = self.half_batch_ilen,
bucket_reverse = self.bucket_reverse,
shuffle = self.shuffle,
read_file = self.read_file))
self.reload_cnt += 1
ret = next(self.loader_iters[lang_id])
ret_ls.append((lang_id,ret))
return ret_ls
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment