Skip to content

Instantly share code, notes, and snippets.

@nkpro2000sr
Last active February 24, 2020 17:25
Show Gist options
  • Save nkpro2000sr/d1ff8fbe877955a860f5e945557c1465 to your computer and use it in GitHub Desktop.
Save nkpro2000sr/d1ff8fbe877955a860f5e945557c1465 to your computer and use it in GitHub Desktop.
To generate batches for training DeepLearning model from a dataset.
import os, random, numpy
def getPathALables(p_TDS, p_VDS =None):
"""
$p_TDS = path of training data set
$p_VDS = path of validation data set
$return[0] = All_labels
$return[1][0] = paths of files in training data set
$return[1][1] = one hot encoded lables for $return[1][0]
$rerurn[2] is same as $return[1] for validation data set
"""
All_labels = sorted([l for l in os.listdir(p_TDS) if os.path.isdir(os.path.join(p_TDS,l))])
TDS_files, TDS_labels = [], []
for path in os.walk(p_TDS) :
if len(path[1]) == 0 :
label = os.path.split(path[0])[1]
if All_labels.count(label) != 1 : continue
if label in All_labels :
for p in path[2] :
TDS_files.append(os.path.join(path[0],p))
OHE_label = [0]*len(All_labels)
OHE_label[All_labels.index(label)] = 1
TDS_labels.append(OHE_label)
if not p_VDS : return (All_labels, (TDS_files, TDS_labels))
VDS_files, VDS_labels = [], []
for path in os.walk(p_VDS) :
if len(path[1]) == 0 :
label = os.path.split(path[0])[1]
if All_labels.count(label) != 1 : continue
if label in All_labels :
for p in path[2] :
VDS_files.append(os.path.join(path[0],p))
OHE_label = [0]*len(All_labels)
OHE_label[All_labels.index(label)] = 1
VDS_labels.append(OHE_label)
return (All_labels, (TDS_files, TDS_labels), (VDS_files, VDS_labels))
def Generator(P2D, batch_size, DS_paths, DS_labels, input_shape =None, seed =258621):
"""
generator for datas and one hot encoded labels
$P2D = function to get data from path (randomly chosen from DS_paths)
$batch_size = -1 => batch_size == total_no_of_datas
$input_shape is passed as second argument while calling P2D
= None => P2D is called with only path
$yield[0] = batch of data (X)
$yield[1] = batch of one hot encoded labels (Y)
"""
datas, labels = [], []
random.seed(seed)
pathAlabels = list(zip(DS_paths,DS_labels))
random.shuffle(pathAlabels)
index = 0
while 5 :
if index == len(pathAlabels) :
index = 0
random.shuffle(pathAlabels)
while len(labels) < batch_size or batch_size == -1 :
if input_shape == None : datas.append(P2D(pathAlabels[index][0]))
else : datas.append(P2D(pathAlabels[index][0], input_shape))
labels.append(pathAlabels[index][1])
index += 1
datas, labels = numpy.array(datas), numpy.array(labels)
yield datas, labels
datas, labels = [], []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment