Last active
July 6, 2023 12:13
-
-
Save DoctorLoop/293ae5cc3bda2ccc333d9b216eacc301 to your computer and use it in GitHub Desktop.
Definition of a Bayesian Convolutional Architecture for regression problems
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tf.keras.backend.clear_session() | |
kl_divergence_function = lambda q, p, _: dist.kl_divergence(q, p) / tf.cast(836, dtype=tf.float32) | |
model = tf.keras.Sequential([ | |
tf.keras.Input(shape=(126,126,1),name="basket"), | |
tfp.layers.Convolution2DFlipout(16, kernel_size=5, strides=(1,1), data_format="channels_last", | |
padding="same", activation=tf.nn.relu, name="conv_tfp_1a", | |
kernel_divergence_fn=kl_divergence_function), | |
tf.keras.layers.MaxPool2D(strides=(4,4), pool_size=(4,4), padding="same"), | |
tfp.layers.Convolution2DFlipout(32, kernel_size=3, strides=(1,1), data_format="channels_last", | |
padding="same", activation=tf.nn.relu, name="conv_tfp_1b", | |
kernel_divergence_fn=kl_divergence_function), | |
tf.keras.layers.MaxPool2D(strides=(4,4), pool_size=(4,4), padding="same"), | |
tf.keras.layers.Flatten(), | |
tfp.layers.DenseFlipout(1, kernel_divergence_fn=kl_divergence_function), | |
]) | |
learning_rate = 1.0e-3 | |
model.compile(loss='mse', | |
optimizer=tf.keras.optimizers.Adam(learning_rate), | |
metrics=['mse']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Thanks for the nice explanation.
Please let me know, how to save and load the given BNN models (models/weights/parameters) for prediction.