Skip to content

Instantly share code, notes, and snippets.

@dipanjanS
Created September 20, 2019 11:07
Show Gist options
  • Save dipanjanS/61f9c8d00635b9cd95d255bcb55b7d54 to your computer and use it in GitHub Desktop.
Save dipanjanS/61f9c8d00635b9cd95d255bcb55b7d54 to your computer and use it in GitHub Desktop.
# create serving function
INPUT_SHAPE_RN = (32, 32, 3)
model2 = create_cnn_architecture_model2(input_shape=INPUT_SHAPE_RN)
model2.load_weights('./model_weights/cnn_model2_wt.h5')
def predict_apparel_model2_regular(img, img_dims=(32,32), label_map=class_names):
sample_img_processed = (np.array([resize_image_array(img,
img_size_dims=img_dims)
for img in np.stack([[img]]*3,
axis=-1)])) / 255.
prediction = model2.predict(sample_img_processed)
prediction = np.argmax(np.array(prediction), axis=1)[0]
return label_map[prediction]
# benchmark 10K requests
%%time
pred_labels = []
for img in tqdm(test_images):
pred_label = predict_apparel_model2_regular(img)
pred_labels.append(img)
len(pred_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment