Skip to content

Instantly share code, notes, and snippets.

@jamartinh
Last active December 5, 2023 20:22
Show Gist options
  • Save jamartinh/7edbe95689be99b1db63a7e790ba10f6 to your computer and use it in GitHub Desktop.
Save jamartinh/7edbe95689be99b1db63a7e790ba10f6 to your computer and use it in GitHub Desktop.
Parallel Monte Carlo Dropout
def mc_dropout_inference(model, input, num_samples):
# Repeat interleave the input
repeated_input = input.repeat_interleave(num_samples, dim=0)
# Enable dropout during inference
model.train()
# Forward pass with the repeated input
repeated_output = model(repeated_input)
# Reshape the output to separate the samples for each original input
output_shape = (-1, num_samples) + repeated_output.shape[1:]
output_samples = repeated_output.view(output_shape)
# Average the outputs for each input
averaged_output = output_samples.mean(dim=1)
# Calculate the standard deviation for each input
std_output = output_samples.std(dim=1)
return averaged_output, std_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment