Skip to content

Instantly share code, notes, and snippets.

@ctodd
Created July 9, 2020 06:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ctodd/467e4e18c85ceb8203db00cfb255086f to your computer and use it in GitHub Desktop.
Save ctodd/467e4e18c85ceb8203db00cfb255086f to your computer and use it in GitHub Desktop.
import glob
test_images = glob.glob('images/test/*')
print(*test_images, sep="\n")
def prediction_to_bbox_data(image_path, prediction):
class_id, confidence, xmin, ymin, xmax, ymax = prediction
width, height = Image.open(image_path).size
bbox_data = {'class_id': class_id,
'height': (ymax-ymin)*height,
'width': (xmax-xmin)*width,
'left': xmin*width,
'top': ymin*height}
return bbox_data
import matplotlib.pyplot as plt
runtime_client = boto3.client('sagemaker-runtime')
def get_predictions_for_img(runtime_client, endpoint_name, img_path):
with open(img_path, 'rb') as f:
payload = f.read()
payload = bytearray(payload)
response = runtime_client.invoke_endpoint(EndpointName=endpoint_name,
ContentType='application/x-image',
Body=payload)
result = response['Body'].read()
result = json.loads(result)
return result
# wait until the status has changed
client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)
endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
status = endpoint_response['EndpointStatus']
if status != 'InService':
raise Exception('Endpoint creation failed.')
for test_image in test_images:
result = get_predictions_for_img(runtime_client, endpoint_name, test_image)
confidence_threshold = .03
best_n = 1
# display the best n predictions with confidence > confidence_threshold
predictions = [prediction for prediction in result['prediction'] if prediction[1] > confidence_threshold]
predictions.sort(reverse=True, key = lambda x: x[1])
bboxes = [prediction_to_bbox_data(test_image, prediction) for prediction in predictions[:best_n]]
for prediction in predictions[:best_n]:
prediction_list = sorted(prediction, reverse = True)
for item in sorted(prediction, reverse = True):
if item == 1.0:
try:
prediction_list.remove(item)
except:
print("")
conf_precision3 = str(prediction_list[0])[0] + str(prediction_list[0])[1] + str(prediction_list[0])[2] + str(prediction_list[0])[3]
show_annotated_image(test_image, bboxes, conf_precision3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment