Skip to content

Instantly share code, notes, and snippets.

@vuiseng9
Last active February 20, 2023 03:59
Show Gist options
  • Save vuiseng9/e590e52c92d004331d7cbbd0c5aef716 to your computer and use it in GitHub Desktop.
Save vuiseng9/e590e52c92d004331d7cbbd0c5aef716 to your computer and use it in GitHub Desktop.
import os
import logging as log
from openvino.runtime import Core, PartialShape, serialize

log.info = print

def get_input_output_names(ports):
    return [port.any_name for port in ports]

def get_node_names(ports):
    return [port.node.friendly_name for port in ports]

def print_inputs_and_outputs_info(model):
    inputs = model.inputs
    input_names = get_input_output_names(inputs)
    for i in range(len(inputs)):
        log.info(f"Model input  {i:2}: {input_names[i]:20}: precision {inputs[i].element_type.get_type_name()}, "
                                                    f"dimensions ({str(inputs[i].node.layout)}): "
                                                    f"{' '.join(str(x) for x in inputs[i].partial_shape)}")
    outputs = model.outputs
    output_names = get_input_output_names(outputs)
    for i in range(len(outputs)):
        log.info(f"Model output {i:2}: {output_names[i]:20}: precision {outputs[i].element_type.get_type_name()}, "
                                        f"dimensions ({str(outputs[i].node.layout)}): "
                                        f"{' '.join(str(x) for x in  outputs[i].partial_shape)}")

def print_divider(label=None):
    dashed_line = '-'*100
    print(dashed_line)
    if label is not None:
        print(f"+ {label} " + "\n")

# this routine presumed BERT model or will work with input of 2 axes
def reshape_ir_by_input(ov_model, batch_size=1, shape=-1):
    new_iport_cfg = dict()
    for iport in ov_model.inputs:
        if shape == -1:
            new_iport_cfg[iport.any_name] = PartialShape([-1] * len(iport.partial_shape))
        else:
            new_iport_cfg[iport.any_name] = PartialShape([batch_size, shape])
    ov_model.reshape(new_iport_cfg)
    return ov_model

def write_model(ov_model, xml_name, output_dir):
    ir_xml = os.path.join(output_dir, xml_name)
    ir_bin = ir_xml.replace(".xml", ".bin")
    serialize(ov_model, ir_xml, ir_bin)


ir_xml = "/set/path/to/your/ir"

outdir = os.path.dirname(ir_xml)
xml_basename = os.path.basename(ir_xml)

core = Core()
ov_model = core.read_model(ir_xml)

print_divider("Original Model Shape")
print_inputs_and_outputs_info(ov_model)

def test_reshaping(ov_model, batch_size, length, xml_basename=None, outdir="./"):
    print_divider(f"Shape to {bs}x{length}")
    ov_model = reshape_ir_by_input(ov_model, batch_size=bs, shape=length)
    print_inputs_and_outputs_info(ov_model)
    if xml_basename is not None:
        write_model(ov_model, f"{batch_size}x{length}_{xml_basename}", outdir)
    return ov_model

bs, length = 1, 384
ov_model = test_reshaping(ov_model, batch_size=bs, length=length, xml_basename=xml_basename, outdir=outdir)

bs, length = -1, -1
ov_model = test_reshaping(ov_model, batch_size=bs, length=length)

bs, length = 2, 256
ov_model = test_reshaping(ov_model, batch_size=bs, length=length)

bs, length = -1, -1
ov_model = test_reshaping(ov_model, batch_size=bs, length=length)

bs, length = 1, 89
ov_model = test_reshaping(ov_model, batch_size=bs, length=length, xml_basename=xml_basename, outdir=outdir)

bs, length = -1, -1
ov_model = test_reshaping(ov_model, batch_size=bs, length=length, xml_basename=xml_basename, outdir=outdir)

print_divider()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment