Skip to content

Instantly share code, notes, and snippets.

@charleslparker
Created August 17, 2022 16:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save charleslparker/57778a1ffde2cf660fa741162c1eb399 to your computer and use it in GitHub Desktop.
Save charleslparker/57778a1ffde2cf660fa741162c1eb399 to your computer and use it in GitHub Desktop.
Practical Tensorflow Test For GPU
import os
import sys
URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
def run_gpu_test(use_cuda):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if not use_cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
path_to_zip = tf.keras.utils.get_file("cats_and_dogs.zip", origin=URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), "cats_and_dogs_filtered")
train_dir = os.path.join(PATH, "train")
valid_dir = os.path.join(PATH, "validation")
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
train_dataset = tf.keras.utils.image_dataset_from_directory(
train_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE
)
validation_dataset = tf.keras.utils.image_dataset_from_directory(
valid_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE
)
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
data_augmentation = tf.keras.Sequential(
[
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.2),
]
)
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.Rescaling(1.0 / 127.5, offset=-1)
# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(
input_shape=IMG_SHAPE, include_top=False, weights="imagenet"
)
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
base_model.trainable = False
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
base_learning_rate = 0.0001
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=["accuracy"],
)
history = model.fit(
train_dataset, verbose=2, epochs=32, validation_data=validation_dataset
)
def main():
gpu = False
if len(sys.argv) != 2 or sys.argv[1].lower() not in ["gpu", "cpu"]:
print("Usage: run_gpu_test gpu|cpu")
exit(1)
elif sys.argv[1].lower() == "gpu":
gpu = True
elif sys.argv[1].lower() == "cpu":
gpu = False
else:
raise ValueError("Something wrong with argument parsing")
run_gpu_test(gpu)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment