Skip to content

Instantly share code, notes, and snippets.

@cdyk
Created May 31, 2020 19:33
Show Gist options
  • Save cdyk/76ed81f376fff14a3586de79ae62a16a to your computer and use it in GitHub Desktop.
Save cdyk/76ed81f376fff14a3586de79ae62a16a to your computer and use it in GitHub Desktop.
Duct-taping onnx-models into a composite onnx-model

Duct-taping onnx-models into a composite onnx-model

Christopher Dyken

Recently, I have played around a bit with onnx-models. One thing I wanted to do was to piece together multiple small models into one large model. I didn´t find any utility to do this (though it might be the issue that I didn´t look hard enough), so I decided to figure out how hard it was to do by hand. This little post details that little adventure. Please be warned that I am definitely no expert on onnx-models, so don´t take anything I write here as an asbolute truth. :-)

I use the python packages onnx and onnxruntime. The package onnx contains code to create and read onnx files, as well as helper factory functions. The package onnxruntime contains code to evaluate an onnx-model, and has the package onnx as a dependency. All the code snippets assume the imports

import onnx
import onnxruntime
import numpy as np

I had some issues as I installed onnx before onnxruntime. Reverting to a clean environment with python version 3.7.7 and install onnxruntime, which pulls onnx as a dependency, worked me.

The plan

From the specification, we see than an onnx-model is a cycle-free network of a set of inputs and a set of outputs connected via a set of nodes, where each node is an operator. Each internal edge gets a label that is used to describe connections. The graph is encoded on statically single assigment (SSA) form; a label is only assigned a value once (hence all inputs, outputs and temporary lables must be unique), and a label cannot be used as input before it has been assigned as output, except for the input nodes.

Thus, the graph is already topologically sorted so that a onnx runtime can directly evaluate all the nodes in order of declaration.

The building block: multiply-add

For this experiment, I needed a simple building block, and I wanted at least one internal edge, to make sure that handling of internal edges worked out. I chose a multiply-add operation:

r = a + b * c

built from a multiply and an add with a temporary value which I named t,

t = b * c
r = a + t

Drawing this as a nice ascii-art graph we get

   a    b     c                        2    3     5
   |    |     |                        |    |     |
   |    V     V                        |    V     V
   |    +-----+                        |    +-----+
   |    | Mul |                        |    | Mul |
   V    +--+--+                        V    +--+--+
+-----+    |                        +-----+    |
| Add |<-- t                        | Add |<-- 3*5=15
+--+--+                             +--+--+
   |                                   |
   V                                   V
   r                                 2+15 = 17

where the left sides has input, output, and temporaries labelled, and the right side have some numbers with a associated calculations. I will use those calculations to test my network.

Creating the model

The onnx format is built on protocol buffers. The onnx package contains the corresponding serialization code as well as some helper factory functions.

Using the onnx package, our multiply-add network can be realized with:

inputs = [
    onnx.helper.make_tensor_value_info('a', onnx.TensorProto.FLOAT, [2,3,5]),
    onnx.helper.make_tensor_value_info('b', onnx.TensorProto.FLOAT, [2,3,5]),
    onnx.helper.make_tensor_value_info('c', onnx.TensorProto.FLOAT, [2,3,5])
]
outputs = [
    onnx.helper.make_tensor_value_info('r', onnx.TensorProto.FLOAT, [2,3,5])
]
nodes = [
    onnx.helper.make_node('Mul', ['b', 'c'], ['t']),
    onnx.helper.make_node('Add', ['a', 't'], ['r'])
]
madd_graph = onnx.helper.make_graph(nodes, "multiply-add", inputs, outputs)
madd_model = onnx.helper.make_model(madd_graph)
onnx.save(madd_model, 'mad.onnx')

Here I have specified the three inputs are assumed to be rank 3 tensors with dimension [2,3,5] and of float type; as I wanted to use something more interesting than plain scalars. The type and shape of the temporary t is implicit.

Running the model

Now, it would be nice to try this model to see if it works. Using a tool like netron, we can directly view the network in the onnx file and verify that it makes sense.

To actually run the model, we use the package onnxruntime. The idea behind onnx is to specify the model as a rather abstract graph and not code, and use a runtime to instantiate the model by either evaluating the graph or generate code.

First we load the model and create an inference session. We then create the input (getting the appropriate shape of the input tensors from the session), and then run it:

madd_session = onnxruntime.InferenceSession('mad.onnx')
a = np.full(madd_session.get_inputs()[0].shape, 2.0).astype(np.float32)
b = np.full(madd_session.get_inputs()[1].shape, 3.0).astype(np.float32)
c = np.full(madd_session.get_inputs()[2].shape, 5.0).astype(np.float32)
r = madd_session.run(['r'], {'a': a, 'b': b, 'c': c})

Which gives the output

print(r[0])
# outputs [[17. 17. 17. 17. 17.]
#          [17. 17. 17. 17. 17.]
#          [17. 17. 17. 17. 17.]]
# 
#         [[17. 17. 17. 17. 17.]
#          [17. 17. 17. 17. 17.]
#          [17. 17. 17. 17. 17.]]]

print(r[0].shape)
# outputs (2, 3, 5)

Thus the output is of the correct size with the right values.

This model has only one output, but a model can have multiple outputs and you specify which of them you are interested in (and thus the runtime can prune the graph to avoid performing calculations of the results that are discarded

Building a larger network using smaller onnx-models

Now over to the meat of the matter. I wanted to duct-tape together four basic multiply-add blocks in the following way:

i0  i1  i2  i3  i4  i5  i6  i7  i8      2   3   5   7   11  13  17  19  23
|   |   |   |   |   |   |   |   |       |   |   |   |   |   |   |   |   |
|   |   |   V   V   V   V   V   V       |   |   |   V   V   V   V   V   V
|   |   |   a---b---c   a---b---c       |   |   |   a---b---c   a---b---c
|   |   |   | Mad 1 |   | Mad 2 |       |   |   |   | Mad 1 |   | Mad 2 |
|   |   |   +---r---+   +---r---+       |   |   |   +---r---+   +---r---+
|   |   |       |           |           |   |   |       |           |
|   |   |   +- t0           |           |   |   |   +- 17+19*23=150 |
|   |   |   |               |           |   |   |   |               |
|   |   V   V               |           |   |   V   V               |
|   |   a---b---c <------- t1           |   |   a---b---c <---- 17+19*23=454
|   |   | Mad 3 |                       |   |   | Mad 3 |
|   |   +---r---+                       |   |   +---r---+
|   |       |                           |   |       |
|   |   t2--+                           |   |   +---+ 5+150*454=68105
|   |   |                               |   |   |
V   V   V                               V   V   V 
a---b---c                               a---b---c
| Mad 4 |                               | Mad 4 |
+---r---+                               +---r---+
    |                                       |
    o                                2+3*68105=204317

This doesn´t compute anything interesting, but the structure has some internal nodes, some parallelism, and some sequential dependencies. On the left I have used symbols (denoting inputs i0 through i8 and the single output o), and on the right I have the calculation of some numbers that I will use to test the network.

The idea is to topologically sort this graph and just concatenate the nodes of each subgraph, building the node order for the full graph. Also, I have to relabel input and outputs.

Reading the building-block models

I start with loading four instances of the same model. This way, I can just move nodes from the building block instances into the large graph, oblivious to any node properties beyond input and output labels.

model_paths = {
    'mad1': 'mad.onnx',
    'mad2': 'mad.onnx',
    'mad3': 'mad.onnx',
    'mad4': 'mad.onnx'
}

models = {}
for key, model_path in model_paths.items():
    models[key] = onnx.load(model_path)

Generating the inputs and outputs

We define the inputs and outputs for the full model by specifying which input/output of which sub-model that plays that role:

input_spec = {
    'i0': ('mad4', 'a'),
    'i1': ('mad4', 'b'),
    'i2': ('mad3', 'a'),
    'i3': ('mad1', 'a'),
    'i4': ('mad1', 'b'),
    'i5': ('mad1', 'c'),
    'i6': ('mad2', 'a'),
    'i7': ('mad2', 'b'),
    'i8': ('mad2', 'c')
}

output_spec = {
    'o': ('mad4', 'r')
}

Here, the first line of input_spec says that the full model input i0 is actually input a of sub-model mad4.

First, we need to a small utility function that finds a particularly named instance in a sequence of repeated ValueInfoProtos:

def get_by_name(repfield, name):
    for item in repfield:
        if item.name == name:
            return item
    raise RuntimeError(f'Failed to find item named {name}')

Then, the following snippets run through the specifications, find the right ValueInfoProto (structure that describes an input/output) on the submodel, corrects its name, and returns a list with the full models input/output.

def generate_inputs(input_spec, models):
    inputs = []
    for global_name, (model, model_name) in input_spec.items():
        input = get_by_name(models[model].graph.input, model_name)
        input.name = global_name
        inputs.append(input)
    return inputs

def generate_outputs(output_spec, models):
    outputs = []
    for global_name, (model, model_name) in output_spec.items():
        output = get_by_name(models[model].graph.output, model_name)
        output.name = global_name
        outputs.append(output)
    return outputs

Generating the sequence of nodes

We start by specifying an ordering of the sub-models such that no value is used as input before it has been computed, along with the mapping of model inputs and outputs to full model input, outputs, and temporaries:

model_spec = [
    ('mad1', {'a': 'i3', 'b': 'i4', 'c': 'i5'}, {'r': 't0'}),
    ('mad2', {'a': 'i6', 'b': 'i7', 'c': 'i8'}, {'r': 't1'}),
    ('mad3', {'a': 'i2', 'b': 't0', 'c': 't1'}, {'r': 't2'}),
    ('mad4', {'a': 'i0', 'b': 'i1', 'c': 't2'}, {'r': 'o'})
]

The first line specifies that we start with submodel mad1, and we relabel its inputs a, b, and c as i3, i4, and i5, and its output r as temporary t0.

A sub-model can only be referenced once here, since we transplant the models directly instead of cloning different pieces (and that is why we loaded the same model four times earlier).

A final issue is that sub-models can have internal temporaries (and indeed our building block model has exactly that --- on purpose), and we have to make sure that these have unique names. We solve this by simpy adding the sub-model name as a prefix.

The following code takes a sequence of labels and creates a new set of labels, either from the map if the name is present, or if not, it prefixes the label:

def relabel(old_labels, rename_map, prefix):
    new_labels = []
    for old_label in old_labels:
        if old_label in rename_map:
            new_labels.append(rename_map[old_label])
        else:
            new_labels.append(prefix + "_" + old_label)
    return new_labels

Then we are ready to transplant the nodes of the sub-models into the full model, while updating input and output names of each node:

def generate_nodes(model_spec, models):
    nodes = []
    for model_name, in_spec, out_spec in model_spec:
        graph = models[model_name].graph
        for node in graph.node:
            node.input[:] = relabel(node.input, in_spec, model_name)
            node.output[:] = relabel(node.output, out_spec, model_name)
            nodes.append(node)
    return nodes

Building the full model

With the bits and pieces in place, building the model and store it in a file is pretty straightforward:

comp_graph = onnx.helper.make_graph(generate_nodes(model_spec, models),
                                    "mad-composition",
                                    generate_inputs(input_spec, models),
                                    generate_outputs(output_spec, models))
comp_model = onnx.helper.make_model(comp_graph)
onnx.save(comp_model, 'comp.onnx')

Running our composite model

To run the code, we do basically the same as we did when running the building block model. The only difference is that the inputs are built using a loop, since there are so many of them:

import onnxruntime
s = onnxruntime.InferenceSession('comp.onnx')

vals = [2, 3, 5, 7, 11, 13, 17, 19, 23]
inputs = {}
for i, val in enumerate(vals):
    inputs['i'+str(i)] = np.full(s.get_inputs()[0].shape, val).astype(np.float32)
r = s.run(['o'], inputs)

which produces the output:

print(r[0])
# outputs: [[[204317. 204317. 204317. 204317. 204317.]
#            [204317. 204317. 204317. 204317. 204317.]
#            [204317. 204317. 204317. 204317. 204317.]]
#
#           [[204317. 204317. 204317. 204317. 204317.]
#            [204317. 204317. 204317. 204317. 204317.]
#            [204317. 204317. 204317. 204317. 204317.]]]

print(r[0].shape)
# (2, 3, 5)

Thus the output is again of the correct size with the right values.

Conclusion

In this little note I have detailed how I duct-taped multiple onnx models into a composite onnx-model. I didn't find much info about this scenario, so I detailed my experiments in this post. So I hope anyone besides me finds this interesting.

import onnxruntime
import onnx
import numpy as np
# Create building-block model
def create_madd_model():
inputs = [
onnx.helper.make_tensor_value_info('a', onnx.TensorProto.FLOAT, [2,3,5]),
onnx.helper.make_tensor_value_info('b', onnx.TensorProto.FLOAT, [2,3,5]),
onnx.helper.make_tensor_value_info('c', onnx.TensorProto.FLOAT, [2,3,5])
]
outputs = [
onnx.helper.make_tensor_value_info('r', onnx.TensorProto.FLOAT, [2,3,5])
]
nodes = [
onnx.helper.make_node('Mul', ['b', 'c'], ['t']),
onnx.helper.make_node('Add', ['a', 't'], ['r'])
]
return onnx.helper.make_model(onnx.helper.make_graph(nodes, "multiply-add", inputs, outputs))
onnx.save(create_madd_model(), 'mad.onnx')
# Run building-block model
def run_madd_model():
madd_session = onnxruntime.InferenceSession('mad.onnx')
a = np.full(madd_session.get_inputs()[0].shape, 2.0).astype(np.float32)
b = np.full(madd_session.get_inputs()[1].shape, 3.0).astype(np.float32)
c = np.full(madd_session.get_inputs()[2].shape, 5.0).astype(np.float32)
r = madd_session.run(['r'], {'a': a, 'b': b, 'c': c})
print(r[0])
print(r[0].shape)
run_madd_model()
# Build composite model
model_paths = {
'mad1': 'mad.onnx',
'mad2': 'mad.onnx',
'mad3': 'mad.onnx',
'mad4': 'mad.onnx'
}
input_spec = {
'i0': ('mad4', 'a'),
'i1': ('mad4', 'b'),
'i2': ('mad3', 'a'),
'i3': ('mad1', 'a'),
'i4': ('mad1', 'b'),
'i5': ('mad1', 'c'),
'i6': ('mad2', 'a'),
'i7': ('mad2', 'b'),
'i8': ('mad2', 'c')
}
output_spec = {
'o': ('mad4', 'r')
}
model_spec = [
('mad1', {'a': 'i3', 'b': 'i4', 'c': 'i5'}, {'r': 't0'}),
('mad2', {'a': 'i6', 'b': 'i7', 'c': 'i8'}, {'r': 't1'}),
('mad3', {'a': 'i2', 'b': 't0', 'c': 't1'}, {'r': 't2'}),
('mad4', {'a': 'i0', 'b': 'i1', 'c': 't2'}, {'r': 'o'})
]
models = {}
for key, model_path in model_paths.items():
models[key] = onnx.load(model_path)
def get_by_name(repfield, name):
for item in repfield:
if item.name == name:
return item
raise RuntimeError(f'Failed to find item named {name}')
def generate_inputs(input_spec, models):
inputs = []
for global_name, (model, model_name) in input_spec.items():
input = get_by_name(models[model].graph.input, model_name)
input.name = global_name
inputs.append(input)
return inputs
def generate_outputs(output_spec, models):
outputs = []
for global_name, (model, model_name) in output_spec.items():
output = get_by_name(models[model].graph.output, model_name)
output.name = global_name
outputs.append(output)
return outputs
def relabel(old_labels, rename_map, prefix):
new_labels = []
for old_label in old_labels:
if old_label in rename_map:
new_labels.append(rename_map[old_label])
else:
new_labels.append(prefix + "_" + old_label)
return new_labels
def generate_nodes(model_spec, models):
nodes = []
for model_name, in_spec, out_spec in model_spec:
graph = models[model_name].graph
for node in graph.node:
node.input[:] = relabel(node.input, in_spec, model_name)
node.output[:] = relabel(node.output, out_spec, model_name)
nodes.append(node)
return nodes
comp_graph = onnx.helper.make_graph(generate_nodes(model_spec, models),
"mad-composition",
generate_inputs(input_spec, models),
generate_outputs(output_spec, models))
comp_model = onnx.helper.make_model(comp_graph)
onnx.save(comp_model, 'comp.onnx')
# Run composite model
def run_comp_model():
s = onnxruntime.InferenceSession('comp.onnx')
vals = [2, 3, 5, 7, 11, 13, 17, 19, 23]
inputs = {}
for i, val in enumerate(vals):
inputs['i'+str(i)] = np.full(s.get_inputs()[0].shape, val).astype(np.float32)
r = s.run(['o'], inputs)
print(r[0])
print(r[0].shape)
run_comp_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment