Last active
April 9, 2019 16:30
-
-
Save arthurdouillard/5a6aaf010a034043a4ad0402a37ca0d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_retinanet(weights, n_classes, freeze=True): | |
modifier = freeze_model if freeze else None | |
model = resnet50_retinanet(num_classes=num_classes, modifier=modifier) | |
model.load_weights(weights, by_name=True, skip_mismatch=True) | |
return model | |
def compile(model): | |
model.compile( | |
loss={ | |
'regression' : keras_retinanet.losses.smooth_l1(), | |
'classification': keras_retinanet.losses.focal() | |
}, | |
optimizer=optimizers.adam(lr=configs['lr'], clipnorm=0.001) | |
) | |
def train(model, train_gen, val_gen, callbacks, n_epochs=20): | |
"""train_gen and val_gen are instances of DfGenerator.""" | |
model.fit_generator( | |
train_gen, | |
steps_per_epoch=len(train_gen), | |
validation_data=val_gen, | |
validation_steps=len(val_gen), | |
callbacks=callbacks, | |
epochs=n_epochs, | |
verbose=2 | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
From line 1, it looks to me that line 4 should be: