Skip to content

Instantly share code, notes, and snippets.

@maitchison
Last active May 18, 2023 05:07
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maitchison/77d519c3945548f063bda94ce3189587 to your computer and use it in GitHub Desktop.
Save maitchison/77d519c3945548f063bda94ce3189587 to your computer and use it in GitHub Desktop.
Example of inserting RESNET-50 into a recurrent neural network.
# tensorflow-slim very nicely includes some prebuilt models for us to make use of.
from tensorflow.contrib.slim.nets import resnet_v2
class ModelCRNN_Resnet(ConvModel):
""" Recurrent Convolutional Neural Network based on resnet"""
MODEL_NAME = "model_resent"
MODEL_DESCRIPTION = "CNN + LSTM"
DEFAULT_PARAMS = {
# training params
'batch_size': 16,
'learning_rate': 1e-4,
'learning_rate_decay': 1.0,
'l2_reg': 0,
'label_smoothing': 0.1,
'keep_prob': 0.5,
# model params
'batch_norm': True,
'lstm_units': 256,
'enable_flow': True,
# augmentation
'augmentation': True,
'thermal_threshold': -10,
'scale_frequency': 0.5
}
def __init__(self, labels, **kwargs):
"""
Initialise the model
:param labels: number of labels for model to predict
"""
super().__init__()
self.params.update(self.DEFAULT_PARAMS)
self.params.update(kwargs)
self._build_model(labels)
def _build_model(self, label_count):
# dimensions are documents as follows
# B batch size
# F frames per segment
# C channels
# H frame height
# W frame width
thermal, flow, mask = self.process_inputs()
frame_count = tf.shape(self.X)[1]
# ---------------------------------------------------
# run the convolutions via resnet
blocks = [
resnet_v2.resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
resnet_v2.resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
resnet_v2.resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
resnet_v2.resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
]
layer = tf.concat((thermal, flow), axis=3)
logging.info("Convolution input shape: {}".format(layer.shape))
thermal_conv, end_points = resnet_v2.resnet_v2(
layer,
blocks=blocks,
num_classes = None,
is_training=self.is_training
)
# ---------------------------------------------------
# pass resnet output to LSTM units
logging.info("Convolution output shape: {}".format(thermal_conv.shape))
filtered_out = tf.reshape(thermal_conv, [-1, frame_count, tools.product(thermal_conv.shape[1:])],
name='thermal/out')
out = tf.concat((filtered_out,), axis=2, name='out')
logging.info('Output shape {}'.format(out.shape))
# -------------------------------------
# run the LSTM
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self.params['lstm_units'])
dropout = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob, dtype=np.float32)
init_state = tf.nn.rnn_cell.LSTMStateTuple(self.state_in[:, :, 0], self.state_in[:, :, 1])
lstm_outputs, lstm_states = tf.nn.dynamic_rnn(
cell=dropout, inputs=out,
initial_state=init_state,
dtype=tf.float32,
scope='lstm'
)
lstm_state_1, lstm_state_2 = lstm_states
# just need the last output
lstm_output = tf.identity(lstm_outputs[:, -1], 'lstm_out')
lstm_state = tf.stack([lstm_state_1, lstm_state_2], axis=2)
logging.info("lstm output shape: {} x {}".format(lstm_outputs.shape[1], lstm_output.shape))
logging.info("lstm state shape: {}".format(lstm_state.shape))
# -------------------------------------
# dense / logits
# dense layer on top of convolutional output mapping to class labels.
logits = tf.layers.dense(inputs=lstm_output, units=label_count, activation=None, name='logits')
tf.summary.histogram('weights/logits', logits)
softmax_loss = tf.losses.softmax_cross_entropy(
onehot_labels=tf.one_hot(self.y, label_count),
logits=logits, label_smoothing=self.params['label_smoothing'],
scope='softmax_loss')
if self.params['l2_reg'] != 0:
with tf.variable_scope('logits', reuse=True):
logit_weights = tf.get_variable('kernel')
reg_loss = (tf.nn.l2_loss(logit_weights, name='loss/reg') * self.params['l2_reg'])
loss = tf.add(
softmax_loss, reg_loss, name='loss'
)
tf.summary.scalar('loss/reg', reg_loss)
tf.summary.scalar('loss/softmax', softmax_loss)
else:
# just relabel the loss node
loss = tf.identity(softmax_loss, 'loss')
class_out = tf.argmax(logits, axis=1, name='class_out')
correct_prediction = tf.equal(class_out, self.y)
pred = tf.nn.softmax(logits, name='prediction')
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32), name='accuracy')
self.setup_novelty(logits, lstm_output)
self.setup_optimizer(loss)
# make reference to special nodes
tf.identity(lstm_state, 'state_out')
tf.identity(lstm_output, 'hidden_out')
tf.identity(logits, 'logits_out')
self.attach_nodes()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment