Skip to content

Instantly share code, notes, and snippets.

@juanting
Last active December 16, 2019 03:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save juanting/3fcb814a9d1db89352a3d00100c7f680 to your computer and use it in GitHub Desktop.
Save juanting/3fcb814a9d1db89352a3d00100c7f680 to your computer and use it in GitHub Desktop.
def model_fn(features, labels, mode, params):
'''The model function of the estimator
'''
# features,features_mask=features[0],features[1]
# raise ValueError(features.shape)
network_type = params['network']
# network_type = config.network
if network_type == 'onset':
net_func = onset_network
loss_func = ModelLoss.onset_loss
metric_func = ConfigMetrics.onset_metrics
elif network_type == 'pitch':
net_func = pitch_network
loss_func = functools.partial(ModelLoss.pitch_loss_weighted_velocity,
velocity_weights = params['velocity_weights'])
metric_func = ConfigMetrics.pitch_metrics
elif network_type == 'frame':
net_func = frame_network
loss_func = ModelLoss.frame_loss
metric_func = ConfigMetrics.frame_metrics
elif network_type == 'union':
net_func = union_network
loss_func = ModelLoss.model_loss
metric_func = ConfigMetrics.model_metrics
logits = net_func(features, (mode == tf.estimator.ModeKeys.TRAIN), True)
# the prediction part
probs = tf.nn.sigmoid(logits)
predictions = {
'probs': probs,
}
# restore the parameters from the specific model file path.
restore_hook = RestoreHook(params['weight_path'], ['global_step'])
if params['weight_path']:
# restore parameters from two models -- onset and onset_pitch and frame
model_paths = [x for x in params['weight_path'].split(',') if len(x) != 0]
if len(model_paths) >= 2:
tf.train.init_from_checkpoint(model_paths[0],
{'onset/': 'onset/'})
tf.train.init_from_checkpoint(model_paths[1],
{'pitch/': 'pitch/'})
if len(model_paths) == 3:
tf.train.init_from_checkpoint(model_paths[2],
{'frame/': 'frame/'})
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode = mode, predictions=predictions,
prediction_hooks = [restore_hook])
# the train part
loss = loss_func(labels, logits) + ModelLoss.regu_loss()
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
lr = params['initial_learning_rate']
# use google learning rate settings
train_op = tf.contrib.layers.optimize_loss(
name='training',
loss=loss,
global_step=global_step,
learning_rate=lr,
learning_rate_decay_fn=functools.partial(tf.train.exponential_decay,decay_steps=20000,decay_rate=0.98,staircase=True),
clip_gradients=3.0,
optimizer='Adam')
else:
train_op = None
# the evaluation part
metrics = metric_func(labels, predictions['probs'])
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions, loss=loss, train_op=train_op,
eval_metric_ops=metrics,
training_hooks=[restore_hook],
evaluation_hooks=[restore_hook])
class RestoreHook(tf.train.SessionRunHook):
'''The Hook is used to restore the parameters from the saved model.
For the estimator, if we want to restore the model's parameters from
the saved model which is specified by ourselves, we should use this
little tirck.
Usage:
in the model function 's return value tf.estimator.EstimatorSpec,
add the training_hooks = [RestoreHook(weight_path, ['global_step'])]
'''
def __init__(self, weight_path, exclude):
'''prepare the saved mode path and the var_list that won't restore
Args:
weight_path: str, model file path. It can contains one or more models.
exclude: list, variable list that wont be restored.
'''
if weight_path:
self.model_paths = [x for x in weight_path.split(',') if len(x) != 0]
self.var_list = tf.contrib.framework.get_variables_to_restore(
exclude=exclude)
if len(self.model_paths) == 1:
self.init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
self.model_paths[0], self.var_list, ignore_missing_vars=True)
else:
self.init_fn = None
else:
self.init_fn = None
def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
if self.init_fn:
self.init_fn(session)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment