Last active
February 24, 2020 17:25
-
-
Save nkpro2000sr/d1ff8fbe877955a860f5e945557c1465 to your computer and use it in GitHub Desktop.
To generate batches for training DeepLearning model from a dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, random, numpy | |
def getPathALables(p_TDS, p_VDS =None): | |
""" | |
$p_TDS = path of training data set | |
$p_VDS = path of validation data set | |
$return[0] = All_labels | |
$return[1][0] = paths of files in training data set | |
$return[1][1] = one hot encoded lables for $return[1][0] | |
$rerurn[2] is same as $return[1] for validation data set | |
""" | |
All_labels = sorted([l for l in os.listdir(p_TDS) if os.path.isdir(os.path.join(p_TDS,l))]) | |
TDS_files, TDS_labels = [], [] | |
for path in os.walk(p_TDS) : | |
if len(path[1]) == 0 : | |
label = os.path.split(path[0])[1] | |
if All_labels.count(label) != 1 : continue | |
if label in All_labels : | |
for p in path[2] : | |
TDS_files.append(os.path.join(path[0],p)) | |
OHE_label = [0]*len(All_labels) | |
OHE_label[All_labels.index(label)] = 1 | |
TDS_labels.append(OHE_label) | |
if not p_VDS : return (All_labels, (TDS_files, TDS_labels)) | |
VDS_files, VDS_labels = [], [] | |
for path in os.walk(p_VDS) : | |
if len(path[1]) == 0 : | |
label = os.path.split(path[0])[1] | |
if All_labels.count(label) != 1 : continue | |
if label in All_labels : | |
for p in path[2] : | |
VDS_files.append(os.path.join(path[0],p)) | |
OHE_label = [0]*len(All_labels) | |
OHE_label[All_labels.index(label)] = 1 | |
VDS_labels.append(OHE_label) | |
return (All_labels, (TDS_files, TDS_labels), (VDS_files, VDS_labels)) | |
def Generator(P2D, batch_size, DS_paths, DS_labels, input_shape =None, seed =258621): | |
""" | |
generator for datas and one hot encoded labels | |
$P2D = function to get data from path (randomly chosen from DS_paths) | |
$batch_size = -1 => batch_size == total_no_of_datas | |
$input_shape is passed as second argument while calling P2D | |
= None => P2D is called with only path | |
$yield[0] = batch of data (X) | |
$yield[1] = batch of one hot encoded labels (Y) | |
""" | |
datas, labels = [], [] | |
random.seed(seed) | |
pathAlabels = list(zip(DS_paths,DS_labels)) | |
random.shuffle(pathAlabels) | |
index = 0 | |
while 5 : | |
if index == len(pathAlabels) : | |
index = 0 | |
random.shuffle(pathAlabels) | |
while len(labels) < batch_size or batch_size == -1 : | |
if input_shape == None : datas.append(P2D(pathAlabels[index][0])) | |
else : datas.append(P2D(pathAlabels[index][0], input_shape)) | |
labels.append(pathAlabels[index][1]) | |
index += 1 | |
datas, labels = numpy.array(datas), numpy.array(labels) | |
yield datas, labels | |
datas, labels = [], [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment