Skip to content

Instantly share code, notes, and snippets.

@emuccino
Created July 1, 2020 04:31
Show Gist options
  • Save emuccino/cb46402d2af7f5bc28521876b8fefd18 to your computer and use it in GitHub Desktop.
Save emuccino/cb46402d2af7f5bc28521876b8fefd18 to your computer and use it in GitHub Desktop.
outputs = {}
#complete model input
inputs = Input(shape=x_train.shape[1:])
net = {'offload':inputs}
#stack all 3 device models together
for device in device_names:
net = device_models[device](net['offload'])
outputs[device] = net[device+'_outputs']
#compile complete model
model = Model(inputs=inputs,outputs=outputs)
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True),
optimizer='nadam',
metrics=['accuracy'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment