Skip to content

Instantly share code, notes, and snippets.

@danishansari
Created May 28, 2019 15:45
Show Gist options
  • Save danishansari/00c46052654241d93a6fa600c8bb1ce1 to your computer and use it in GitHub Desktop.
Save danishansari/00c46052654241d93a6fa600c8bb1ce1 to your computer and use it in GitHub Desktop.
Multi-label classification
import os, sys, cv2
import numpy as np
from data_manager import DataManager
from model import get_model, get_cust_model
n_classes = 5
shape = (128, 128, 3)
def train(inp_path):
trn = DataManager(inp_path, shape, 'TRAIN')
val = DataManager(inp_path, shape, 'VALID')
model = get_cust_model(shape, n_classes)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(trn.dataset[0], trn.dataset[1], epochs=10, validation_data=(val.dataset[0], val.dataset[1]), batch_size=64)
model.save('scenery_10.hd5')
model.evaluate(x=val.dataset[0], y=val.dataset[1], batch_size=64)
def test(inp_path, model_path):
model = get_cust_model(shape, n_classes)
model.load_weights(model_path)
#val = DataManager(inp_path, shape, 'VALID')
val = DataManager(inp_path, shape, 'TRAIN')
corr, size = 0, len(val.dataset[0])
for i in range(len(val.dataset[0])):
image, labels = val.dataset[0][i], val.dataset[1][i]
image = np.expand_dims(image, axis=0)
predn = model.predict(image)
predn = np.array([0 if e < 0.46 else 1 for e in predn[0]])
if np.equal(predn, labels).all():
corr += 1
print ('acc = %.2f' % float(corr/size))
if __name__=='__main__':
if sys.argv[-1] == '-train':
train(sys.argv[1])
else:
test(sys.argv[1], sys.argv[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment