Last active
June 16, 2020 15:40
-
-
Save sgodfrey66/dc0c5a808f8866ef7622a7f2889d1d1a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VGG16Test(tf.keras.models.Sequential): | |
"""Construct a Sequential model using transfer learning with | |
VGG16 as the base. | |
This class uses some transfer learning and follows the work of Dr. Sivarama Krishnan Rajaraman, et al in | |
using 'pre-trained convolutional neural networks' to detect malaria infections in thin blood smear samples; | |
specifically, the pretrained VGG16 model. | |
As is the case with SequentialTest, we're building a Sequential model in which the first layer is a part of | |
the VGG16 model. Therefore, we can follow the same approach in which this class inherits from the | |
Sequential class and the model is completely built a model in the __init__() method. | |
https://lhncbc.nlm.nih.gov/publication/pub9932 | |
https://github.com/sivaramakrishnan-rajaraman/CNN-for-malaria-parasite-detection | |
https://androidkt.com/how-to-use-vgg-model-in-tensorflow-keras/ | |
https://keras.io/api/applications/vgg/#vgg16-function | |
https://arxiv.org/abs/1409.1556 | |
""" | |
# Model name | |
model_name='' | |
# Loss function | |
loss_object=None | |
# optimizer | |
optimizer=None | |
def __init__(self, | |
model_name: str, | |
input_shape: tuple, | |
num_classes: int): | |
""" | |
Args: | |
model_name: str | |
A name for this model | |
input_shape: tuple | |
The input_layer's input shape, (image_height, image_width, channels) | |
num_classes: int | |
The number of classes in this classification task | |
""" | |
# Use the super function to access methods in the superclass, Keras's Model | |
super(VGG16Test, self).__init__() | |
# Following Rajaraman, et al build a model starting with DenseNet | |
base_model = applications.VGG16(weights='imagenet', | |
include_top=False, | |
input_shape=input_shape) | |
# And a particular layer | |
base_output=base_model.get_layer('block5_conv2').output | |
base_model=Model(inputs=base_model.input, outputs=base_output) | |
# Freeze the base_model so these layers are not trained | |
base_model.trainable = False | |
# Add base_model to the class | |
self.add(base_model) | |
# Add a global average pooling layer | |
self.add(tf.keras.layers.GlobalAveragePooling2D()) | |
# Add a dense layer | |
self.add(tf.keras.layers.Dense(units=1024, activation='relu')) | |
# Add a dropout layer | |
self.add(tf.keras.layers.Dropout(rate=0.5)) | |
# Add the predictions (dense) layer | |
self.add(tf.keras.layers.Dense(units=num_classes, activation='softmax')) | |
# Set the loss object | |
self.loss_object=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
# Set the optimizer | |
self.optimizer=tf.keras.optimizers.SGD(lr=0.00001, decay=1e-6, | |
momentum=0.9, nesterov=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment