Skip to content

Instantly share code, notes, and snippets.

@PCJohn
Created May 16, 2019 22:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PCJohn/a73572f706efb3e6f739502aad6440bd to your computer and use it in GitHub Desktop.
Save PCJohn/a73572f706efb3e6f739502aad6440bd to your computer and use it in GitHub Desktop.
"""
Usage:
srun -p 1080ti-short --mem 100000 --gres=gpu:1 \
python tools/viz_net.py --dataset coco2017 --output_dir tmp \
--cfg configs/context/self-attn/e2e_faster_rcnn_R-50-C4_1x_attn_head-8_lr-long-v3.yaml \
--ckpt data/models/coco-visual-8-head/model_step69999.pth
Example (8 heads v3 on COCO):
srun -p 1080ti-short --gres=gpu:1 --mem 20000 python tools/viz_net.py \
--dataset coco2017 \
--output_dir tmp \
--cfg /mnt/nfs/work1/elm/arunirc/Research/detectron-context/Detectron-pytorch-video/configs/context/self-attn/e2e_faster_rcnn_R-50-C4_1x_attn_head-8_lr-long-v3.yaml \
--ckpt /mnt/nfs/work1/elm/arunirc/Research/detectron-context/Detectron-pytorch-video/Outputs/e2e_faster_rcnn_R-50-C4_1x_attn_head-8_lr-long-v3/Apr10-19-35-20_node106_step/ckpt/model_step69999.pth
"""
import os
import sys
import cv2
import json
import torch
import argparse
import numpy as np
from numpy import unravel_index
import matplotlib
matplotlib.use('Agg')
from functools import reduce
import matplotlib.pyplot as plt
import networkx as nx
import _init_paths # pylint: disable=unused-import
import nn as mynn
from core.config import cfg, merge_cfg_from_file, merge_cfg_from_list, assert_and_infer_cfg
from core.test_engine import run_inference
from core.test import im_detect_bbox
from utils import net as net_utils
from utils import vis as vis_utils
import utils.boxes as box_utils
from utils.colormap import colormap
from core import test as test_utils
from core import test_engine
from modeling import model_builder
from datasets.dataset_catalog import DATASETS
global args
def parse_args():
parser = argparse.ArgumentParser(description='Convert dataset')
parser.add_argument(
'--dataset', help="path to coco json file",
required=True, type=str
)
parser.add_argument(
'--output_dir', help='folder to save the graph viz',
default='./', type=str
)
parser.add_argument(
'--ckpt', required=True,
help='list of checkpoint numbers to eval: use all if not specified')
parser.add_argument(
'--cfg', required=True,
help='config file')
return parser.parse_args()
def box_results(scores, boxes, use_nms=True):
num_classes = cfg.MODEL.NUM_CLASSES
cls_boxes = [[] for _ in range(num_classes)]
sel_roi_inds = [[] for _ in range(num_classes)]
for j in range(1, num_classes):
inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0]
scores_j = scores[inds, j]
boxes_j = boxes[inds, j * 4:(j + 1) * 4]
dets_j = np.hstack((boxes_j, scores_j[:, np.newaxis])).astype(np.float32, copy=False)
if use_nms and cfg.TEST.SOFT_NMS.ENABLED:
nms_dets, _ = box_utils.soft_nms(
dets_j,
sigma=cfg.TEST.SOFT_NMS.SIGMA,
overlap_thresh=cfg.TEST.NMS,
score_thresh=0.0001,
method=cfg.TEST.SOFT_NMS.METHOD
)
else:
keep = box_utils.nms(dets_j, cfg.TEST.NMS)
if not use_nms:
keep = range(dets_j.shape[0])
nms_dets = dets_j[keep, :]
# Refine the post-NMS boxes using bounding-box voting
if use_nms and cfg.TEST.BBOX_VOTE.ENABLED:
nms_dets = box_utils.box_voting(
nms_dets,
dets_j,
cfg.TEST.BBOX_VOTE.VOTE_TH,
scoring_method=cfg.TEST.BBOX_VOTE.SCORING_METHOD
)
cls_boxes[j] = nms_dets
sel_roi_inds[j].extend([inds[k] for k in keep]) # track the roi indices picked up by nms
# Limit to max_per_image detections **over all classes**
if cfg.TEST.DETECTIONS_PER_IM > 0:
image_scores = np.hstack(
[cls_boxes[j][:, -1] for j in range(1, num_classes)]
)
if len(image_scores) > cfg.TEST.DETECTIONS_PER_IM:
image_thresh = np.sort(image_scores)[-cfg.TEST.DETECTIONS_PER_IM]
for j in range(1, num_classes):
keep = np.where(cls_boxes[j][:, -1] >= image_thresh)[0]
if not use_nms:
keep = range(cls_boxes[j].shape[0])
cls_boxes[j] = cls_boxes[j][keep, :]
sel_roi_inds[j] = [sel_roi_inds[j][kp] for kp in keep]
sel_roi_inds = [ind for sl in sel_roi_inds for ind in sl]
im_results = np.vstack([cls_boxes[j] for j in range(1, num_classes)])
boxes = im_results[:, :-1]
scores = im_results[:, -1]
assert (len(sel_roi_inds) == len(boxes))
return scores, boxes, cls_boxes, sel_roi_inds
def load_model(path, gpu_id=0):
model = model_builder.Generalized_RCNN()
model.eval()
model.cuda() # always use cuda
checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
net_utils.load_ckpt(model, checkpoint['model'])
model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True)
return model
def draw_graph(adjmat,boxes):
# scale adjmat so max=1
# np.fill_diagonal(adjmat, 0.0)
scaled_adjmat = adjmat / adjmat.max()
graph = nx.from_numpy_matrix(scaled_adjmat)
edge_weights = dict([((u,v),d['weight']) for u,v,d in graph.edges(data=True)])
labels = dict(enumerate(range(scaled_adjmat.shape[0])))
edge_labels = {(n1,n2):"{:.2f}".format(scaled_adjmat[n1,n2])
for n1,v1 in enumerate(labels) for n2,v2 in enumerate(labels) if (scaled_adjmat[n1,n2] > 0)}
pos = {n:[(b[0]+b[2])/2.,(b[1]+b[3])/2.] for n,b in enumerate(boxes)}
nx.draw_networkx_nodes(graph,pos,alpha=0.5)
nx.draw_networkx_labels(graph,pos,labels,font_size=4,alpha=0.5)
edge_colors = list(edge_weights.values())
edge_list = list(edge_weights.keys())
# edge_colors = edge_colors / edge_colors.min()
# nx.draw_networkx_edges(graph,pos,alpha=0.5, edge_cmap=plt.cm.Blues, edge_color=edge_colors)
labels = nx.get_edge_attributes(graph,'weight')
nx.draw_networkx_edges(graph,pos,alpha=0.5, edge_cmap=plt.cm.jet,
edge_color=edge_colors, edge_list=edge_list)
nx.draw_networkx_edge_labels(graph,pos,font_size=6,edge_labels=edge_labels,alpha=0.5)
def draw_graph_color(adjmat,boxes):
# scale adjmat so max=1
np.fill_diagonal(adjmat, 0.0)
scaled_adjmat = adjmat / adjmat.max()
#import pdb; pdb.set_trace();
G = nx.from_numpy_matrix(scaled_adjmat, create_using=nx.DiGraph)
#G = nx.from_numpy_matrix(scaled_adjmat)
edges,weights = zip(*nx.get_edge_attributes(G,'weight').items())
pos = {n:[(b[0]+b[2])/2.,(b[1]+b[3])/2.] for n,b in enumerate(boxes)}
# nodes_fig = nx.draw_networkx_nodes(G,pos,alpha=0.5)
# edges_fig = nx.draw_networkx_edges(G,pos,alpha=0.5, edge_cmap=plt.cm.Blues,
# edge_color=weights, edge_list=edges)
nx.draw(G, pos, edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Blues,
alpha=weights, width=[2.0*x for x in weights], node_size=2.0)
# print(weights)
# pass
def vis_one_image(fname, im, boxes, roi_inds, id2class, thresh=0.9, box_alpha=0.0, head=0, use_graph=True):
if isinstance(boxes, list):
boxes, _, _, classes = vis_utils.convert_from_cls_format(boxes, None, None)
if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
return
color_list = colormap(rgb=True) / 255
cmap = plt.get_cmap('rainbow')
fig = plt.figure(frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.axis('off')
fig.add_axes(ax)
ax.imshow(cv2.cvtColor(im,cv2.COLOR_BGR2RGB),alpha=1.0)#,extent=[0,im.shape[1],0,im.shape[0]])
# Display in largest to smallest order to reduce occlusion
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sorted_inds = np.argsort(-areas)
mask_color_id = 0
sel_boxes = []
sel_rois = []
for i in sorted_inds:
c = classes[i]
bbox = boxes[i, :4]
score = boxes[i, -1]
if score < thresh:
continue
sel_boxes.append(boxes[i])
sel_rois.append(roi_inds[i])
# show box (off by default, box_alpha=0.0)
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1],
fill=False, edgecolor='g', linewidth=0.6))
# show_class:
ax.text(bbox[0], bbox[1] - 2,
id2class[c]+':'+('%.2f' % score),
fontsize=14, family='serif',
bbox=dict(facecolor='g', alpha=0.4, pad=0, edgecolor='none'),
color='white')
assert (len(sel_rois) == len(sel_boxes))
n_fg = len(sel_boxes)
sel_rois = np.array(sel_rois)
if use_graph:
adjmat = adjmat_pred[head].detach().cpu().numpy()[0]
adjmat = adjmat[sel_rois,:][:,sel_rois] # only selected rois
# draw_graph(adjmat,sel_boxes)
draw_graph_color(adjmat,sel_boxes)
plt.savefig(fname, bbox_inches='tight')
plt.clf(); plt.cla(); plt.close(fig)
# save pairs with edge weight
top_roi_pairs = []
if use_graph:
for k,ind in enumerate(np.argsort(adjmat.flatten())[::-1][:5]):
i,j = unravel_index(ind,adjmat.shape)
edge_weight = adjmat[i,j]
bbox1 = boxes[i, :4]
bbox2 = boxes[j, :4]
#bbox1 = list(map(int,bbox1))
#bbox2 = list(map(int,bbox2))
#bbox1 = list(map(int,boxes[i, :4]))
#bbox2 = list(map(int,boxes[j, :4]))
#bbox1 = list(map(int,sel_boxes[0]))
#bbox2 = list(map(int,sel_boxes[1]))
"""b1 = im[bbox1[1]:bbox1[3],bbox1[0]:bbox1[2]]
b2 = im[bbox2[1]:bbox2[3],bbox2[0]:bbox2[2]]
plt.subplot(121)
plt.imshow(cv2.cvtColor(b1,cv2.COLOR_BGR2RGB))
plt.subplot(122)
plt.imshow(cv2.cvtColor(b2,cv2.COLOR_BGR2RGB))
plt.suptitle(str(edge_weight))
plt.savefig(fname[:-4]+'roi_edge-'+str(k)+'.png', bbox_inches='tight')
plt.clf(); plt.cla(); plt.close(fig)
"""
top_roi_pairs.append([str(bbox1),str(bbox2),str(edge_weight)])
return top_roi_pairs
if __name__ == '__main__':
args = parse_args()
merge_cfg_from_file(args.cfg)
assert_and_infer_cfg()
if args.dataset == "coco2017":
cfg.MODEL.NUM_CLASSES = 81
ds_name = 'coco_2017_val'
elif args.dataset == "ade20k":
cfg.MODEL.NUM_CLASSES = 456
ds_name = 'ade_val'
dataset_root = DATASETS[ds_name]['image_directory']
dataset_json = DATASETS[ds_name]['annotation_file']
with open(dataset_json,'r') as f:
dataset = json.load(f)
f.close()
classmap = dataset['categories']
id2class = [(0,'background')]+[(c['id'],c['name']) for c in classmap]
id2class.sort(key=lambda x:x[0])
id2class = [v for k,v in id2class]
model = load_model(args.ckpt)
target_scale = cfg.TEST.SCALE
target_max_size = cfg.TEST.MAX_SIZE
ext = '.png'
n_heads = 8
use_graph = True
viz_thresh = 0.7
num_img = 200
all_roi_pairs = {}
for ind,im_entry in enumerate(dataset['images']):
all_roi_pairs['img-'+str(ind)] = {'file_name':im_entry['file_name']}
img_dir = os.path.join(args.output_dir,'img-'+str(ind))
if not os.path.exists(img_dir):
os.system('mkdir '+img_dir)
img = cv2.imread(os.path.join(dataset_root,im_entry['file_name']))
for use_nms in [True]:
#import pdb; pdb.set_trace();
scores, pred_boxes, im_scale, blob_conv, adjmat_pred = test_utils.im_detect_bbox(model,img,target_scale,target_max_size)
#scores, pred_boxes, im_scale, blob_conv = test_utils.im_detect_bbox(model,img,target_scale,target_max_size)
scores, boxes, cls_boxes, sel_roi_inds = box_results(scores,pred_boxes,use_nms=use_nms)
if len(adjmat_pred) == 0:
use_graph=False
for head in range(n_heads):
im_name = os.path.splitext(im_entry['file_name'])[0] + '.png'
fname = os.path.join(img_dir,
'head-'+str(head)+'_nms-'+str(use_nms)+im_name.replace('/','-'))
#vis_one_image(fname,img,cls_boxes,adjmat_pred,sel_roi_inds,id2class,head=head,use_graph=use_graph,thresh=viz_thresh)
top_roi_pairs = vis_one_image(fname,img,cls_boxes,sel_roi_inds,id2class,head=head,use_graph=use_graph,thresh=viz_thresh)
all_roi_pairs['img-'+str(ind)]['head-'+str(head)] = top_roi_pairs
print('Saved image',ind)
if ind >= num_img:
break
if len(all_roi_pairs.keys()) > 0:
with open(os.path.join(args.output_dir,'roi_pairs.json'),'w') as f:
f.write(json.dumps(all_roi_pairs,indent=2))
f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment