Last active
December 5, 2023 20:22
-
-
Save jamartinh/7edbe95689be99b1db63a7e790ba10f6 to your computer and use it in GitHub Desktop.
Parallel Monte Carlo Dropout
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mc_dropout_inference(model, input, num_samples): | |
# Repeat interleave the input | |
repeated_input = input.repeat_interleave(num_samples, dim=0) | |
# Enable dropout during inference | |
model.train() | |
# Forward pass with the repeated input | |
repeated_output = model(repeated_input) | |
# Reshape the output to separate the samples for each original input | |
output_shape = (-1, num_samples) + repeated_output.shape[1:] | |
output_samples = repeated_output.view(output_shape) | |
# Average the outputs for each input | |
averaged_output = output_samples.mean(dim=1) | |
# Calculate the standard deviation for each input | |
std_output = output_samples.std(dim=1) | |
return averaged_output, std_output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment