Last active
March 16, 2024 17:40
-
-
Save Hyuto/f3db1c0c2c36308284e101f441c2555f to your computer and use it in GitHub Desktop.
Generate metadata from custom trained YOLO-NAS model to help achieve best performance on inferencing with ONNX.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
custom-nas-model-metadata.py | |
Generate metadata from custom trained YOLO-NAS model to help achieve best performance | |
on inferencing with ONNX. | |
Usage: | |
$ python custom-nas-model-metadata.py -m <CHECKPOINT-PATH> \ # Custom trained YOLO-NAS checkpoint path | |
-t <MODEL-TYPE> \ # Custom trained YOLO-NAS model type | |
-n <NUM-CLASSES> # Number of classes | |
""" | |
import argparse | |
import json | |
import logging | |
from pathlib import Path | |
import numpy as np | |
logging.basicConfig(format="%(message)s", level=logging.INFO) | |
SUPPORTED_YOLO_NAS_TYPE = ["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"] | |
def parse_opt(): | |
parser = argparse.ArgumentParser(description="Export Custom Trained YOLO-NAS Model Metadata.") | |
required = parser.add_argument_group("required arguments") | |
required.add_argument( | |
"-m", "--model", type=str, required=True, help="Custom Trained YOLO-NAS checkpoint path" | |
) | |
required.add_argument( | |
"-t", "--type", type=str, required=True, help="Custom Trained YOLO-NAS type model" | |
) | |
required.add_argument( | |
"-n", "--num-classes", type=int, required=True, help="Custom Trained YOLO-NAS num classes" | |
) | |
parser.add_argument( | |
"--export-onnx", type=str, help="Convert model to onnx (path with extension)" | |
) | |
opt = parser.parse_args() # parsing args | |
opt.model = Path(opt.model) | |
opt.type = opt.type.lower() | |
# path checking | |
if not opt.model.exists(): | |
raise FileNotFoundError("Wrong path! Model Not Found!") | |
if opt.type not in SUPPORTED_YOLO_NAS_TYPE: | |
raise NotImplementedError( | |
f"Type: {opt.type} isn't supported.\nSupported YOLO-NAS type: {SUPPORTED_YOLO_NAS_TYPE}" | |
) | |
if opt.export_onnx: | |
opt.export_onnx = Path(opt.export_onnx) | |
if not opt.export_onnx.parent.exists(): | |
raise FileNotFoundError("Wrong path! Export directory not found.") | |
# logging | |
args = vars(opt).items() | |
logging.info( | |
"🚀 \033[1m\033[94m" | |
+ "Generate Metadata: " | |
+ "\033[0m" | |
+ ", ".join([f"{x}={y}" for x, y in args if y is not None]) | |
) | |
return opt | |
def get_preprocessing_steps(preprocessing, processing): | |
if isinstance(preprocessing, processing.StandardizeImage): | |
return {"Standardize": {"max_value": preprocessing.max_value}} | |
elif isinstance(preprocessing, processing.DetectionRescale): | |
return {"DetRescale": None} | |
elif isinstance(preprocessing, processing.DetectionLongestMaxSizeRescale): | |
return {"DetLongMaxRescale": None} | |
elif isinstance(preprocessing, processing.DetectionBottomRightPadding): | |
return { | |
"BotRightPad": { | |
"pad_value": preprocessing.pad_value, | |
} | |
} | |
elif isinstance(preprocessing, processing.DetectionCenterPadding): | |
return { | |
"CenterPad": { | |
"pad_value": preprocessing.pad_value, | |
} | |
} | |
elif isinstance(preprocessing, processing.NormalizeImage): | |
return { | |
"Normalize": {"mean": preprocessing.mean.tolist(), "std": preprocessing.std.tolist()} | |
} | |
elif isinstance(preprocessing, processing.ImagePermute): | |
return None | |
elif isinstance(preprocessing, processing.ReverseImageChannels): | |
return None | |
else: | |
raise NotImplemented("Model have processing steps that haven't been implemented") | |
def main(opt): | |
from super_gradients.training import models | |
import super_gradients.training.processing as processing | |
net = models.get(opt.type, num_classes=opt.num_classes, checkpoint_path=opt.model.as_posix()) | |
dummy = np.random.randint(0, 255, (1000, 800, 3), dtype=np.uint8) | |
labels = net._class_names | |
iou = net._default_nms_iou | |
conf = net._default_nms_conf | |
preprocessing_steps = [ | |
get_preprocessing_steps(st, processing) for st in net._image_processor.processings | |
] | |
imgsz = np.expand_dims(net._image_processor.preprocess_image(dummy)[0], 0).shape | |
res = { | |
"type": opt.type, | |
"original_insz": imgsz, | |
"iou_thres": iou, | |
"score_thres": conf, | |
"prep_steps": preprocessing_steps, | |
"labels": labels, | |
} | |
filename = f"custom-{opt.type}-metadata.json" | |
logging.info(f"Export metadata to: {filename}") | |
with open(filename, "w") as f: | |
f.write(json.dumps(res)) | |
if opt.export_onnx: | |
logging.info(f"Export ONNX model to: {opt.export_onnx}") | |
models.convert_to_onnx( | |
model=net, input_shape=imgsz[1:], out_path=opt.export_onnx.as_posix() | |
) | |
if __name__ == "__main__": | |
opt = parse_opt() | |
main(opt) |
@Sridharanraja, @Hyuto. I have this same error, i.e UnpicklingError: invalid load key, '\x08'. Any solution to this issue? Thank you
The script DOES NOT accept onnx
file format. ONLY pt
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Error:
File "yolo nas.py", line 8, in
model = models.get('yolo_nas_s', num_classes = 1,checkpoint_path = "custom_nas.onnx")
File "C:\Users\pc\anaconda3\envs\yolonas\lib\site-packages\super_gradients\training\models\model_factory.py", line 205, in get
ckpt_entries = read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
File "C:\Users\pc\anaconda3\envs\yolonas\lib\site-packages\super_gradients\training\utils\checkpoint_utils.py", line 136, in read_ckpt_state_dict
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
File "C:\Users\pc\anaconda3\envs\yolonas\lib\site-packages\torch\serialization.py", line 795, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "C:\Users\pc\anaconda3\envs\yolonas\lib\site-packages\torch\serialization.py", line 1002, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, '\x08'.