Skip to content

Instantly share code, notes, and snippets.

@Hyuto
Last active March 16, 2024 17:40
Show Gist options
  • Save Hyuto/f3db1c0c2c36308284e101f441c2555f to your computer and use it in GitHub Desktop.
Save Hyuto/f3db1c0c2c36308284e101f441c2555f to your computer and use it in GitHub Desktop.
Generate metadata from custom trained YOLO-NAS model to help achieve best performance on inferencing with ONNX.
"""
custom-nas-model-metadata.py
Generate metadata from custom trained YOLO-NAS model to help achieve best performance
on inferencing with ONNX.
Usage:
$ python custom-nas-model-metadata.py -m <CHECKPOINT-PATH> \ # Custom trained YOLO-NAS checkpoint path
-t <MODEL-TYPE> \ # Custom trained YOLO-NAS model type
-n <NUM-CLASSES> # Number of classes
"""
import argparse
import json
import logging
from pathlib import Path
import numpy as np
logging.basicConfig(format="%(message)s", level=logging.INFO)
SUPPORTED_YOLO_NAS_TYPE = ["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"]
def parse_opt():
parser = argparse.ArgumentParser(description="Export Custom Trained YOLO-NAS Model Metadata.")
required = parser.add_argument_group("required arguments")
required.add_argument(
"-m", "--model", type=str, required=True, help="Custom Trained YOLO-NAS checkpoint path"
)
required.add_argument(
"-t", "--type", type=str, required=True, help="Custom Trained YOLO-NAS type model"
)
required.add_argument(
"-n", "--num-classes", type=int, required=True, help="Custom Trained YOLO-NAS num classes"
)
parser.add_argument(
"--export-onnx", type=str, help="Convert model to onnx (path with extension)"
)
opt = parser.parse_args() # parsing args
opt.model = Path(opt.model)
opt.type = opt.type.lower()
# path checking
if not opt.model.exists():
raise FileNotFoundError("Wrong path! Model Not Found!")
if opt.type not in SUPPORTED_YOLO_NAS_TYPE:
raise NotImplementedError(
f"Type: {opt.type} isn't supported.\nSupported YOLO-NAS type: {SUPPORTED_YOLO_NAS_TYPE}"
)
if opt.export_onnx:
opt.export_onnx = Path(opt.export_onnx)
if not opt.export_onnx.parent.exists():
raise FileNotFoundError("Wrong path! Export directory not found.")
# logging
args = vars(opt).items()
logging.info(
"🚀 \033[1m\033[94m"
+ "Generate Metadata: "
+ "\033[0m"
+ ", ".join([f"{x}={y}" for x, y in args if y is not None])
)
return opt
def get_preprocessing_steps(preprocessing, processing):
if isinstance(preprocessing, processing.StandardizeImage):
return {"Standardize": {"max_value": preprocessing.max_value}}
elif isinstance(preprocessing, processing.DetectionRescale):
return {"DetRescale": None}
elif isinstance(preprocessing, processing.DetectionLongestMaxSizeRescale):
return {"DetLongMaxRescale": None}
elif isinstance(preprocessing, processing.DetectionBottomRightPadding):
return {
"BotRightPad": {
"pad_value": preprocessing.pad_value,
}
}
elif isinstance(preprocessing, processing.DetectionCenterPadding):
return {
"CenterPad": {
"pad_value": preprocessing.pad_value,
}
}
elif isinstance(preprocessing, processing.NormalizeImage):
return {
"Normalize": {"mean": preprocessing.mean.tolist(), "std": preprocessing.std.tolist()}
}
elif isinstance(preprocessing, processing.ImagePermute):
return None
elif isinstance(preprocessing, processing.ReverseImageChannels):
return None
else:
raise NotImplemented("Model have processing steps that haven't been implemented")
def main(opt):
from super_gradients.training import models
import super_gradients.training.processing as processing
net = models.get(opt.type, num_classes=opt.num_classes, checkpoint_path=opt.model.as_posix())
dummy = np.random.randint(0, 255, (1000, 800, 3), dtype=np.uint8)
labels = net._class_names
iou = net._default_nms_iou
conf = net._default_nms_conf
preprocessing_steps = [
get_preprocessing_steps(st, processing) for st in net._image_processor.processings
]
imgsz = np.expand_dims(net._image_processor.preprocess_image(dummy)[0], 0).shape
res = {
"type": opt.type,
"original_insz": imgsz,
"iou_thres": iou,
"score_thres": conf,
"prep_steps": preprocessing_steps,
"labels": labels,
}
filename = f"custom-{opt.type}-metadata.json"
logging.info(f"Export metadata to: {filename}")
with open(filename, "w") as f:
f.write(json.dumps(res))
if opt.export_onnx:
logging.info(f"Export ONNX model to: {opt.export_onnx}")
models.convert_to_onnx(
model=net, input_shape=imgsz[1:], out_path=opt.export_onnx.as_posix()
)
if __name__ == "__main__":
opt = parse_opt()
main(opt)
@Sumeshbaba
Copy link

@Sridharanraja, @Hyuto. I have this same error, i.e UnpicklingError: invalid load key, '\x08'. Any solution to this issue? Thank you

@ukicomputers
Copy link

The script DOES NOT accept onnx file format. ONLY pt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment