Skip to content

Instantly share code, notes, and snippets.

@idontcalculate
Last active May 17, 2024 19:58
Show Gist options
  • Save idontcalculate/41067003a3015bf0c10883d75e85f70a to your computer and use it in GitHub Desktop.
Save idontcalculate/41067003a3015bf0c10883d75e85f70a to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping
# Load the pre-trained VGG16 model without the top layer
base_model = VGG16(input_shape=(256, 256, 3), include_top=False, weights='imagenet')
# Freeze the convolutional base
base_model.trainable = False
# Add new layers on top of the pre-trained base
model = models.Sequential([
base_model,
layers.Flatten(),
layers.Dense(512, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid')
])
# Compile the model
model.compile(optimizer=optimizers.Adam(learning_rate=0.001), # Initial learning rate
loss='binary_crossentropy',
metrics=['accuracy'])
# Data augmentation
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
test_datagen = ImageDataGenerator(rescale=1./255)
# Load the training and validation data
train_dir = 'MURA-v1.1/train/XR_WRIST'
validation_dir = 'MURA-v1.1/valid/XR_WRIST'
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(256, 256),
batch_size=32,
class_mode='binary'
)
validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size=(256, 256),
batch_size=32,
class_mode='binary'
)
# Set up early stopping
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
# Train the model with data augmentation
history = model.fit(
train_generator,
epochs=10, # Initial number of epochs
validation_data=validation_generator,
callbacks=[early_stopping]
)
# Fine-tuning the model
# Unfreeze the top 4 convolutional layers of the VGG16 base model
base_model.trainable = True
for layer in base_model.layers[:-4]:
layer.trainable = False
# Compile the model with a lower learning rate
model.compile(optimizer=optimizers.Adam(learning_rate=1e-5), # Lower learning rate for fine-tuning
loss='binary_crossentropy',
metrics=['accuracy'])
# Continue training the model with fine-tuning
history_fine = model.fit(
train_generator,
epochs=10, # Additional epochs for fine-tuning
validation_data=validation_generator,
callbacks=[early_stopping]
)
# Save the trained model
from tensorflow.keras.models import save_model
save_model(model, "modelVGG16.h5")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment