Skip to content

Instantly share code, notes, and snippets.

@manashmandal
Created February 18, 2018 15:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save manashmandal/44b8375875a705d6451a850c2968b41c to your computer and use it in GitHub Desktop.
Save manashmandal/44b8375875a705d6451a850c2968b41c to your computer and use it in GitHub Desktop.
def infer(input_data, model=model):
inference = []
classes = np.array(['desert', 'mountain', 'sea', 'sunset', 'trees'])
y_pred = model.predict(input_data)
# Performing masking
y_pred = (y_pred > 0.5) * 1.0
for i in range(y_pred.shape[0]):
# select the indices
indices = np.where(y_pred[i] == 1.0)[0]
# Adding the results
inference.append(classes[indices].tolist())
return inference
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment