Skip to content

Instantly share code, notes, and snippets.

@DoctorLoop
Last active July 6, 2023 12:13
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DoctorLoop/293ae5cc3bda2ccc333d9b216eacc301 to your computer and use it in GitHub Desktop.
Save DoctorLoop/293ae5cc3bda2ccc333d9b216eacc301 to your computer and use it in GitHub Desktop.
Definition of a Bayesian Convolutional Architecture for regression problems
tf.keras.backend.clear_session()
kl_divergence_function = lambda q, p, _: dist.kl_divergence(q, p) / tf.cast(836, dtype=tf.float32)
model = tf.keras.Sequential([
tf.keras.Input(shape=(126,126,1),name="basket"),
tfp.layers.Convolution2DFlipout(16, kernel_size=5, strides=(1,1), data_format="channels_last",
padding="same", activation=tf.nn.relu, name="conv_tfp_1a",
kernel_divergence_fn=kl_divergence_function),
tf.keras.layers.MaxPool2D(strides=(4,4), pool_size=(4,4), padding="same"),
tfp.layers.Convolution2DFlipout(32, kernel_size=3, strides=(1,1), data_format="channels_last",
padding="same", activation=tf.nn.relu, name="conv_tfp_1b",
kernel_divergence_fn=kl_divergence_function),
tf.keras.layers.MaxPool2D(strides=(4,4), pool_size=(4,4), padding="same"),
tf.keras.layers.Flatten(),
tfp.layers.DenseFlipout(1, kernel_divergence_fn=kl_divergence_function),
])
learning_rate = 1.0e-3
model.compile(loss='mse',
optimizer=tf.keras.optimizers.Adam(learning_rate),
metrics=['mse'])
@sreerajfab
Copy link

sreerajfab commented Jul 6, 2023

Hi,
Thanks for the nice explanation.
Please let me know, how to save and load the given BNN models (models/weights/parameters) for prediction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment