Skip to content

Instantly share code, notes, and snippets.

@AhmadMoussa
Last active December 8, 2019 05:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AhmadMoussa/1c42e819e4cbc22497093fabb58f3aa3 to your computer and use it in GitHub Desktop.
Save AhmadMoussa/1c42e819e4cbc22497093fabb58f3aa3 to your computer and use it in GitHub Desktop.
DataLoader for data loading purposes C:
import os
import numpy as np
class DataLoader():
def __init__(self, data_path, data_shape, index_path = 0, batch_size = 1):
self.data_path = data_path
self.data_shape = data_shape
self.data_index = self._load_data_index(index_path) if index_path else self._create_data_index(data_path)
self.shuffled_index = self.data_index
self.batch_size = batch_size
def _create_data_index(self, data_path):
file_names = os.listdir(data_path) # scandir() faster than listdir()
np.save("index.npy", file_names)
return file_names
def _load_data_index(self, index_path):
return np.load(index_path)
def __len__(self):
'''
:return: number of total batches, depends on batch size and index
'''
return int(np.floor(len(self.data_index) / float(self.batch_size)))
def __str__(self):
return "-- Dataset Path: {}\n" \
"-- Dataset Size: {}\n" \
"-- Batch Size: {}\n" \
"-- # of batches: {}\n" \
"-- Data Shape: {}\n" \
"".format(self.data_path, len(self.data_index), self.batch_size, self.__len__(), self.data_shape)
def load_batch(self):
for i in range(self.__len__()):
batch = self.shuffled_index[i * self.batch_size:(i + 1) * self.batch_size]
dry_batch = np.empty((self.batch_size, self.data_shape[0], self.data_shape[1]))
wet_batch = np.empty((self.batch_size, self.data_shape[0], self.data_shape[1]))
for j, file in enumerate(batch):
''' for audio this could be used
audio = np.load(os.path.join(self.data_path, file))
dry_batch[j,], wet_batch[j,] = audio[:16384].reshape(self.data_shape), audio[16384:32768].reshape(self.data_shape)
'''
yield dry_batch, wet_batch
# this should be called at the end of the epoch to shuffle data
def data_shuffler(self):
np.random.shuffle(self.shuffled_index)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment