Last active
May 2, 2020 20:00
-
-
Save PavlosMelissinos/162621051c906ea85b772997d982403a to your computer and use it in GitHub Desktop.
Yolo layers for keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class YoloHead(Layer): | |
def __init__(self, anchors, num_classes, **kwargs): | |
self.anchors = anchors | |
self.num_classes = num_classes | |
super(YoloHead, self).__init__(**kwargs) | |
def call(self, inputs, **kwargs): | |
"""Convert final layer features to bounding box parameters. | |
Parameters | |
---------- | |
feats : tensor | |
Final convolutional layer features. | |
anchors : array-like | |
Anchor box widths and heights. | |
num_classes : int | |
Number of target classes. | |
Returns | |
------- | |
box_xy : tensor | |
x, y box predictions adjusted by spatial location in conv layer. | |
box_wh : tensor | |
w, h box predictions adjusted by anchors and conv spatial resolution. | |
box_conf : tensor | |
Probability estimate for whether each box contains any object. | |
box_class_pred : tensor | |
Probability distribution estimate for each box over class labels. | |
""" | |
feats = inputs | |
num_anchors = len(self.anchors) | |
# Reshape to batch, height, width, num_anchors, box_params. | |
anchors_tensor = K.reshape(K.variable(self.anchors), [1, 1, 1, num_anchors, 2]) | |
# Static implementation for fixed models. | |
# TODO: Remove or add option for static implementation. | |
# _, conv_height, conv_width, _ = K.int_shape(feats) | |
# conv_dims = K.variable([conv_width, conv_height]) | |
# Dynamic implementation of conv dims for fully convolutional model. | |
conv_dims = K.shape(feats)[1:3] # assuming channels last | |
# In YOLO the height index is the inner most iteration. | |
conv_height_index = K.arange(0, stop=conv_dims[0]) | |
conv_width_index = K.arange(0, stop=conv_dims[1]) | |
conv_height_index = K.tile(conv_height_index, [conv_dims[1]]) | |
# TODO: Repeat_elements and tf.split doesn't support dynamic splits. | |
# conv_width_index = K.repeat_elements(conv_width_index, conv_dims[1], axis=0) | |
conv_width_index = K.tile( | |
K.expand_dims(conv_width_index, 0), [conv_dims[0], 1]) | |
conv_width_index = K.flatten(K.transpose(conv_width_index)) | |
conv_index = K.transpose(K.stack([conv_height_index, conv_width_index])) | |
conv_index = K.reshape(conv_index, [1, conv_dims[0], conv_dims[1], 1, 2]) | |
conv_index = K.cast(conv_index, K.dtype(feats)) | |
feats = K.reshape( | |
feats, [-1, conv_dims[0], conv_dims[1], num_anchors, self.num_classes + 5]) | |
conv_dims = K.cast(K.reshape(conv_dims, [1, 1, 1, 1, 2]), K.dtype(feats)) | |
# Static generation of conv_index: | |
# conv_index = np.array([_ for _ in np.ndindex(conv_width, conv_height)]) | |
# conv_index = conv_index[:, [1, 0]] # swap columns for YOLO ordering. | |
# conv_index = K.variable( | |
# conv_index.reshape(1, conv_height, conv_width, 1, 2)) | |
# feats = Reshape( | |
# (conv_dims[0], conv_dims[1], num_anchors, num_classes + 5))(feats) | |
box_xy = K.sigmoid(feats[..., :2]) | |
box_wh = K.exp(feats[..., 2:4]) | |
box_confidence = K.sigmoid(feats[..., 4:5]) | |
box_class_probs = K.softmax(feats[..., 5:]) | |
# Adjust preditions to each spatial grid point and anchor size. | |
# Note: YOLO iterates over height index before width index. | |
box_xy = (box_xy + conv_index) / conv_dims | |
box_wh = box_wh * anchors_tensor / conv_dims | |
return [box_xy, box_wh, box_confidence, box_class_probs] | |
def compute_output_shape(self, input_shape): | |
num_anchors = 5 | |
n, h, w, c = input_shape | |
box_xy = (None, h, w, num_anchors, 2) | |
box_wh = (None, h, w, num_anchors, 2) | |
box_confidence = (None, h, w, num_anchors, 1) | |
box_class_probs = (None, h, w, num_anchors, c // num_anchors - 5) | |
return [box_xy, box_wh, box_confidence, box_class_probs] | |
def compute_mask(self, inputs, mask=None): | |
return 4 * [None] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Εισαι ΩΡΑΙΟΣ! Cheers bro