Skip to content

Instantly share code, notes, and snippets.

@sthalles
Created November 26, 2019 16:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sthalles/d4e0c4691dc2be2497ba1cdbfe3bc2eb to your computer and use it in GitHub Desktop.
Save sthalles/d4e0c4691dc2be2497ba1cdbfe3bc2eb to your computer and use it in GitHub Desktop.
import os
import tempfile
def add_regularization(model, regularizer=tf.keras.regularizers.l2(0.0001)):
if not isinstance(regularizer, tf.keras.regularizers.Regularizer):
print("Regularizer must be a subclass of tf.keras.regularizers.Regularizer")
return model
for layer in model.layers:
for attr in ['kernel_regularizer']:
if hasattr(layer, attr):
setattr(layer, attr, regularizer)
# When we change the layers attributes, the change only happens in the model config file
model_json = model.to_json()
# Save the weights before reloading the model.
tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
model.save_weights(tmp_weights_path)
# load the model from the config
model = tf.keras.models.model_from_json(model_json)
# Reload the model weights
model.load_weights(tmp_weights_path, by_name=True)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment