Skip to content

Instantly share code, notes, and snippets.

@amankharwal
Created November 8, 2020 06:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amankharwal/b71375a7d70ad0481b35559b8b307e2f to your computer and use it in GitHub Desktop.
Save amankharwal/b71375a7d70ad0481b35559b8b307e2f to your computer and use it in GitHub Desktop.
### Function used for processing the data, fitted into a data generator.
def get_image_from_number(num, df):
fname, label = df.iloc[num,:]
fname = fname + ".jpg"
f1 = fname[0]
f2 = fname[1]
f3 = fname[2]
path = os.path.join(f1,f2,f3,fname)
im = cv2.imread(os.path.join(base_path,path))
return im, label
def image_reshape(im, target_size):
return cv2.resize(im, target_size)
def get_batch(dataframe,start, batch_size):
image_array = []
label_array = []
end_img = start+batch_size
if end_img > len(dataframe):
end_img = len(dataframe)
for idx in range(start, end_img):
n = idx
im, label = get_image_from_number(n, dataframe)
im = image_reshape(im, (224, 224)) / 255.0
image_array.append(im)
label_array.append(label)
label_array = encode_label(label_array)
return np.array(image_array), np.array(label_array)
batch_size = 16
epoch_shuffle = True
weight_classes = True
epochs = 15
# Split train data up into 80% and 20% validation
train, validate = np.split(df.sample(frac=1), [int(.8*len(df))])
print("Training on:", len(train), "samples")
print("Validation on:", len(validate), "samples")
for e in range(epochs):
print("Epoch: ", str(e+1) + "/" + str(epochs))
if epoch_shuffle:
train = train.sample(frac = 1)
for it in range(int(np.ceil(len(train)/batch_size))):
X_train, y_train = get_batch(train, it*batch_size, batch_size)
model.train_on_batch(X_train, y_train)
model.save("Model.h5")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment