Skip to content

Instantly share code, notes, and snippets.

@sgodfrey66
Last active June 16, 2020 15:40
Show Gist options
  • Save sgodfrey66/dc0c5a808f8866ef7622a7f2889d1d1a to your computer and use it in GitHub Desktop.
Save sgodfrey66/dc0c5a808f8866ef7622a7f2889d1d1a to your computer and use it in GitHub Desktop.
class VGG16Test(tf.keras.models.Sequential):
"""Construct a Sequential model using transfer learning with
VGG16 as the base.
This class uses some transfer learning and follows the work of Dr. Sivarama Krishnan Rajaraman, et al in
using 'pre-trained convolutional neural networks' to detect malaria infections in thin blood smear samples;
specifically, the pretrained VGG16 model.
As is the case with SequentialTest, we're building a Sequential model in which the first layer is a part of
the VGG16 model. Therefore, we can follow the same approach in which this class inherits from the
Sequential class and the model is completely built a model in the __init__() method.
https://lhncbc.nlm.nih.gov/publication/pub9932
https://github.com/sivaramakrishnan-rajaraman/CNN-for-malaria-parasite-detection
https://androidkt.com/how-to-use-vgg-model-in-tensorflow-keras/
https://keras.io/api/applications/vgg/#vgg16-function
https://arxiv.org/abs/1409.1556
"""
# Model name
model_name=''
# Loss function
loss_object=None
# optimizer
optimizer=None
def __init__(self,
model_name: str,
input_shape: tuple,
num_classes: int):
"""
Args:
model_name: str
A name for this model
input_shape: tuple
The input_layer's input shape, (image_height, image_width, channels)
num_classes: int
The number of classes in this classification task
"""
# Use the super function to access methods in the superclass, Keras's Model
super(VGG16Test, self).__init__()
# Following Rajaraman, et al build a model starting with DenseNet
base_model = applications.VGG16(weights='imagenet',
include_top=False,
input_shape=input_shape)
# And a particular layer
base_output=base_model.get_layer('block5_conv2').output
base_model=Model(inputs=base_model.input, outputs=base_output)
# Freeze the base_model so these layers are not trained
base_model.trainable = False
# Add base_model to the class
self.add(base_model)
# Add a global average pooling layer
self.add(tf.keras.layers.GlobalAveragePooling2D())
# Add a dense layer
self.add(tf.keras.layers.Dense(units=1024, activation='relu'))
# Add a dropout layer
self.add(tf.keras.layers.Dropout(rate=0.5))
# Add the predictions (dense) layer
self.add(tf.keras.layers.Dense(units=num_classes, activation='softmax'))
# Set the loss object
self.loss_object=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Set the optimizer
self.optimizer=tf.keras.optimizers.SGD(lr=0.00001, decay=1e-6,
momentum=0.9, nesterov=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment