Skip to content

Instantly share code, notes, and snippets.

@hpcslag
Last active September 6, 2020 08:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hpcslag/4f1494bf3d06d747d074b4840b448130 to your computer and use it in GitHub Desktop.
Save hpcslag/4f1494bf3d06d747d074b4840b448130 to your computer and use it in GitHub Desktop.
給定模型、圖片,告訴我結果
import os
import sys
import random
import math
import re
import time
import numpy as np
import cv2
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
#調整參數
INFERENCE_IMAGE = "C:\\Users\\Mac\\Pictures\\d0d3a9f4ac8a65c58b7b1701d14a440a-600x400.jpg"
WEIGHTS_PATH = "C:\\Users\\Mac\\Desktop\\aktwelve_mask_rcnn\\mask_rcnn_coco.h5"
MASK_RCNN_PROJECT_PATH = "C:\\Users\\Mac\\Desktop\\aktwelve_mask_rcnn"
DEVICE = "/cpu:0" # /cpu:0 or /gpu:0
# Import Mask RCNN
sys.path.append(MASK_RCNN_PROJECT_PATH) # To find local version of the library
from mrcnn import utils
from mrcnn import visualize
from mrcnn.visualize import display_images
import mrcnn.model as modellib
from mrcnn.model import log
# MS COCO Dataset
import coco
config = coco.CocoConfig()
# Override the training configurations with a few
# changes for inferencing.
class InferenceConfig(config.__class__):
# Run detection on one image at a time
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
config.display()
# Inspect the model in training or inference modes
# values: 'inference' or 'training'
# TODO: code for 'training' test mode not ready yet
TEST_MODE = "inference"
def get_ax(rows=1, cols=1, size=16):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes.
Adjust the size attribute to control how big to render images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
# Build validation dataset
dataset = coco.CocoDataset()
classes = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
# Must call before using the dataset
dataset.prepare()
# Create model in inference mode
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir="./logs", config=config)
# Load weights
print("Loading weights ", WEIGHTS_PATH)
model.load_weights(WEIGHTS_PATH, by_name=True)
import cv2
im = cv2.imread(INFERENCE_IMAGE)
# Run object detection
results = model.detect([im], verbose=1)
# Display results
ax = get_ax(1)
r = results[0]
visualize.display_instances(im, r['rois'], r['masks'], r['class_ids'],
classes, r['scores'], ax=ax,
title="Predictions")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment