Skip to content

Instantly share code, notes, and snippets.

@albanie
Created February 18, 2017 19:42
Show Gist options
  • Save albanie/4ffdda3ef38d7f0cb356fb94247e8e22 to your computer and use it in GitHub Desktop.
Save albanie/4ffdda3ef38d7f0cb356fb94247e8e22 to your computer and use it in GitHub Desktop.
A simple inspection layer for caffe debugging
from __future__ import print_function
import matplotlib
matplotlib.use('Agg')
import colorsys
import caffe
import json
import pdb
import time
import sys
import numpy as np
import matplotlib.pyplot as plt
from os.path import expanduser
from os.path import join as pjoin
import scipy.io as sio
import matplotlib.patches as patches
from zvision.zv_iterm import zv_dispFig
visual = 1
in_layer = 'annotated_data'
#in_layer = 'conv1_1'
save_blob_dir = expanduser('~/coding/libs/matconvnets/ssd-matconvnet/data/albanie/stored_vars')
class PyInspectLayer(caffe.Layer):
"""
A simple python Layer that can be used to inspect the previous
layer's outputs.
"""
def setup(self, bottom, top):
if len(bottom) == 0:
raise Exception("PyInspectLayer should have at least one input.")
pp(dir(self))
pp(self.param_str)
if hasattr(self, 'param_str') and self.param_str:
params = json.loads(self.param_str)
self.input_layers = params['inputs']
else:
self.input_layers = 'data'
def reshape(self, bottom, top):
"""
set the output and deriviative sizes to match the
input size.
"""
top[0].reshape(*bottom[0].data.shape)
self.diff = np.zeros_like(bottom[0].data, dtype=np.float32)
def forward(self, bottom, top):
num_inputs = len(bottom)
for i in range(num_inputs):
data = bottom[i].data
layer_name = self.input_layers[i]
pp('layer name: {}'.format(layer_name))
pp(data.shape)
x = data.flatten()
x_norm = np.linalg.norm(x)
pp('min: {}'.format(min(x)))
pp('max: {}'.format(max(x)))
pp('norm: {}'.format(x_norm))
pp('type:{}'.format(x.dtype))
path = pjoin(save_blob_dir, layer_name + '.mat')
sio.savemat(path, {'data':data})
if visual:
if in_layer == 'annotated_data':
labels = bottom[1].data
item_ids = labels[:,:,:,0]
display_inputs(data, labels, item_ids)
sys.exit()
#if in_layer in interest_layers:
# pp('layer name: {}'.format(in_layer))
# x = data.flatten()
# x_norm = np.linalg.norm(x)
# pp('norm: {}'.format(x_norm))
# pp(' ')
#sys.exit()
def backward(self, top, propagate_down, bottom):
if propagate_down[0]:
bottom[0].diff[...] = np.cos(bottom[0].data) * top[0].diff
pascal_classes = [
'background'
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor',
]
def _get_colors(num_colors):
colors=[]
for i in np.arange(0., 360., 360. / num_colors):
hue = i/360.
lightness = (50 + np.random.rand() * 10)/100.
saturation = (90 + np.random.rand() * 10)/100.
colors.append(colorsys.hls_to_rgb(hue, lightness, saturation))
return colors
colors = _get_colors(len(pascal_classes))
def pp(x):
print(x, file=sys.stderr)
def display_inputs(data, labels, item_ids):
"""
visualize network inputs
"""
max_imgs = 2
path = expanduser('~/coding/libs/matconvnets/ssd-matconvnet/data/fixed_img.mat')
sio.savemat(path, {'data':data})
for i in range(min(data.shape[0], max_imgs)):
# data
B = data[i,0,:,:]
G = data[i,1,:,:]
R = data[i,2,:,:]
pp('B-sum: {} '.format(B.sum()))
pp('G-sum: {} '.format(G.sum()))
pp('R-sum: {} '.format(R.sum()))
img = np.stack((B, G, R), axis=2)
img /= (img.max()/255.0)
fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.imshow(img / 255)
plt.hold(True)
label_idx = np.where(item_ids.flatten() == i)[0]
pp('batch elem: {}'.format(i))
pp('label_idx: {}'.format(label_idx))
for idx in label_idx:
l = labels[:,:,idx,:].flatten()
img_width = img.shape[1]
img_height = img.shape[0]
xmin = l[3] * img_width
ymin = l[4] * img_height
xmax = l[5] * img_width
ymax = l[6] * img_height
group_id = int(l[1])
#pp('goup id {}'.format(group_id))
label_class = pascal_classes[group_id - 1]
#pp('label_class: {}'.format(label_class))
instance_id = int(l[2])
# Create a Rectangle patch
box_color = colors[group_id - 1]
rect = patches.Rectangle(
(xmin, ymin),
xmax - xmin,
ymax - ymin,
linewidth=3,
edgecolor=box_color,
fill=False,
alpha=1)
# Add the patch to the Axes
ax.add_patch(rect)
ax.text(xmin, ymin - 2,
'{:s}'.format(label_class),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=10,
color='white')
plt.title('training image {}, num gt: {}'.format(i, len(label_idx)))
plt.savefig('figs/img{}.png'.format(i))
plt.show()
zv_dispFig()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment