Skip to content

Instantly share code, notes, and snippets.

@LukeAI
Last active June 19, 2023 19:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save LukeAI/bbfc3ab749601ab0f2cb06e4b8fc75cb to your computer and use it in GitHub Desktop.
Save LukeAI/bbfc3ab749601ab0f2cb06e4b8fc75cb to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import os
from super_gradients.training import models
from super_gradients.common.object_names import Models
import onnx
import torch
import torch.nn as nn
# CONFIG
NO_CLASSES=80
batch_size = 1
topk_all = 100
input_shape = (3, 640, 640)
iou_thres=0.45
score_thres=0.25
end2end=True
onnx_path = "yolo_nas_s.onnx"
net = models.get(Models.YOLO_NAS_S, pretrained_weights="coco")
#net = models.get(Models.YOLO_NAS_L, num_classes=NO_CLASSES,
# checkpoint_path="/home/luke/yoloNAS/checkpoints/yolo_nas_l_spss_export/ckpt_latest.pth")
class TRT_NMS(torch.autograd.Function):
'''TensorRT NMS operation'''
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
class_agnostic=1
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=0,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
class_agnostic=1
):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
class_agnostic_i=class_agnostic,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes
class ONNX_TRT(nn.Module):
'''onnx module with TensorRT NMS operation.'''
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
super().__init__()
assert max_wh is None
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 0,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres
self.n_classes=n_classes
def forward(self, x):
boxes, confscores = x
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, confscores, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj,
self.plugin_version, self.score_activation,
self.score_threshold)
return num_det, det_boxes, det_scores, det_classes
net.eval()
net.prep_model_for_conversion()
# https://github.com/Deci-AI/super-gradients/blob/master/documentation/source/BenchmarkingYoloNAS.md
if (end2end):
onnx_path = os.path.splitext(onnx_path)[0] + "_nms" + ".onnx"
NMS = ONNX_TRT(
max_obj=topk_all, iou_thres=iou_thres, score_thres=score_thres, max_wh=None ,device=None, n_classes=NO_CLASSES
)
NMS.eval()
onnx_export_kwargs = {
'input_names' : ['images'],
'output_names' : ["num_dets", "det_boxes", "det_scores", "det_classes"]
}
models.convert_to_onnx(model=net, input_shape=input_shape, post_process=NMS, out_path=onnx_path,
torch_onnx_export_kwargs=onnx_export_kwargs)
else:
models.convert_to_onnx(model=net, input_shape=input_shape, out_path=onnx_path)
# set output dimensions
# note: this makes no functional difference, just explicitly labels output dims
# so can be understood better when onnx inspected with netron etc.
shapes = [batch_size, 1,
batch_size, topk_all, 4,
batch_size, topk_all,
batch_size, topk_all]
onnx_model = onnx.load(onnx_path) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
onnx.save(onnx_model, onnx_path)
@olivierbeauve
Copy link

Hello ! When I'm trying your code, I have this error :

Traceback (most recent call last):
  File "test.py", line 119, in <module>
    models.convert_to_onnx(model=net, input_shape=input_shape, post_process=NMS, out_path=onnx_path,
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/super_gradients/common/decorators/factory_decorator.py", line 36, in wrapper
    return func(*args, **kwargs)
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/super_gradients/common/decorators/factory_decorator.py", line 36, in wrapper
    return func(*args, **kwargs)
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/super_gradients/training/models/conversion.py", line 98, in convert_to_onnx
    onnx_simplify(out_path, out_path)
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/super_gradients/training/models/conversion.py", line 156, in onnx_simplify
    model_sim, check = simplify(model=onnx_path)
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/onnxsim/onnx_simplifier.py", line 199, in simplify
    model_opt_bytes = C.simplify(
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/onnxsim/onnx_simplifier.py", line 260, in Run
    output_arrs = sess.run(output_names, inputs, run_options=run_options)
  File "/home/beauveol/Desktop/lambda-checkout-supermarket-iot/venv/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: 1100 Got: 3 Expected: 2 Please fix either the inputs or the model.

Do you know why ?

Thank you !

@LukeAI
Copy link
Author

LukeAI commented Jun 19, 2023

no idea I'm afraid, if you copy and paste the code above and run it as-is, without modification it does that? is your super-gradients and onnxsim/onnxruntime etc. up to date?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment