Created
October 11, 2019 17:12
-
-
Save Niranjankumar-c/69e44ad064d24640791d7a23b3764ec3 to your computer and use it in GitHub Desktop.
function for occlusion analysis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#custom function to conduct occlusion experiments | |
def occlusion(model, image, label, occ_size = 50, occ_stride = 50, occ_pixel = 0.5): | |
#get the width and height of the image | |
width, height = image.shape[-2], image.shape[-1] | |
#setting the output image width and height | |
output_height = int(np.ceil((height-occ_size)/occ_stride)) | |
output_width = int(np.ceil((width-occ_size)/occ_stride)) | |
#create a white image of sizes we defined | |
heatmap = torch.zeros((output_height, output_width)) | |
#iterate all the pixels in each column | |
for h in range(0, height): | |
for w in range(0, width): | |
h_start = h*occ_stride | |
w_start = w*occ_stride | |
h_end = min(height, h_start + occ_size) | |
w_end = min(width, w_start + occ_size) | |
if (w_end) >= width or (h_end) >= height: | |
continue | |
input_image = image.clone().detach() | |
#replacing all the pixel information in the image with occ_pixel(grey) in the specified location | |
input_image[:, :, w_start:w_end, h_start:h_end] = occ_pixel | |
#run inference on modified image | |
output = model(input_image) | |
output = nn.functional.softmax(output, dim=1) | |
prob = output.tolist()[0][label] | |
#setting the heatmap location to probability value | |
heatmap[h, w] = prob | |
return heatmap |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment