Skip to content

Instantly share code, notes, and snippets.

@emuccino
Created July 1, 2020 04:22
Show Gist options
  • Save emuccino/a37f7362f49367680e797e957a14de04 to your computer and use it in GitHub Desktop.
Save emuccino/a37f7362f49367680e797e957a14de04 to your computer and use it in GitHub Desktop.
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Dense, GlobalAvgPool2D, GlobalMaxPool2D, Concatenate, BatchNormalization
from tensorflow.keras.losses import SparseCategoricalCrossentropy
#function compiling device model
def compile_device_model(input_shape=None,n_filters=None,name=None,offload=False):
outputs = {}
inputs = Input(shape=input_shape)
net = inputs
net = Conv2D(n_filters,3,activation='relu',kernel_initializer='he_uniform')(net)
net = BatchNormalization()(net)
net = Conv2D(n_filters,3,activation='relu',kernel_initializer='he_uniform')(net)
net = BatchNormalization()(net)
#set up output layer for offloading data to next device
if offload:
net = Conv2D(1,5,activation='relu',kernel_initializer='he_uniform')(net)
offload = BatchNormalization(name='offload')(net)
outputs['offload'] = offload
net = offload
net = Conv2D(n_filters,3,activation='relu',kernel_initializer='he_uniform')(net)
net = Concatenate()([GlobalMaxPool2D()(net),GlobalAvgPool2D()(net)])
else:
net = Concatenate()([GlobalMaxPool2D()(net),GlobalAvgPool2D()(net)])
net = Dense(n_filters,activation='relu',kernel_initializer='he_uniform')(net)
net = BatchNormalization()(net)
#output prediction layer
outputs[name+'_outputs'] = Dense(n_classes,activation='linear',
kernel_initializer='glorot_uniform',name=name+'_outputs')(net)
model = Model(inputs=inputs,outputs=outputs,name=name)
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True),
optimizer='nadam',
metrics=['accuracy'])
return model
#number of output classes
n_classes = 10
device_names = ['end','edge','cloud']
#number of convolutional fitlers in each device network
device_n_filters = [4,8,16]
#whether device model offloads to another device or not
device_offload = [True,True,False]
device_models = {}
#end device model input shape
input_shape = x_train.shape[1:]
#compile end, edge, and cloud device models
for device, n_filters, offload in zip(device_names, device_n_filters, device_offload):
#compile device model
device_models[device] = compile_device_model(input_shape=input_shape,
n_filters=n_filters, name=device, offload=offload)
if 'offload' in device_models[device].output:
#input shape of next device model
input_shape = device_models[device].output['offload']._shape_tuple()[1:]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment