Last active
May 18, 2023 05:07
-
-
Save maitchison/77d519c3945548f063bda94ce3189587 to your computer and use it in GitHub Desktop.
Example of inserting RESNET-50 into a recurrent neural network.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tensorflow-slim very nicely includes some prebuilt models for us to make use of. | |
from tensorflow.contrib.slim.nets import resnet_v2 | |
class ModelCRNN_Resnet(ConvModel): | |
""" Recurrent Convolutional Neural Network based on resnet""" | |
MODEL_NAME = "model_resent" | |
MODEL_DESCRIPTION = "CNN + LSTM" | |
DEFAULT_PARAMS = { | |
# training params | |
'batch_size': 16, | |
'learning_rate': 1e-4, | |
'learning_rate_decay': 1.0, | |
'l2_reg': 0, | |
'label_smoothing': 0.1, | |
'keep_prob': 0.5, | |
# model params | |
'batch_norm': True, | |
'lstm_units': 256, | |
'enable_flow': True, | |
# augmentation | |
'augmentation': True, | |
'thermal_threshold': -10, | |
'scale_frequency': 0.5 | |
} | |
def __init__(self, labels, **kwargs): | |
""" | |
Initialise the model | |
:param labels: number of labels for model to predict | |
""" | |
super().__init__() | |
self.params.update(self.DEFAULT_PARAMS) | |
self.params.update(kwargs) | |
self._build_model(labels) | |
def _build_model(self, label_count): | |
# dimensions are documents as follows | |
# B batch size | |
# F frames per segment | |
# C channels | |
# H frame height | |
# W frame width | |
thermal, flow, mask = self.process_inputs() | |
frame_count = tf.shape(self.X)[1] | |
# --------------------------------------------------- | |
# run the convolutions via resnet | |
blocks = [ | |
resnet_v2.resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), | |
resnet_v2.resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), | |
resnet_v2.resnet_v2_block('block3', base_depth=256, num_units=6, stride=2), | |
resnet_v2.resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), | |
] | |
layer = tf.concat((thermal, flow), axis=3) | |
logging.info("Convolution input shape: {}".format(layer.shape)) | |
thermal_conv, end_points = resnet_v2.resnet_v2( | |
layer, | |
blocks=blocks, | |
num_classes = None, | |
is_training=self.is_training | |
) | |
# --------------------------------------------------- | |
# pass resnet output to LSTM units | |
logging.info("Convolution output shape: {}".format(thermal_conv.shape)) | |
filtered_out = tf.reshape(thermal_conv, [-1, frame_count, tools.product(thermal_conv.shape[1:])], | |
name='thermal/out') | |
out = tf.concat((filtered_out,), axis=2, name='out') | |
logging.info('Output shape {}'.format(out.shape)) | |
# ------------------------------------- | |
# run the LSTM | |
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self.params['lstm_units']) | |
dropout = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob, dtype=np.float32) | |
init_state = tf.nn.rnn_cell.LSTMStateTuple(self.state_in[:, :, 0], self.state_in[:, :, 1]) | |
lstm_outputs, lstm_states = tf.nn.dynamic_rnn( | |
cell=dropout, inputs=out, | |
initial_state=init_state, | |
dtype=tf.float32, | |
scope='lstm' | |
) | |
lstm_state_1, lstm_state_2 = lstm_states | |
# just need the last output | |
lstm_output = tf.identity(lstm_outputs[:, -1], 'lstm_out') | |
lstm_state = tf.stack([lstm_state_1, lstm_state_2], axis=2) | |
logging.info("lstm output shape: {} x {}".format(lstm_outputs.shape[1], lstm_output.shape)) | |
logging.info("lstm state shape: {}".format(lstm_state.shape)) | |
# ------------------------------------- | |
# dense / logits | |
# dense layer on top of convolutional output mapping to class labels. | |
logits = tf.layers.dense(inputs=lstm_output, units=label_count, activation=None, name='logits') | |
tf.summary.histogram('weights/logits', logits) | |
softmax_loss = tf.losses.softmax_cross_entropy( | |
onehot_labels=tf.one_hot(self.y, label_count), | |
logits=logits, label_smoothing=self.params['label_smoothing'], | |
scope='softmax_loss') | |
if self.params['l2_reg'] != 0: | |
with tf.variable_scope('logits', reuse=True): | |
logit_weights = tf.get_variable('kernel') | |
reg_loss = (tf.nn.l2_loss(logit_weights, name='loss/reg') * self.params['l2_reg']) | |
loss = tf.add( | |
softmax_loss, reg_loss, name='loss' | |
) | |
tf.summary.scalar('loss/reg', reg_loss) | |
tf.summary.scalar('loss/softmax', softmax_loss) | |
else: | |
# just relabel the loss node | |
loss = tf.identity(softmax_loss, 'loss') | |
class_out = tf.argmax(logits, axis=1, name='class_out') | |
correct_prediction = tf.equal(class_out, self.y) | |
pred = tf.nn.softmax(logits, name='prediction') | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32), name='accuracy') | |
self.setup_novelty(logits, lstm_output) | |
self.setup_optimizer(loss) | |
# make reference to special nodes | |
tf.identity(lstm_state, 'state_out') | |
tf.identity(lstm_output, 'hidden_out') | |
tf.identity(logits, 'logits_out') | |
self.attach_nodes() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment