Skip to content

Instantly share code, notes, and snippets.

@arthurdouillard
Last active April 9, 2019 16:30
Show Gist options
  • Save arthurdouillard/5a6aaf010a034043a4ad0402a37ca0d1 to your computer and use it in GitHub Desktop.
Save arthurdouillard/5a6aaf010a034043a4ad0402a37ca0d1 to your computer and use it in GitHub Desktop.
def load_retinanet(weights, n_classes, freeze=True):
modifier = freeze_model if freeze else None
model = resnet50_retinanet(num_classes=num_classes, modifier=modifier)
model.load_weights(weights, by_name=True, skip_mismatch=True)
return model
def compile(model):
model.compile(
loss={
'regression' : keras_retinanet.losses.smooth_l1(),
'classification': keras_retinanet.losses.focal()
},
optimizer=optimizers.adam(lr=configs['lr'], clipnorm=0.001)
)
def train(model, train_gen, val_gen, callbacks, n_epochs=20):
"""train_gen and val_gen are instances of DfGenerator."""
model.fit_generator(
train_gen,
steps_per_epoch=len(train_gen),
validation_data=val_gen,
validation_steps=len(val_gen),
callbacks=callbacks,
epochs=n_epochs,
verbose=2
)
@ChubaOraka
Copy link

From line 1, it looks to me that line 4 should be:

model = resnet50_retinanet(num_classes=n_classes, modifier=modifier)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment