Created
July 27, 2022 22:02
-
-
Save ianstenbit/bdb604eef78907da7d9d8c590a469de8 to your computer and use it in GitHub Desktop.
retinanet but using kerasCV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 The KerasCV Authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import tensorflow as tf | |
from tensorflow import keras | |
from keras_cv import bounding_box | |
from keras_cv import models | |
from keras_cv.applications.object_detection.retina_net.__internal__ import ( | |
layers as layers_lib, | |
) | |
from keras_cv.applications.object_detection.retina_net.__internal__ import ( | |
utils as utils_lib, | |
) | |
# TODO(lukewood): update docstring to include documentation on creating a custom label | |
# decoder/etc. | |
class RetinaNet(keras.Model): | |
"""A Keras model implementing the RetinaNet architecture. | |
Implements the RetinaNet architecture for object detection. The constructor | |
requires `num_classes`, `bounding_box_format` and a `backbone`. Optionally, a | |
custom label encoder, feature pyramid network, and prediction decoder may all be | |
provided. | |
Usage: | |
``` | |
retina_net = keras_cv.applications.RetinaNet( | |
num_classes=20, | |
bounding_box_format="xywh", | |
backbone="resnet50", | |
backbone_weights="imagenet", | |
include_rescaling=True, | |
) | |
``` | |
Args: | |
num_classes: the number of classes in your dataset excluding the background | |
class. Classes should be represented by integers in the range | |
[0, num_classes). | |
bounding_box_format: The format of bounding boxes of input dataset. Refer | |
https://github.com/keras-team/keras-cv/blob/master/keras_cv/bounding_box/converters.py | |
for more details on supported bounding box formats. | |
backbone: Either 'resnet50' or a custom backbone model. Please see {link} to see | |
how to construct your own backbone. | |
include_rescaling: Required if provided backbone is a pre-configured model. | |
If set to True, inputs will be passed through a Rescaling(1/255.0) layer. | |
backbone_weights: (Optional) if using a KerasCV provided backbone, the | |
underlying backbone model will be loaded using the weights provided in this | |
argument. Can be a model checkpoint path, or a string from the supported | |
weight sets in the underlying model. | |
label_encoder: (Optional) a keras.Layer that accepts an image Tensor and a | |
bounding box Tensor to its `call()` method, and returns RetinaNet training | |
targets. By default, a KerasCV standard LabelEncoder is created and used. | |
Results of this `call()` method are passed to the `loss` object passed into | |
`compile()` as the `y_true` argument. | |
feature_pyramid: (Optional) A `keras.Model` representing a feature pyramid | |
network (FPN). The feature pyramid network is called on the outputs of the | |
`backbone`. The keras_cv default backbones return three outputs in a list, | |
but custom backbones may be written and used with custom feature pyramid | |
networks. If not provided, a default feature pyramid neetwork is produced | |
by the library. The default feature pyramid network is compatible with all | |
standard keras_cv backbones. | |
prediction_decoder: (Optional) A `keras.layer` that is responsible for | |
transforming retina_net predictions into usable bounding box Tensors. If | |
not provided, a default is provided. The default PredictionDecoder layer | |
operates using an AnchorBox matching algorithm and a NonMaxSuppression | |
operation. | |
name: (Optional) name for the model, defaults to RetinaNet. | |
""" | |
def __init__( | |
self, | |
num_classes, | |
bounding_box_format, | |
backbone, | |
include_rescaling=None, | |
backbone_weights=None, | |
label_encoder=None, | |
feature_pyramid=None, | |
prediction_decoder=None, | |
name="RetinaNet", | |
**kwargs, | |
): | |
super().__init__(name=name, **kwargs) | |
if backbone is not None and include_rescaling is None: | |
raise ValueError( | |
"Either `backbone` OR `include_rescaling` must be set when " | |
"constructing a `keras_cv.models.RetinaNet()` model. " | |
"When `include_rescaling` is set, a ResNet50 backbone will be used. " | |
"Rescaling will be performed according to the include_rescaling parameter. " | |
"When `backbone` is set, rescaling will be the responsibility of the " | |
"backbone. Please read more about input scaling at {LINK}. " | |
f"Received backbone={backbone}, include_rescaling={include_rescaling}." | |
) | |
self.bounding_box_format = bounding_box_format | |
self.num_classes = num_classes | |
self.backbone = _parse_backbone(backbone, include_rescaling, backbone_weights) | |
self.label_encoder = label_encoder or utils_lib.LabelEncoder( | |
bounding_box_format=bounding_box_format | |
) | |
self.feature_pyramid = feature_pyramid or layers_lib.FeaturePyramid() | |
prior_probability = tf.constant_initializer(-tf.math.log((1 - 0.01) / 0.01)) | |
self.classification_head = layers_lib.PredictionHead( | |
output_filters=9 * num_classes, bias_initializer=prior_probability | |
) | |
self.box_head = layers_lib.PredictionHead( | |
output_filters=9 * 4, bias_initializer="zeros" | |
) | |
self.prediction_decoder = prediction_decoder or layers_lib.DecodePredictions( | |
num_classes=num_classes, bounding_box_format=bounding_box_format | |
) | |
self._metrics_bounding_box_format = None | |
def compile(self, metrics=None, **kwargs): | |
metrics = metrics or [] | |
super().compile(metrics=metrics, **kwargs) | |
all_have_format = any( | |
[ | |
m.bounding_box_format != self._metrics_bounding_box_format | |
for m in metrics | |
] | |
) | |
if not all_have_format: | |
raise ValueError( | |
"All metrics passed to RetinaNet.compile() must have " | |
"a `bounding_box_format` attribute." | |
) | |
if len(metrics) != 0: | |
self._metrics_bounding_box_format = metrics[0].bounding_box_format | |
else: | |
self._metrics_bounding_box_format = self.bounding_box_format | |
any_wrong_format = any( | |
[ | |
m.bounding_box_format != self._metrics_bounding_box_format | |
for m in metrics | |
] | |
) | |
if any_wrong_format: | |
raise ValueError( | |
"All metrics passed to RetinaNet.compile() must have " | |
"the same `bounding_box_format` attribute. For example, if one metric " | |
"uses 'xyxy', all other metrics must use 'xyxy'" | |
) | |
def call(self, x, training=False): | |
backbone_outputs = self.backbone(x, training=training) | |
features = self.feature_pyramid(backbone_outputs, training=training) | |
N = tf.shape(x)[0] | |
cls_outputs = [] | |
box_outputs = [] | |
for feature in features: | |
box_outputs.append(tf.reshape(self.box_head(feature), [N, -1, 4])) | |
cls_outputs.append( | |
tf.reshape(self.classification_head(feature), [N, -1, self.num_classes]) | |
) | |
cls_outputs = tf.concat(cls_outputs, axis=1) | |
box_outputs = tf.concat(box_outputs, axis=1) | |
train_preds = tf.concat([box_outputs, cls_outputs], axis=-1) | |
# no-op if default decoder is used. | |
pred_for_inference = bounding_box.convert_format( | |
train_preds, | |
source=self.bounding_box_format, | |
target=self.prediction_decoder.bounding_box_format, | |
images=x, | |
) | |
pred_for_inference = self.prediction_decoder(x, pred_for_inference) | |
pred_for_inference = bounding_box.convert_format( | |
pred_for_inference, | |
source=self.prediction_decoder.bounding_box_format, | |
target=self.bounding_box_format, | |
images=x, | |
) | |
return {"train_preds": train_preds, "inference": pred_for_inference} | |
def _encode_data(self, x, y): | |
y_for_metrics = y | |
y = bounding_box.convert_format( | |
y, | |
source=self.bounding_box_format, | |
target=self.label_encoder.bounding_box_format, | |
images=x, | |
) | |
y_training_target = self.label_encoder(x, y) | |
y_training_target = bounding_box.convert_format( | |
y_training_target, | |
source=self.label_encoder.bounding_box_format, | |
target=self.bounding_box_format, | |
images=x, | |
) | |
return y_for_metrics, y_training_target | |
def train_step(self, data): | |
x, y = data | |
# y comes in in self.bounding_box_format | |
y_for_metrics, y_training_target = self._encode_data(x, y) | |
# y is still in self.bounding_box_format | |
with tf.GradientTape() as tape: | |
predictions = self(x, training=True) | |
# predictions technically do not have a format | |
# loss accepts | |
loss = self.compiled_loss( | |
y_training_target, | |
predictions["train_preds"], | |
regularization_losses=self.losses, | |
) | |
# Training specific code | |
trainable_vars = self.trainable_variables | |
gradients = tape.gradient(loss, trainable_vars) | |
self.optimizer.apply_gradients(zip(gradients, trainable_vars)) | |
# To minimize GPU transfers, we update metrics AFTER we take grades and apply | |
# them. | |
# TODO(lukewood): assert that all metric formats are the same | |
self._update_metrics(y_for_metrics, predictions["inference"]) | |
return self._metrics_result(loss) | |
def test_step(self, data): | |
x, y = data | |
y_for_metrics, y_training_target = self._encode_data(x, y) | |
predictions = self(x) | |
loss = self.compiled_loss( | |
y_training_target, | |
predictions["train_preds"], | |
regularization_losses=self.losses, | |
) | |
self._update_metrics(y_for_metrics, predictions["inference"]) | |
return self._metrics_result(loss) | |
def _update_metrics(self, y_true, y_pred): | |
y_true = bounding_box.convert_format( | |
y_true, | |
source=self.bounding_box_format, | |
target=self._metrics_bounding_box_format, | |
) | |
y_pred = bounding_box.convert_format( | |
y_pred, | |
source=self.bounding_box_format, | |
target=self._metrics_bounding_box_format, | |
) | |
self.compiled_metrics.update_state(y_true, y_pred) | |
def _metrics_result(self, loss): | |
metrics_result = {m.name: m.result() for m in self.metrics} | |
metrics_result["loss"] = loss | |
return metrics_result | |
def inference(self, x): | |
predictions = self.predict(x) | |
return predictions["inference"] | |
def _parse_backbone(backbone, include_rescaling, backbone_weights): | |
if isinstance(backbone, str): | |
if backbone == "resnet50": | |
return _resnet50_backbone(include_rescaling, backbone_weights) | |
else: | |
raise ValueError( | |
"backbone expected to be one of ['resnet50', keras.Model]. " | |
f"Received backbone={backbone}." | |
) | |
if include_rescaling or backbone_weights: | |
raise ValueError( | |
"When a custom backbone is used, include_rescaling and " | |
f"backbone_weights are not supported. Received backbone={backbone}, " | |
f"include_rescaling={include_rescaling}, and " | |
f"backbone_weights={backbone_weights}." | |
) | |
if not callable(backbone): | |
raise ValueError( | |
"Custom backbones should be subclasses of a keras.Model. " | |
f"Received backbone={backbone}." | |
) | |
return backbone | |
# --- Building the ResNet50 backbone --- | |
def _resnet50_backbone(include_rescaling, backbone_weights): | |
inputs = keras.layers.Input(shape=(None, None, 3)) | |
# TODO(lukewood): this should really be calling keras_cv.models.ResNet50 | |
backbone = models.ResNet50( | |
include_top=False, include_rescaling=include_rescaling, weights=backbone_weights | |
) | |
c3_output, c4_output, c5_output = [ | |
backbone.get_layer(layer_name).output | |
for layer_name in ["v1_stack_2_block4_3_conv", "v1_stack_3_block6_3_conv", "v1_stack_4_block3_3_conv"] | |
] | |
return keras.Model(inputs=inputs, outputs=[c3_output, c4_output, c5_output]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment