Created
January 14, 2019 06:42
-
-
Save mahxn0/28220e9e33bb30e072a51e54d0413894 to your computer and use it in GitHub Desktop.
deeplabv3+ demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# codinf=utf-8 | |
import sys | |
sys.path.append('./utils/') | |
from matplotlib import pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
import get_dataset_colormap | |
import os | |
import matplotlib | |
import cv2 | |
LABEL_NAMES = np.asarray([ | |
'_background_','RoadMarking_LongSolidLine','RoadMarking_DottedLine','RoadMarking_ArrowLine', | |
'RoadMarking_EntranceLine','RoadMarking_TransverseSolidLine' | |
,'RoadMarking_Sidewalk','RoadMarking_DottedLineChangXi','mark','RoadMarking_MeshLine','RoadMarking_DecelerationHeng', | |
'RoadMarking_python labelme2voc.py --hDecelerationZong','RoadMarking_DottedLineDuanXi' | |
]) | |
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) | |
FULL_COLOR_MAP = get_dataset_colormap.label_to_color_image(FULL_LABEL_MAP) | |
class DeepLabModel(object): | |
"""Class to load deeplab model and run inference.""" | |
INPUT_TENSOR_NAME = 'ImageTensor:0' | |
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' | |
INPUT_SIZE = 513 | |
def __init__(self, model_path): | |
"""Creates and loads pretrained deeplab model.""" | |
self.graph = tf.Graph() | |
with open(model_path,'rb') as fd: | |
graph_def = tf.GraphDef.FromString(fd.read()) | |
with self.graph.as_default(): | |
tf.import_graph_def(graph_def, name='') | |
self.sess = tf.Session(graph=self.graph) | |
def run(self, image): | |
"""Runs inference on a single image. | |
Args: | |
image: A PIL.Image object, raw input image. | |
Returns: | |
resized_image: RGB image resized from original input image. | |
seg_map: Segmentation map of `resized_image`. | |
""" | |
width, height = image.size | |
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) | |
target_size = (int(resize_ratio * width), int(resize_ratio * height)) | |
resized_image = image.convert('RGB').resize(target_size, | |
Image.ANTIALIAS) | |
batch_seg_map = self.sess.run( | |
self.OUTPUT_TENSOR_NAME, | |
feed_dict={ | |
self.INPUT_TENSOR_NAME: [np.asarray(resized_image)] | |
}) | |
seg_map = batch_seg_map[0] | |
np.set_printoptions(threshold = np.nan) | |
#print(type(seg_map)) | |
return resized_image, seg_map | |
def vis_segmentation(image, seg_map,outputimage): | |
#plt.figure() | |
#plt.subplot(221) | |
#plt.imshow(image) | |
#plt.axis('off') | |
#plt.title('input image') | |
#plt.subplot(222) | |
seg_image = get_dataset_colormap.label_to_color_image( | |
seg_map, get_dataset_colormap.get_pascal_name()).astype(np.uint8) | |
#plt.imshow(seg_image) | |
#plt.axis('off') | |
#plt.title('segmentation map') | |
#plt.subplot(223) | |
#plt.imshow(image) | |
#plt.imshow(seg_image, alpha=0.7) | |
#plt.axis('off') | |
#plt.title('segmentation overlay') | |
#unique_labels = np.unique(seg_map) | |
img_ori = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR) | |
img_seg= cv2.cvtColor(np.asarray(seg_image),cv2.COLOR_RGB2BGR) | |
#cv2.imshow("ori_image",img_ori) | |
#cv2.imshow("seg_image",img_seg) | |
dst = cv2.addWeighted(img_ori,0.7,img_seg,0.3,0) | |
cv2.imshow("dts",dst) | |
cv2.imwrite(outputimage,dst) | |
cv2.waitKey(30) | |
#ax = plt.subplot(224) | |
#plt.imshow( | |
# FULL_COLOR_MAP[unique_labels].astype(np.uint8), | |
# interpolation='nearest') | |
#ax.yaxis.tick_right() | |
#plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) | |
#plt.xticks([], []) | |
#ax.tick_params(width=0) | |
#plt.savefig(outputimage) | |
#plt.show() | |
if __name__ == '__main__': | |
#if len(sys.argv) < 3: | |
# print('Usage: python {} image_path model_path'.format(sys.argv[0])) | |
# exit() | |
#image_path = sys.argv[1] | |
#model_path = sys.argv[2] | |
model_path="/home/mahxn0/M_DeepLearning/models/research/deeplab/datasets/RoadMarking/model/frozen_rect10000.pb" ###模型地址 | |
model = DeepLabModel(model_path) | |
inputdir="/home/mahxn0/M_DeepLearning/models/research/deeplab/datasets/RoadMarking/rect_img/" ### 输入图片路径 | |
outputdir="/home/mahxn0/M_DeepLearning/models/research/deeplab/datasets/RoadMarking/result/" ### 保存结果路径 | |
filenames=os.listdir(inputdir) | |
for filename in filenames: | |
if filename[-3:]=="jpg" or filename[-3:]=="png": | |
inputimage=os.path.join(inputdir,filename) | |
print(inputimage) | |
prename=filename.split(".")[0] | |
prename=prename+'.png' | |
outputimage=os.path.join(outputdir,prename) | |
orignal_im = Image.open(inputimage) | |
resized_im, seg_map = model.run(orignal_im) | |
vis_segmentation(resized_im, seg_map,outputimage) ###不想显示注释掉这里 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment