Created
March 20, 2018 12:26
-
-
Save skrish13/4e10fb46017b7abf459d1eabe5967041 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def ROIAlign(feature_maps, rois, config, pool_size, mode='bilinear'): | |
"""Implements ROI Align on the features. | |
Params: | |
- pool_shape: [height, width] of the output pooled regions. Usually [7, 7] | |
- image_shape: [height, width, chanells]. Shape of input image in pixels | |
Inputs: | |
- boxes: [batch, num_boxes, (x1, y1, x2, y2)] in normalized | |
coordinates. Possibly padded with zeros if not enough | |
boxes to fill the array. | |
- Feature maps: List of feature maps from different levels of the pyramid. | |
Each is [batch, channels, height, width] | |
Output: | |
Pooled regions in the shape: [batch, num_boxes, height, width, channels]. | |
The width and height are those specific in the pool_shape in the layer | |
constructor. | |
""" | |
""" | |
[ x2-x1 x1 + x2 - W + 1 ] | |
[ ----- 0 --------------- ] | |
[ W - 1 W - 1 ] | |
[ ] | |
[ y2-y1 y1 + y2 - H + 1 ] | |
[ 0 ----- --------------- ] | |
[ H - 1 H - 1 ] | |
""" | |
#feature_maps= [P2, P3, P4, P5] | |
rois = rois.detach() | |
crop_resize = CropAndResize(pool_size, pool_size, 0) | |
roi_number = rois.size()[1] | |
pooled = rois.data.new( | |
config.IMAGES_PER_GPU*rois.size( | |
1), 256, pool_size, pool_size).zero_() | |
rois = rois.view( | |
config.IMAGES_PER_GPU*rois.size(1), | |
4) | |
# Loop through levels and apply ROI pooling to each. P2 to P5. | |
x_1 = rois[:, 0] | |
y_1 = rois[:, 1] | |
x_2 = rois[:, 2] | |
y_2 = rois[:, 3] | |
roi_level = log2_graph( | |
torch.div(torch.sqrt((y_2 - y_1) * (x_2 - x_1)), 224.0)) | |
roi_level = torch.clamp(torch.clamp( | |
torch.add(torch.round(roi_level), 4), min=2), max=5) | |
# P2 is 256x256, P3 is 128x128, P4 is 64x64, P5 is 32x32 | |
# P2 is 4, P3 is 8, P4 is 16, P5 is 32 | |
for i, level in enumerate(range(2, 6)): | |
scaling_ratio = 2**level | |
height = float(config.IMAGE_MAX_DIM)/ scaling_ratio | |
width = float(config.IMAGE_MAX_DIM) / scaling_ratio | |
ixx = torch.eq(roi_level, level) | |
box_indices = ixx.view(-1).int() * 0 | |
ix = torch.unsqueeze(ixx, 1) | |
level_boxes = torch.masked_select(rois, ix) | |
if level_boxes.size()[0] == 0: | |
continue | |
level_boxes = level_boxes.view(-1, 4) | |
crops = crop_resize(feature_maps[i], torch.div( | |
level_boxes, float(config.IMAGE_MAX_DIM) | |
)[:, [1, 0, 3, 2]], box_indices) | |
indices_pooled = ixx.nonzero()[:, 0] | |
pooled[indices_pooled.data, :, :, :] = crops.data | |
pooled = pooled.view(config.IMAGES_PER_GPU, roi_number, | |
256, pool_size, pool_size) | |
pooled = Variable(pooled).cuda() | |
return pooled |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment