-
-
Save haimat/d5f179b23e61c2b80ba424f988b90c9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train a Keras model on multiple GPUs in parallel, using TF Dataset slices. | |
from tensorflow.keras.applications import VGG16 | |
from tensorflow.keras.applications import vgg16 | |
from tensorflow.keras import models | |
from tensorflow.keras import layers | |
from sklearn.preprocessing import LabelEncoder | |
from collections import Counter | |
from imutils import paths | |
import tensorflow as tf | |
import numpy as np | |
# Download Kaggle Cats vs. Dogs - https://www.kaggle.com/c/dogs-vs-cats/data?select=train.zip | |
TRAIN_DS_PATH = "/tmp/catsvsdogs" | |
IMAGE_CLASSES = ["cat", "dog"] | |
TRAIN_SPLIT = 0.8 | |
BATCH_SIZE = 32 | |
IMG_SIZE = 224 | |
EPOCHS = 10 | |
def build_model(): | |
input_layer = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name="model_input") | |
base_model = VGG16(weights="imagenet", include_top=False, input_tensor=input_layer) | |
base_model.trainable = False | |
model_head = base_model.output | |
model_head = layers.AveragePooling2D(name="custom_header", pool_size=(4, 4))(model_head) | |
model_head = layers.Flatten()(model_head) | |
model_head = layers.Dense(256, activation="relu")(model_head) | |
model_head = layers.Dropout(0.5)(model_head) | |
model_head = layers.Dense(len(IMAGE_CLASSES), activation="softmax")(model_head) | |
return models.Model(inputs=input_layer, outputs=model_head) | |
@tf.function | |
def load_images(image_path, label): | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_jpeg(image, channels=3) | |
image = vgg16.preprocess_input(image) | |
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) | |
image = tf.cast(image, tf.float32) / 255.0 | |
return (image, label) | |
# Get all images and ther labels/classes | |
labels = [] | |
image_paths = list(paths.list_images(TRAIN_DS_PATH, contains=".jpg")) | |
np.random.shuffle(image_paths) | |
for image_path in image_paths: | |
label = image_path.split("/")[3] | |
labels.append(label) | |
class_count = Counter(labels) | |
# Split images into training and validation sets | |
i = int(len(image_paths) * TRAIN_SPLIT) | |
train_paths = image_paths[:i] | |
train_labels = labels[:i] | |
validation_paths = image_paths[i:] | |
validation_labels = labels[i:] | |
# Label the dataset | |
le = LabelEncoder() | |
train_labels_le = tf.one_hot(le.fit_transform(train_labels), depth=len(IMAGE_CLASSES)) | |
validation_labels_le = tf.one_hot(le.fit_transform(validation_labels), depth=len(IMAGE_CLASSES)) | |
training_dataset = tf.data.Dataset.from_tensor_slices((train_paths, train_labels_le)) | |
training_dataset = ( | |
training_dataset.shuffle(1024) | |
.map(load_images, num_parallel_calls=tf.data.AUTOTUNE) | |
.batch(BATCH_SIZE) | |
# .map(lambda x, y: (data_augmentation(x), y), num_parallel_calls=tf.data.AUTOTUNE) | |
.prefetch(tf.data.AUTOTUNE) | |
) | |
validation_dataset = tf.data.Dataset.from_tensor_slices((validation_paths, validation_labels_le)) | |
validation_dataset = ( | |
validation_dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE) | |
.batch(BATCH_SIZE) | |
.prefetch(tf.data.AUTOTUNE) | |
) | |
# Set the sharding policy to DATA | |
options = tf.data.Options() | |
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA | |
training_dataset.with_options(options) | |
validation_dataset.with_options(options) | |
strategy = tf.distribute.MirroredStrategy() | |
with strategy.scope(): | |
keras_model = build_model() | |
keras_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | |
keras_model.fit(training_dataset, validation_data=validation_dataset, batch_size=BATCH_SIZE, epochs=EPOCHS) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment