Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save yuyu2172/25e7c8cb613a9b6d8f77db97e5da3f38 to your computer and use it in GitHub Desktop.
Save yuyu2172/25e7c8cb613a9b6d8f77db97e5da3f38 to your computer and use it in GitHub Desktop.
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import chainer
from chainercv.datasets import cityscapes_semantic_segmentation_label_colors
from chainercv.datasets import cityscapes_semantic_segmentation_label_names
from chainercv.datasets import voc_bbox_label_names
from chainercv.links import PSPNet
from chainercv.links import SSD512
from chainercv.utils import read_image
from chainercv.visualizations import vis_image
from chainercv.visualizations import vis_bbox
from chainercv.visualizations import vis_semantic_segmentation
chainer.config.train = False
def visualize(path, out_path):
orig_img = read_image(path)
img = orig_img.copy()
# Semantic Segmentation
pspnet = PSPNet(pretrained_model='cityscapes')
pspnet.to_gpu()
label_maps = pspnet.predict([img])
label_map = label_maps[0]
# Detection
ssd = SSD512(pretrained_model='voc0712')
ssd.to_gpu()
bboxes, labels, scores = ssd.predict([img])
bbox = bboxes[0]
label = labels[0]
score = scores[0]
# Setup Matplotlib Grid
fig = plt.gcf()
fig.set_size_inches(40, 30)
gs = gridspec.GridSpec(3, 2)
gs.update(wspace=0.05, hspace=0)
ax1 = plt.subplot(gs[:-1, :])
ax2 = plt.subplot(gs[-1, 0])
ax3 = plt.subplot(gs[-1, 1])
ax1.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()
# Visualize predictions
ax1 = vis_image(orig_img, ax=ax1)
ax2 = vis_bbox(orig_img, bbox, label, score, voc_bbox_label_names, ax=ax2)
ax3 = vis_image(orig_img, ax=ax3)
vis_semantic_segmentation(
label_map, cityscapes_semantic_segmentation_label_names,
cityscapes_semantic_segmentation_label_colors, alpha=0.9, ax=ax3)
# Save
plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
base_dir = '/dataA/yuyu2172/pfnet/chainercv/cityscapes/leftImg8bit/demoVideo/stuttgart_00/'
paths = sorted([os.path.join(base_dir, path) for path in os.listdir(base_dir)])
for i, path in enumerate(paths):
print('{}: processing {}'.format(i, path))
visualize(path, 'output/{}.png'.format(i))
@yuyu2172
Copy link
Author

yuyu2172 commented Oct 9, 2017

2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment