Skip to content

Instantly share code, notes, and snippets.

@chiehpower
Last active May 18, 2020 01:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chiehpower/7d2a598c9c2b6bef96a525c2f93ae927 to your computer and use it in GitHub Desktop.
Save chiehpower/7d2a598c9c2b6bef96a525c2f93ae927 to your computer and use it in GitHub Desktop.
For detectron2 issue
import numpy as np
import cv2
import os
import sys
import requests
import torch
import detectron2
#from detectron2.utils.visualizer import Visualizer
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
#from detectron2.utils.visualizer import ColorMode
# from detectron2.data.datasets import register_coco_instances, register_semantic
import time
import json
import copy
from detectron2.utils.logger import setup_logger
import ctypes
import torch
# from skimage import measure
# import skimage
currentfile_path = os.path.dirname(os.path.realpath(__file__))
# edited by Brilian
class model_custom(torch.nn.Module):
def __init__(self, fs, uptolayer=-1, include_head=False):
super(model_custom, self).__init__()
features = fs
bodylist = list(features.children())
if not include_head:
bodylist[-1] = bodylist[uptolayer].head
else:
bodylist = bodylist[:uptolayer]
self.features = torch.nn.ModuleList(bodylist)
def forward(self, x, debug=False):
xsize = x[1]
x = x[0]
print('..................input: ', x.shape, xsize)
for ii, fs in enumerate(self.features):
if ii == 1:
x = fs([xsize, [torch.tensor(1)]], x, None)
continue
x = fs(x)
return x
def download_model(basepath=currentfile_path,
download_file = 'model_final_a3ec72.pkl'):
print('\n')
print('*'*80)
download_file = basepath + '/' + download_file
if not os.path.exists(download_file):
url = 'https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x/138205316/model_final_a3ec72.pkl'
print('...[download_model] Downloading weight file... from {}'.format(url))
r = requests.get(url, allow_redirects=True)
print('...[download_model] Write the downloaded file...')
open(download_file, 'wb').write(r.content)
# with open(download_file, 'wb') as of:
# torch.save(of)
print('...[download_model] Download weight file finish...')
else:
print('...[download_model] The weight file is already downloaded, skip downloading...')
return download_file
def combine_config(path, weightfile, basepath = currentfile_path):
print('\n')
print('*'*80)
print('...[combine_config] Combining config file... from {}'.format(path))
voc_config = {
"min_dimension": 800,
"max_dimension": 1333,
"test_score_thresh": 0.8,
"fine_tune_checkpoint": "",
"max_iter": 500,
"save_by_second": -1,
"network": 0,
"max_detections": 100,
"is_use_data_aug": 1,
"gpu_limit": -1,
"UseAngle": 1,
"segmenting_images": "",
"class_name": "",
"n_classes": 80,
"TrainModel_n_classes": 0,
"ResetModel": 0
}
cfg = get_cfg()
# cfg.merge_from_file(r"D:\NewSegment\detectron2\configs\COCO-InstanceSegmentation\mask_rcnn_R_101_FPN_3x.yaml")
cfg.merge_from_file(path)
cfg.DATALOADER.NUM_WORKERS =0
cfg.MODEL.WEIGHTS = weightfile
# cfg.OUTPUT_DIR = basepath
cfg.SOLVER.IMS_PER_BATCH = 1
cfg.SOLVER.BASE_LR = 0.0025
cfg.SOLVER.MAX_ITER = (voc_config["max_iter"])
cfg.MODEL.ROI_HEADS.NUM_CLASSES = voc_config['n_classes']
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = voc_config['test_score_thresh']
cfg.INPUT.MIN_SIZE_TRAIN = voc_config['min_dimension']
cfg.INPUT.MAX_SIZE_TRAIN = voc_config['max_dimension']
cfg.INPUT.MIN_SIZE_TEST = voc_config['min_dimension']
cfg.INPUT.MAX_SIZE_TEST = voc_config['max_dimension']
print('...[combine_config] Combining config file finish...')
return cfg
def dummy_convert(cfg, device = torch.device('cuda:0'), only_backbone = True):
print('\n')
print('*'*80)
print('...[dummy_convert] Loading base model and load weightfile...')
weightfile = cfg.MODEL.WEIGHTS
detector = DefaultPredictor(cfg)
# checkpoint = torch.load(cfg.MODEL.WEIGHTS)
import pickle
with open(cfg.MODEL.WEIGHTS, 'rb') as f:
obj = f.read()
checkpoint = pickle.loads(obj, encoding='latin1')
detector.model.load_state_dict(checkpoint, strict = False)
# create a dummy input
print('...[dummy_convert] Initialize dummy input and size...')
dummy_input = torch.randn(1, 3, cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN)
dummy_input = dummy_input.to(device)
imsize = torch.IntTensor((cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN))
imsize = imsize.to(device)
if only_backbone:
print('...[dummy_convert] Loading custom model...')
cpmodel = model_custom(detector.model, uptolayer=-2, include_head=True).to(device)
print('...[dummy_convert] The custom model: \n', cpmodel)
# _ = cpmodel([dummy_input, imsize])
# export the model
print('...[dummy_convert] Export model...')
tonnxfile = currentfile_path+'\\checkmodel.onnx'
torch.onnx.export(cpmodel,
[dummy_input, imsize],
currentfile_path+'/checkmodel.onnx',
opset_version=11,
export_params=True)
else:
print('...[dummy_convert] Test detect...')
print('what is this value: ', dummy_input.shape, imsize)
# FIXME:
# # this one is error INT
# imsize= imsize.cpu().numpy()
# inputs = {"image": dummy_input[0], "height": int(imsize[0]), "width": int(imsize[1])}
# this one is error instance
inputs = {"image": dummy_input[0], "height": imsize[0], "width": imsize[1]}
with torch.no_grad():
_ = detector.model([inputs])
print('...[dummy_convert] Export model in full...')
torch.onnx.export(detector.model.to(device),
[inputs],
currentfile_path+'/checkmodel.onnx',
opset_version=11,
export_params=True
)
print('...[dummy_convert] Export model to simplify model...')
tonnxsimple = tonnxfile.split('.')
tonnxsimple[-1] = '_simplify.' + tonnxsimple[-1]
tonnxsimple = ''.join(tonnxsimple)
os.system('python3 -m onnxsim {} {}'.format(tonnxfile, tonnxsimple))
print('...[dummy_convert] Export model finish...')
if __name__ == "__main__":
# download weight file (automatically skip if exist)
# will return weightfile path
weightfile = download_model()
# set path maskrcnn config
path_maskrcnn_cfg = os.path.join(currentfile_path, "../configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")
# combine our custom config with the default config rom maskrcnn
# will return config with all of its parameters
cfg = combine_config(path_maskrcnn_cfg, weightfile)
print('cfg:', cfg)
# try to export our custom model
# dummy_convert(cfg, only_backbone = True) # only backbone + FPN
dummy_convert(cfg, only_backbone = False) # all
@chiehpower
Copy link
Author

chiehpower commented Apr 29, 2020

Please check the line 172 and line 173 below:

dummy_convert(cfg, only_backbone = True) # only backbone + FPN
dummy_convert(cfg, only_backbone = False) # all

If only_backbone = True, you can convert it successfully that only with backbone + FPN.
However, if only_backbone = False, it means including whole model that it will get wrong.

@chiehpower
Copy link
Author

Command:

$ python3 test_detect.py       

Output:

Failed to load OpenCL runtime


********************************************************************************
...[download_model] The weight file is already downloaded, skip downloading...


********************************************************************************
...[combine_config] Combining config file... from /home/nvidia/ssd240/github/detectron2/tools/python_class_issue/../../configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml
...[combine_config] Combining config file finish...
cfg: CUDNN_BENCHMARK: False
DATALOADER:
  ASPECT_RATIO_GROUPING: True
  FILTER_EMPTY_ANNOTATIONS: True
  NUM_WORKERS: 0
  REPEAT_THRESHOLD: 0.0
  SAMPLER_TRAIN: TrainingSampler
DATASETS:
  PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000
  PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000
  PROPOSAL_FILES_TEST: ()
  PROPOSAL_FILES_TRAIN: ()
  TEST: ('coco_2017_val',)
  TRAIN: ('coco_2017_train',)
GLOBAL:
  HACK: 1.0
INPUT:
  CROP:
    ENABLED: False
    SIZE: [0.9, 0.9]
    TYPE: relative_range
  FORMAT: BGR
  MASK_FORMAT: polygon
  MAX_SIZE_TEST: 1333
  MAX_SIZE_TRAIN: 1333
  MIN_SIZE_TEST: 800
  MIN_SIZE_TRAIN: 800
  MIN_SIZE_TRAIN_SAMPLING: choice
MODEL:
  ANCHOR_GENERATOR:
    ANGLES: [[-90, 0, 90]]
    ASPECT_RATIOS: [[0.5, 1.0, 2.0]]
    NAME: DefaultAnchorGenerator
    OFFSET: 0.0
    SIZES: [[32], [64], [128], [256], [512]]
  BACKBONE:
    FREEZE_AT: 2
    NAME: build_resnet_fpn_backbone
  DEVICE: cuda
  FPN:
    FUSE_TYPE: sum
    IN_FEATURES: ['res2', 'res3', 'res4', 'res5']
    NORM: 
    OUT_CHANNELS: 256
  KEYPOINT_ON: False
  LOAD_PROPOSALS: False
  MASK_ON: True
  META_ARCHITECTURE: GeneralizedRCNN
  PANOPTIC_FPN:
    COMBINE:
      ENABLED: True
      INSTANCES_CONFIDENCE_THRESH: 0.5
      OVERLAP_THRESH: 0.5
      STUFF_AREA_LIMIT: 4096
    INSTANCE_LOSS_WEIGHT: 1.0
  PIXEL_MEAN: [103.53, 116.28, 123.675]
  PIXEL_STD: [1.0, 1.0, 1.0]
  PROPOSAL_GENERATOR:
    MIN_SIZE: 0
    NAME: RPN
  RESNETS:
    DEFORM_MODULATED: False
    DEFORM_NUM_GROUPS: 1
    DEFORM_ON_PER_STAGE: [False, False, False, False]
    DEPTH: 101
    NORM: FrozenBN
    NUM_GROUPS: 1
    OUT_FEATURES: ['res2', 'res3', 'res4', 'res5']
    RES2_OUT_CHANNELS: 256
    RES5_DILATION: 1
    STEM_OUT_CHANNELS: 64
    STRIDE_IN_1X1: True
    WIDTH_PER_GROUP: 64
  RETINANET:
    BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
    FOCAL_LOSS_ALPHA: 0.25
    FOCAL_LOSS_GAMMA: 2.0
    IN_FEATURES: ['p3', 'p4', 'p5', 'p6', 'p7']
    IOU_LABELS: [0, -1, 1]
    IOU_THRESHOLDS: [0.4, 0.5]
    NMS_THRESH_TEST: 0.5
    NUM_CLASSES: 80
    NUM_CONVS: 4
    PRIOR_PROB: 0.01
    SCORE_THRESH_TEST: 0.05
    SMOOTH_L1_LOSS_BETA: 0.1
    TOPK_CANDIDATES_TEST: 1000
  ROI_BOX_CASCADE_HEAD:
    BBOX_REG_WEIGHTS: ((10.0, 10.0, 5.0, 5.0), (20.0, 20.0, 10.0, 10.0), (30.0, 30.0, 15.0, 15.0))
    IOUS: (0.5, 0.6, 0.7)
  ROI_BOX_HEAD:
    BBOX_REG_WEIGHTS: (10.0, 10.0, 5.0, 5.0)
    CLS_AGNOSTIC_BBOX_REG: False
    CONV_DIM: 256
    FC_DIM: 1024
    NAME: FastRCNNConvFCHead
    NORM: 
    NUM_CONV: 0
    NUM_FC: 2
    POOLER_RESOLUTION: 7
    POOLER_SAMPLING_RATIO: 0
    POOLER_TYPE: ROIAlignV2
    SMOOTH_L1_BETA: 0.0
    TRAIN_ON_PRED_BOXES: False
  ROI_HEADS:
    BATCH_SIZE_PER_IMAGE: 512
    IN_FEATURES: ['p2', 'p3', 'p4', 'p5']
    IOU_LABELS: [0, 1]
    IOU_THRESHOLDS: [0.5]
    NAME: StandardROIHeads
    NMS_THRESH_TEST: 0.5
    NUM_CLASSES: 80
    POSITIVE_FRACTION: 0.25
    PROPOSAL_APPEND_GT: True
    SCORE_THRESH_TEST: 0.8
  ROI_KEYPOINT_HEAD:
    CONV_DIMS: (512, 512, 512, 512, 512, 512, 512, 512)
    LOSS_WEIGHT: 1.0
    MIN_KEYPOINTS_PER_IMAGE: 1
    NAME: KRCNNConvDeconvUpsampleHead
    NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS: True
    NUM_KEYPOINTS: 17
    POOLER_RESOLUTION: 14
    POOLER_SAMPLING_RATIO: 0
    POOLER_TYPE: ROIAlignV2
  ROI_MASK_HEAD:
    CLS_AGNOSTIC_MASK: False
    CONV_DIM: 256
    NAME: MaskRCNNConvUpsampleHead
    NORM: 
    NUM_CONV: 4
    POOLER_RESOLUTION: 14
    POOLER_SAMPLING_RATIO: 0
    POOLER_TYPE: ROIAlignV2
  RPN:
    BATCH_SIZE_PER_IMAGE: 256
    BBOX_REG_WEIGHTS: (1.0, 1.0, 1.0, 1.0)
    BOUNDARY_THRESH: -1
    HEAD_NAME: StandardRPNHead
    IN_FEATURES: ['p2', 'p3', 'p4', 'p5', 'p6']
    IOU_LABELS: [0, -1, 1]
    IOU_THRESHOLDS: [0.3, 0.7]
    LOSS_WEIGHT: 1.0
    NMS_THRESH: 0.7
    POSITIVE_FRACTION: 0.5
    POST_NMS_TOPK_TEST: 1000
    POST_NMS_TOPK_TRAIN: 1000
    PRE_NMS_TOPK_TEST: 1000
    PRE_NMS_TOPK_TRAIN: 2000
    SMOOTH_L1_BETA: 0.0
  SEM_SEG_HEAD:
    COMMON_STRIDE: 4
    CONVS_DIM: 128
    IGNORE_VALUE: 255
    IN_FEATURES: ['p2', 'p3', 'p4', 'p5']
    LOSS_WEIGHT: 1.0
    NAME: SemSegFPNHead
    NORM: GN
    NUM_CLASSES: 54
  WEIGHTS: /home/nvidia/ssd240/github/detectron2/tools/python_class_issue/model_final_a3ec72.pkl
OUTPUT_DIR: ./output
SEED: -1
SOLVER:
  BASE_LR: 0.0025
  BIAS_LR_FACTOR: 1.0
  CHECKPOINT_PERIOD: 5000
  GAMMA: 0.1
  IMS_PER_BATCH: 1
  LR_SCHEDULER_NAME: WarmupMultiStepLR
  MAX_ITER: 500
  MOMENTUM: 0.9
  STEPS: (210000, 250000)
  WARMUP_FACTOR: 0.001
  WARMUP_ITERS: 1000
  WARMUP_METHOD: linear
  WEIGHT_DECAY: 0.0001
  WEIGHT_DECAY_BIAS: 0.0001
  WEIGHT_DECAY_NORM: 0.0
TEST:
  AUG:
    ENABLED: False
    FLIP: True
    MAX_SIZE: 4000
    MIN_SIZES: (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
  DETECTIONS_PER_IMAGE: 100
  EVAL_PERIOD: 0
  EXPECTED_RESULTS: []
  KEYPOINT_OKS_SIGMAS: []
  PRECISE_BN:
    ENABLED: False
    NUM_ITER: 200
VERSION: 2
VIS_PERIOD: 0


********************************************************************************
...[dummy_convert] Loading base model and load weightfile...
...[dummy_convert] Initialize dummy input and size...
...[dummy_convert] Test detect...
what is this value:  torch.Size([1, 3, 800, 800]) tensor([800, 800], device='cuda:0', dtype=torch.int32)
...[dummy_convert] Export model in full...
/home/nvidia/ssd240/github/detectron2/detectron2/structures/image_list.py:78: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_size[-2] = int(math.ceil(max_size[-2] / stride) * stride)  # type: ignore
/home/nvidia/ssd240/github/detectron2/detectron2/structures/image_list.py:79: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_size[-1] = int(math.ceil(max_size[-1] / stride) * stride)  # type: ignore
/home/nvidia/ssd240/github/detectron2/detectron2/structures/image_list.py:89: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if all(x == 0 for x in padding_size):  # https://github.com/pytorch/pytorch/issues/31734
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/anchor_generator.py:188: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_images = len(features[0])
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:145: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:150: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/box_regression.py:106: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w  # x1
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/box_regression.py:107: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h  # y1
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/box_regression.py:108: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w  # x2
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/box_regression.py:109: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h  # y2
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py:105: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_proposals_i = min(pre_nms_topk, Hi_Wi_A)
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py:110: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  topk_scores_i = logits_i[batch_idx, :num_proposals_i]
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py:111: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  topk_idx = idx[batch_idx, :num_proposals_i]
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py:131: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not valid_mask.all():
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:185: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:187: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 0].clamp_(min=0, max=w)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:188: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 1].clamp_(min=0, max=h)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:189: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 2].clamp_(min=0, max=w)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:190: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 3].clamp_(min=0, max=h)
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py:139: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if keep.sum().item() != len(boxes):
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/proposal_generator/rpn_outputs.py:139: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if keep.sum().item() != len(boxes):
/home/nvidia/ssd240/github/detectron2/detectron2/layers/nms.py:13: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/home/nvidia/ssd240/github/detectron2/detectron2/layers/nms.py:16: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if len(boxes) < 40000:
/home/nvidia/ssd240/github/detectron2/detectron2/structures/instances.py:71: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  data_len = len(value)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/instances.py:135: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return len(v)
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/poolers.py:207: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  0
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/poolers.py:73: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  (len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/poolers.py:221: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_boxes = len(pooler_fmt_boxes)
/home/nvidia/ssd240/github/detectron2/detectron2/layers/roi_align.py:93: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py:270: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_pred = len(self.proposals)
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py:91: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not valid_mask.all():
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:187: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 0].clamp_(min=0, max=w)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:188: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 1].clamp_(min=0, max=h)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:189: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 2].clamp_(min=0, max=w)
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:190: TracerWarning: There are 4 live references to the data region being modified when tracing in-place operator clamp_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 3].clamp_(min=0, max=h)
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py:107: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_bbox_reg_classes == 1:
/home/nvidia/ssd240/github/detectron2/detectron2/modeling/roi_heads/mask_head.py:128: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if cls_agnostic_mask:
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:265: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator mul_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 0::2] *= scale_x
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:265: TracerWarning: There are 5 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 0::2] *= scale_x
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:266: TracerWarning: There are 3 live references to the data region being modified when tracing in-place operator mul_. This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 1::2] *= scale_y
/home/nvidia/ssd240/github/detectron2/detectron2/structures/boxes.py:266: TracerWarning: There are 5 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
  self.tensor[:, 1::2] *= scale_y
/home/nvidia/ssd240/github/detectron2/detectron2/layers/mask_ops.py:88: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
/home/nvidia/ssd240/github/detectron2/detectron2/layers/mask_ops.py:89: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  N = len(masks)
Traceback (most recent call last):
  File "test_detect.py", line 173, in <module>
    dummy_convert(cfg, only_backbone = False) # all
  File "test_detect.py", line 152, in dummy_convert
    export_params=True
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 143, in export
    strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 66, in export
    dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 382, in _export
    fixed_batch_size=fixed_batch_size)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 249, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 206, in _trace_and_get_graph_from_model
    trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 275, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 355, in forward
    out_vars, _ = _flatten(out)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type Instances

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment