Created
May 16, 2019 22:40
-
-
Save PCJohn/a73572f706efb3e6f739502aad6440bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Usage: | |
srun -p 1080ti-short --mem 100000 --gres=gpu:1 \ | |
python tools/viz_net.py --dataset coco2017 --output_dir tmp \ | |
--cfg configs/context/self-attn/e2e_faster_rcnn_R-50-C4_1x_attn_head-8_lr-long-v3.yaml \ | |
--ckpt data/models/coco-visual-8-head/model_step69999.pth | |
Example (8 heads v3 on COCO): | |
srun -p 1080ti-short --gres=gpu:1 --mem 20000 python tools/viz_net.py \ | |
--dataset coco2017 \ | |
--output_dir tmp \ | |
--cfg /mnt/nfs/work1/elm/arunirc/Research/detectron-context/Detectron-pytorch-video/configs/context/self-attn/e2e_faster_rcnn_R-50-C4_1x_attn_head-8_lr-long-v3.yaml \ | |
--ckpt /mnt/nfs/work1/elm/arunirc/Research/detectron-context/Detectron-pytorch-video/Outputs/e2e_faster_rcnn_R-50-C4_1x_attn_head-8_lr-long-v3/Apr10-19-35-20_node106_step/ckpt/model_step69999.pth | |
""" | |
import os | |
import sys | |
import cv2 | |
import json | |
import torch | |
import argparse | |
import numpy as np | |
from numpy import unravel_index | |
import matplotlib | |
matplotlib.use('Agg') | |
from functools import reduce | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import _init_paths # pylint: disable=unused-import | |
import nn as mynn | |
from core.config import cfg, merge_cfg_from_file, merge_cfg_from_list, assert_and_infer_cfg | |
from core.test_engine import run_inference | |
from core.test import im_detect_bbox | |
from utils import net as net_utils | |
from utils import vis as vis_utils | |
import utils.boxes as box_utils | |
from utils.colormap import colormap | |
from core import test as test_utils | |
from core import test_engine | |
from modeling import model_builder | |
from datasets.dataset_catalog import DATASETS | |
global args | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Convert dataset') | |
parser.add_argument( | |
'--dataset', help="path to coco json file", | |
required=True, type=str | |
) | |
parser.add_argument( | |
'--output_dir', help='folder to save the graph viz', | |
default='./', type=str | |
) | |
parser.add_argument( | |
'--ckpt', required=True, | |
help='list of checkpoint numbers to eval: use all if not specified') | |
parser.add_argument( | |
'--cfg', required=True, | |
help='config file') | |
return parser.parse_args() | |
def box_results(scores, boxes, use_nms=True): | |
num_classes = cfg.MODEL.NUM_CLASSES | |
cls_boxes = [[] for _ in range(num_classes)] | |
sel_roi_inds = [[] for _ in range(num_classes)] | |
for j in range(1, num_classes): | |
inds = np.where(scores[:, j] > cfg.TEST.SCORE_THRESH)[0] | |
scores_j = scores[inds, j] | |
boxes_j = boxes[inds, j * 4:(j + 1) * 4] | |
dets_j = np.hstack((boxes_j, scores_j[:, np.newaxis])).astype(np.float32, copy=False) | |
if use_nms and cfg.TEST.SOFT_NMS.ENABLED: | |
nms_dets, _ = box_utils.soft_nms( | |
dets_j, | |
sigma=cfg.TEST.SOFT_NMS.SIGMA, | |
overlap_thresh=cfg.TEST.NMS, | |
score_thresh=0.0001, | |
method=cfg.TEST.SOFT_NMS.METHOD | |
) | |
else: | |
keep = box_utils.nms(dets_j, cfg.TEST.NMS) | |
if not use_nms: | |
keep = range(dets_j.shape[0]) | |
nms_dets = dets_j[keep, :] | |
# Refine the post-NMS boxes using bounding-box voting | |
if use_nms and cfg.TEST.BBOX_VOTE.ENABLED: | |
nms_dets = box_utils.box_voting( | |
nms_dets, | |
dets_j, | |
cfg.TEST.BBOX_VOTE.VOTE_TH, | |
scoring_method=cfg.TEST.BBOX_VOTE.SCORING_METHOD | |
) | |
cls_boxes[j] = nms_dets | |
sel_roi_inds[j].extend([inds[k] for k in keep]) # track the roi indices picked up by nms | |
# Limit to max_per_image detections **over all classes** | |
if cfg.TEST.DETECTIONS_PER_IM > 0: | |
image_scores = np.hstack( | |
[cls_boxes[j][:, -1] for j in range(1, num_classes)] | |
) | |
if len(image_scores) > cfg.TEST.DETECTIONS_PER_IM: | |
image_thresh = np.sort(image_scores)[-cfg.TEST.DETECTIONS_PER_IM] | |
for j in range(1, num_classes): | |
keep = np.where(cls_boxes[j][:, -1] >= image_thresh)[0] | |
if not use_nms: | |
keep = range(cls_boxes[j].shape[0]) | |
cls_boxes[j] = cls_boxes[j][keep, :] | |
sel_roi_inds[j] = [sel_roi_inds[j][kp] for kp in keep] | |
sel_roi_inds = [ind for sl in sel_roi_inds for ind in sl] | |
im_results = np.vstack([cls_boxes[j] for j in range(1, num_classes)]) | |
boxes = im_results[:, :-1] | |
scores = im_results[:, -1] | |
assert (len(sel_roi_inds) == len(boxes)) | |
return scores, boxes, cls_boxes, sel_roi_inds | |
def load_model(path, gpu_id=0): | |
model = model_builder.Generalized_RCNN() | |
model.eval() | |
model.cuda() # always use cuda | |
checkpoint = torch.load(path, map_location=lambda storage, loc: storage) | |
net_utils.load_ckpt(model, checkpoint['model']) | |
model = mynn.DataParallel(model, cpu_keywords=['im_info', 'roidb'], minibatch=True) | |
return model | |
def draw_graph(adjmat,boxes): | |
# scale adjmat so max=1 | |
# np.fill_diagonal(adjmat, 0.0) | |
scaled_adjmat = adjmat / adjmat.max() | |
graph = nx.from_numpy_matrix(scaled_adjmat) | |
edge_weights = dict([((u,v),d['weight']) for u,v,d in graph.edges(data=True)]) | |
labels = dict(enumerate(range(scaled_adjmat.shape[0]))) | |
edge_labels = {(n1,n2):"{:.2f}".format(scaled_adjmat[n1,n2]) | |
for n1,v1 in enumerate(labels) for n2,v2 in enumerate(labels) if (scaled_adjmat[n1,n2] > 0)} | |
pos = {n:[(b[0]+b[2])/2.,(b[1]+b[3])/2.] for n,b in enumerate(boxes)} | |
nx.draw_networkx_nodes(graph,pos,alpha=0.5) | |
nx.draw_networkx_labels(graph,pos,labels,font_size=4,alpha=0.5) | |
edge_colors = list(edge_weights.values()) | |
edge_list = list(edge_weights.keys()) | |
# edge_colors = edge_colors / edge_colors.min() | |
# nx.draw_networkx_edges(graph,pos,alpha=0.5, edge_cmap=plt.cm.Blues, edge_color=edge_colors) | |
labels = nx.get_edge_attributes(graph,'weight') | |
nx.draw_networkx_edges(graph,pos,alpha=0.5, edge_cmap=plt.cm.jet, | |
edge_color=edge_colors, edge_list=edge_list) | |
nx.draw_networkx_edge_labels(graph,pos,font_size=6,edge_labels=edge_labels,alpha=0.5) | |
def draw_graph_color(adjmat,boxes): | |
# scale adjmat so max=1 | |
np.fill_diagonal(adjmat, 0.0) | |
scaled_adjmat = adjmat / adjmat.max() | |
#import pdb; pdb.set_trace(); | |
G = nx.from_numpy_matrix(scaled_adjmat, create_using=nx.DiGraph) | |
#G = nx.from_numpy_matrix(scaled_adjmat) | |
edges,weights = zip(*nx.get_edge_attributes(G,'weight').items()) | |
pos = {n:[(b[0]+b[2])/2.,(b[1]+b[3])/2.] for n,b in enumerate(boxes)} | |
# nodes_fig = nx.draw_networkx_nodes(G,pos,alpha=0.5) | |
# edges_fig = nx.draw_networkx_edges(G,pos,alpha=0.5, edge_cmap=plt.cm.Blues, | |
# edge_color=weights, edge_list=edges) | |
nx.draw(G, pos, edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Blues, | |
alpha=weights, width=[2.0*x for x in weights], node_size=2.0) | |
# print(weights) | |
# pass | |
def vis_one_image(fname, im, boxes, roi_inds, id2class, thresh=0.9, box_alpha=0.0, head=0, use_graph=True): | |
if isinstance(boxes, list): | |
boxes, _, _, classes = vis_utils.convert_from_cls_format(boxes, None, None) | |
if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh: | |
return | |
color_list = colormap(rgb=True) / 255 | |
cmap = plt.get_cmap('rainbow') | |
fig = plt.figure(frameon=False) | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.axis('off') | |
fig.add_axes(ax) | |
ax.imshow(cv2.cvtColor(im,cv2.COLOR_BGR2RGB),alpha=1.0)#,extent=[0,im.shape[1],0,im.shape[0]]) | |
# Display in largest to smallest order to reduce occlusion | |
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
sorted_inds = np.argsort(-areas) | |
mask_color_id = 0 | |
sel_boxes = [] | |
sel_rois = [] | |
for i in sorted_inds: | |
c = classes[i] | |
bbox = boxes[i, :4] | |
score = boxes[i, -1] | |
if score < thresh: | |
continue | |
sel_boxes.append(boxes[i]) | |
sel_rois.append(roi_inds[i]) | |
# show box (off by default, box_alpha=0.0) | |
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], | |
fill=False, edgecolor='g', linewidth=0.6)) | |
# show_class: | |
ax.text(bbox[0], bbox[1] - 2, | |
id2class[c]+':'+('%.2f' % score), | |
fontsize=14, family='serif', | |
bbox=dict(facecolor='g', alpha=0.4, pad=0, edgecolor='none'), | |
color='white') | |
assert (len(sel_rois) == len(sel_boxes)) | |
n_fg = len(sel_boxes) | |
sel_rois = np.array(sel_rois) | |
if use_graph: | |
adjmat = adjmat_pred[head].detach().cpu().numpy()[0] | |
adjmat = adjmat[sel_rois,:][:,sel_rois] # only selected rois | |
# draw_graph(adjmat,sel_boxes) | |
draw_graph_color(adjmat,sel_boxes) | |
plt.savefig(fname, bbox_inches='tight') | |
plt.clf(); plt.cla(); plt.close(fig) | |
# save pairs with edge weight | |
top_roi_pairs = [] | |
if use_graph: | |
for k,ind in enumerate(np.argsort(adjmat.flatten())[::-1][:5]): | |
i,j = unravel_index(ind,adjmat.shape) | |
edge_weight = adjmat[i,j] | |
bbox1 = boxes[i, :4] | |
bbox2 = boxes[j, :4] | |
#bbox1 = list(map(int,bbox1)) | |
#bbox2 = list(map(int,bbox2)) | |
#bbox1 = list(map(int,boxes[i, :4])) | |
#bbox2 = list(map(int,boxes[j, :4])) | |
#bbox1 = list(map(int,sel_boxes[0])) | |
#bbox2 = list(map(int,sel_boxes[1])) | |
"""b1 = im[bbox1[1]:bbox1[3],bbox1[0]:bbox1[2]] | |
b2 = im[bbox2[1]:bbox2[3],bbox2[0]:bbox2[2]] | |
plt.subplot(121) | |
plt.imshow(cv2.cvtColor(b1,cv2.COLOR_BGR2RGB)) | |
plt.subplot(122) | |
plt.imshow(cv2.cvtColor(b2,cv2.COLOR_BGR2RGB)) | |
plt.suptitle(str(edge_weight)) | |
plt.savefig(fname[:-4]+'roi_edge-'+str(k)+'.png', bbox_inches='tight') | |
plt.clf(); plt.cla(); plt.close(fig) | |
""" | |
top_roi_pairs.append([str(bbox1),str(bbox2),str(edge_weight)]) | |
return top_roi_pairs | |
if __name__ == '__main__': | |
args = parse_args() | |
merge_cfg_from_file(args.cfg) | |
assert_and_infer_cfg() | |
if args.dataset == "coco2017": | |
cfg.MODEL.NUM_CLASSES = 81 | |
ds_name = 'coco_2017_val' | |
elif args.dataset == "ade20k": | |
cfg.MODEL.NUM_CLASSES = 456 | |
ds_name = 'ade_val' | |
dataset_root = DATASETS[ds_name]['image_directory'] | |
dataset_json = DATASETS[ds_name]['annotation_file'] | |
with open(dataset_json,'r') as f: | |
dataset = json.load(f) | |
f.close() | |
classmap = dataset['categories'] | |
id2class = [(0,'background')]+[(c['id'],c['name']) for c in classmap] | |
id2class.sort(key=lambda x:x[0]) | |
id2class = [v for k,v in id2class] | |
model = load_model(args.ckpt) | |
target_scale = cfg.TEST.SCALE | |
target_max_size = cfg.TEST.MAX_SIZE | |
ext = '.png' | |
n_heads = 8 | |
use_graph = True | |
viz_thresh = 0.7 | |
num_img = 200 | |
all_roi_pairs = {} | |
for ind,im_entry in enumerate(dataset['images']): | |
all_roi_pairs['img-'+str(ind)] = {'file_name':im_entry['file_name']} | |
img_dir = os.path.join(args.output_dir,'img-'+str(ind)) | |
if not os.path.exists(img_dir): | |
os.system('mkdir '+img_dir) | |
img = cv2.imread(os.path.join(dataset_root,im_entry['file_name'])) | |
for use_nms in [True]: | |
#import pdb; pdb.set_trace(); | |
scores, pred_boxes, im_scale, blob_conv, adjmat_pred = test_utils.im_detect_bbox(model,img,target_scale,target_max_size) | |
#scores, pred_boxes, im_scale, blob_conv = test_utils.im_detect_bbox(model,img,target_scale,target_max_size) | |
scores, boxes, cls_boxes, sel_roi_inds = box_results(scores,pred_boxes,use_nms=use_nms) | |
if len(adjmat_pred) == 0: | |
use_graph=False | |
for head in range(n_heads): | |
im_name = os.path.splitext(im_entry['file_name'])[0] + '.png' | |
fname = os.path.join(img_dir, | |
'head-'+str(head)+'_nms-'+str(use_nms)+im_name.replace('/','-')) | |
#vis_one_image(fname,img,cls_boxes,adjmat_pred,sel_roi_inds,id2class,head=head,use_graph=use_graph,thresh=viz_thresh) | |
top_roi_pairs = vis_one_image(fname,img,cls_boxes,sel_roi_inds,id2class,head=head,use_graph=use_graph,thresh=viz_thresh) | |
all_roi_pairs['img-'+str(ind)]['head-'+str(head)] = top_roi_pairs | |
print('Saved image',ind) | |
if ind >= num_img: | |
break | |
if len(all_roi_pairs.keys()) > 0: | |
with open(os.path.join(args.output_dir,'roi_pairs.json'),'w') as f: | |
f.write(json.dumps(all_roi_pairs,indent=2)) | |
f.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment