Last active
December 16, 2019 03:33
-
-
Save juanting/3fcb814a9d1db89352a3d00100c7f680 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_fn(features, labels, mode, params): | |
'''The model function of the estimator | |
''' | |
# features,features_mask=features[0],features[1] | |
# raise ValueError(features.shape) | |
network_type = params['network'] | |
# network_type = config.network | |
if network_type == 'onset': | |
net_func = onset_network | |
loss_func = ModelLoss.onset_loss | |
metric_func = ConfigMetrics.onset_metrics | |
elif network_type == 'pitch': | |
net_func = pitch_network | |
loss_func = functools.partial(ModelLoss.pitch_loss_weighted_velocity, | |
velocity_weights = params['velocity_weights']) | |
metric_func = ConfigMetrics.pitch_metrics | |
elif network_type == 'frame': | |
net_func = frame_network | |
loss_func = ModelLoss.frame_loss | |
metric_func = ConfigMetrics.frame_metrics | |
elif network_type == 'union': | |
net_func = union_network | |
loss_func = ModelLoss.model_loss | |
metric_func = ConfigMetrics.model_metrics | |
logits = net_func(features, (mode == tf.estimator.ModeKeys.TRAIN), True) | |
# the prediction part | |
probs = tf.nn.sigmoid(logits) | |
predictions = { | |
'probs': probs, | |
} | |
# restore the parameters from the specific model file path. | |
restore_hook = RestoreHook(params['weight_path'], ['global_step']) | |
if params['weight_path']: | |
# restore parameters from two models -- onset and onset_pitch and frame | |
model_paths = [x for x in params['weight_path'].split(',') if len(x) != 0] | |
if len(model_paths) >= 2: | |
tf.train.init_from_checkpoint(model_paths[0], | |
{'onset/': 'onset/'}) | |
tf.train.init_from_checkpoint(model_paths[1], | |
{'pitch/': 'pitch/'}) | |
if len(model_paths) == 3: | |
tf.train.init_from_checkpoint(model_paths[2], | |
{'frame/': 'frame/'}) | |
if mode == tf.estimator.ModeKeys.PREDICT: | |
return tf.estimator.EstimatorSpec(mode = mode, predictions=predictions, | |
prediction_hooks = [restore_hook]) | |
# the train part | |
loss = loss_func(labels, logits) + ModelLoss.regu_loss() | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
global_step = tf.train.get_or_create_global_step() | |
lr = params['initial_learning_rate'] | |
# use google learning rate settings | |
train_op = tf.contrib.layers.optimize_loss( | |
name='training', | |
loss=loss, | |
global_step=global_step, | |
learning_rate=lr, | |
learning_rate_decay_fn=functools.partial(tf.train.exponential_decay,decay_steps=20000,decay_rate=0.98,staircase=True), | |
clip_gradients=3.0, | |
optimizer='Adam') | |
else: | |
train_op = None | |
# the evaluation part | |
metrics = metric_func(labels, predictions['probs']) | |
return tf.estimator.EstimatorSpec( | |
mode=mode, predictions=predictions, loss=loss, train_op=train_op, | |
eval_metric_ops=metrics, | |
training_hooks=[restore_hook], | |
evaluation_hooks=[restore_hook]) | |
class RestoreHook(tf.train.SessionRunHook): | |
'''The Hook is used to restore the parameters from the saved model. | |
For the estimator, if we want to restore the model's parameters from | |
the saved model which is specified by ourselves, we should use this | |
little tirck. | |
Usage: | |
in the model function 's return value tf.estimator.EstimatorSpec, | |
add the training_hooks = [RestoreHook(weight_path, ['global_step'])] | |
''' | |
def __init__(self, weight_path, exclude): | |
'''prepare the saved mode path and the var_list that won't restore | |
Args: | |
weight_path: str, model file path. It can contains one or more models. | |
exclude: list, variable list that wont be restored. | |
''' | |
if weight_path: | |
self.model_paths = [x for x in weight_path.split(',') if len(x) != 0] | |
self.var_list = tf.contrib.framework.get_variables_to_restore( | |
exclude=exclude) | |
if len(self.model_paths) == 1: | |
self.init_fn = tf.contrib.framework.assign_from_checkpoint_fn( | |
self.model_paths[0], self.var_list, ignore_missing_vars=True) | |
else: | |
self.init_fn = None | |
else: | |
self.init_fn = None | |
def after_create_session(self, session, coord=None): | |
if session.run(tf.train.get_or_create_global_step()) == 0: | |
if self.init_fn: | |
self.init_fn(session) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment