Skip to content

Instantly share code, notes, and snippets.

@michaellee8
Created November 7, 2021 15:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaellee8/2667922534e3589f1629efa616674c54 to your computer and use it in GitHub Desktop.
Save michaellee8/2667922534e3589f1629efa616674c54 to your computer and use it in GitHub Desktop.
jetson-inference run_ssd_example_annotate_video.py
from vision.ssd.vgg_ssd import create_vgg_ssd, create_vgg_ssd_predictor
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor
from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite, create_mobilenetv1_ssd_lite_predictor
from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite, create_squeezenet_ssd_lite_predictor
from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite, create_mobilenetv2_ssd_lite_predictor
from vision.utils.misc import Timer
import cv2
import sys
if len(sys.argv) < 5:
print('Usage: python run_ssd_example_annotate_video.py <net type> <model path> <label path> <input path> <output path>')
sys.exit(0)
net_type = sys.argv[1]
model_path = sys.argv[2]
label_path = sys.argv[3]
input_path = sys.argv[4]
output_path = sys.argv[5]
class_names = [name.strip() for name in open(label_path).readlines()]
if net_type == 'vgg16-ssd':
net = create_vgg_ssd(len(class_names), is_test=True)
elif net_type == 'mb1-ssd':
net = create_mobilenetv1_ssd(len(class_names), is_test=True)
elif net_type == 'mb1-ssd-lite':
net = create_mobilenetv1_ssd_lite(len(class_names), is_test=True)
elif net_type == 'mb2-ssd-lite':
net = create_mobilenetv2_ssd_lite(len(class_names), is_test=True)
elif net_type == 'sq-ssd-lite':
net = create_squeezenet_ssd_lite(len(class_names), is_test=True)
else:
print("The net type is wrong. It should be one of vgg16-ssd, mb1-ssd and mb1-ssd-lite.")
sys.exit(1)
net.load(model_path)
if net_type == 'vgg16-ssd':
predictor = create_vgg_ssd_predictor(net, candidate_size=200)
elif net_type == 'mb1-ssd':
predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200)
elif net_type == 'mb1-ssd-lite':
predictor = create_mobilenetv1_ssd_lite_predictor(net, candidate_size=200)
elif net_type == 'mb2-ssd-lite':
predictor = create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200)
elif net_type == 'sq-ssd-lite':
predictor = create_squeezenet_ssd_lite_predictor(net, candidate_size=200)
else:
predictor = create_vgg_ssd_predictor(net, candidate_size=200)
vid_capture = cv2.VideoCapture(input_path)
if (vid_capture.isOpened() == False):
print("Fatal: Error opening the video file")
exit(0)
fps = vid_capture.get(cv2.CAP_PROP_FPS)
frame_count = vid_capture.get(cv2.CAP_PROP_FRAME_COUNT)
print('Frames per second : ', fps,'FPS')
print('Frame count : ', frame_count)
# Obtain frame size information using get() method
frame_width = int(vid_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(vid_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_size = (frame_width,frame_height)
output = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'XVID'), fps, frame_size)
count = 0
while(vid_capture.isOpened()):
if count % 100 == 0:
print("now working on frame", count)
ret, frame = vid_capture.read()
if ret == False:
break
orig_image = frame
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
boxes, labels, probs = predictor.predict(image, 10, 0.4)
for i in range(boxes.size(0)):
box = boxes[i, :]
# print(box)
cv2.rectangle(orig_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 0), 4)
#label = f"""{voc_dataset.class_names[labels[i]]}: {probs[i]:.2f}"""
label = f"{class_names[labels[i]]}: {probs[i]:.2f}"
cv2.putText(orig_image, label,
(int(box[0]) + 20, int(box[1]) + 40),
cv2.FONT_HERSHEY_SIMPLEX,
1, # font scale
(255, 0, 255),
2) # line type
output.write(orig_image)
count += 1
# print(f"Found {len(probs)} objects. The output image is {path}")
vid_capture.release()
output.release()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment