Skip to content

Instantly share code, notes, and snippets.

@thsno02
thsno02 / add_normalization_to_onnx.py
Created October 9, 2023 02:35
add_normalization_to_onnx
def add_normalization_to_onnx(model_path: str, first_node_name: str, mean: list, std: list):
'''
Edit the exported onnx model => add preprocess layer: sub mean and div std
'''
model = onnx.load(model_path)
onnx.save(model, model_path.replace('inference', 'raw_inference'))
# Assuming the input of your model is a single tensor of shape (batch_size, num_channels, height, width)
input_tensor = model.graph.input[0]