Skip to content

Instantly share code, notes, and snippets.

@mahxn0
Created January 14, 2019 06:42
Show Gist options
  • Save mahxn0/28220e9e33bb30e072a51e54d0413894 to your computer and use it in GitHub Desktop.
Save mahxn0/28220e9e33bb30e072a51e54d0413894 to your computer and use it in GitHub Desktop.
deeplabv3+ demo
# codinf=utf-8
import sys
sys.path.append('./utils/')
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf
import get_dataset_colormap
import os
import matplotlib
import cv2
LABEL_NAMES = np.asarray([
'_background_','RoadMarking_LongSolidLine','RoadMarking_DottedLine','RoadMarking_ArrowLine',
'RoadMarking_EntranceLine','RoadMarking_TransverseSolidLine'
,'RoadMarking_Sidewalk','RoadMarking_DottedLineChangXi','mark','RoadMarking_MeshLine','RoadMarking_DecelerationHeng',
'RoadMarking_python labelme2voc.py --hDecelerationZong','RoadMarking_DottedLineDuanXi'
])
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = get_dataset_colormap.label_to_color_image(FULL_LABEL_MAP)
class DeepLabModel(object):
"""Class to load deeplab model and run inference."""
INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513
def __init__(self, model_path):
"""Creates and loads pretrained deeplab model."""
self.graph = tf.Graph()
with open(model_path,'rb') as fd:
graph_def = tf.GraphDef.FromString(fd.read())
with self.graph.as_default():
tf.import_graph_def(graph_def, name='')
self.sess = tf.Session(graph=self.graph)
def run(self, image):
"""Runs inference on a single image.
Args:
image: A PIL.Image object, raw input image.
Returns:
resized_image: RGB image resized from original input image.
seg_map: Segmentation map of `resized_image`.
"""
width, height = image.size
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = image.convert('RGB').resize(target_size,
Image.ANTIALIAS)
batch_seg_map = self.sess.run(
self.OUTPUT_TENSOR_NAME,
feed_dict={
self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]
})
seg_map = batch_seg_map[0]
np.set_printoptions(threshold = np.nan)
#print(type(seg_map))
return resized_image, seg_map
def vis_segmentation(image, seg_map,outputimage):
#plt.figure()
#plt.subplot(221)
#plt.imshow(image)
#plt.axis('off')
#plt.title('input image')
#plt.subplot(222)
seg_image = get_dataset_colormap.label_to_color_image(
seg_map, get_dataset_colormap.get_pascal_name()).astype(np.uint8)
#plt.imshow(seg_image)
#plt.axis('off')
#plt.title('segmentation map')
#plt.subplot(223)
#plt.imshow(image)
#plt.imshow(seg_image, alpha=0.7)
#plt.axis('off')
#plt.title('segmentation overlay')
#unique_labels = np.unique(seg_map)
img_ori = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
img_seg= cv2.cvtColor(np.asarray(seg_image),cv2.COLOR_RGB2BGR)
#cv2.imshow("ori_image",img_ori)
#cv2.imshow("seg_image",img_seg)
dst = cv2.addWeighted(img_ori,0.7,img_seg,0.3,0)
cv2.imshow("dts",dst)
cv2.imwrite(outputimage,dst)
cv2.waitKey(30)
#ax = plt.subplot(224)
#plt.imshow(
# FULL_COLOR_MAP[unique_labels].astype(np.uint8),
# interpolation='nearest')
#ax.yaxis.tick_right()
#plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
#plt.xticks([], [])
#ax.tick_params(width=0)
#plt.savefig(outputimage)
#plt.show()
if __name__ == '__main__':
#if len(sys.argv) < 3:
# print('Usage: python {} image_path model_path'.format(sys.argv[0]))
# exit()
#image_path = sys.argv[1]
#model_path = sys.argv[2]
model_path="/home/mahxn0/M_DeepLearning/models/research/deeplab/datasets/RoadMarking/model/frozen_rect10000.pb" ###模型地址
model = DeepLabModel(model_path)
inputdir="/home/mahxn0/M_DeepLearning/models/research/deeplab/datasets/RoadMarking/rect_img/" ### 输入图片路径
outputdir="/home/mahxn0/M_DeepLearning/models/research/deeplab/datasets/RoadMarking/result/" ### 保存结果路径
filenames=os.listdir(inputdir)
for filename in filenames:
if filename[-3:]=="jpg" or filename[-3:]=="png":
inputimage=os.path.join(inputdir,filename)
print(inputimage)
prename=filename.split(".")[0]
prename=prename+'.png'
outputimage=os.path.join(outputdir,prename)
orignal_im = Image.open(inputimage)
resized_im, seg_map = model.run(orignal_im)
vis_segmentation(resized_im, seg_map,outputimage) ###不想显示注释掉这里
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment