Last active
June 18, 2022 18:54
-
-
Save marquisvictor/ce107431fd7bcbd5bbc635be53acc4d1 to your computer and use it in GitHub Desktop.
Creating a TF.data.Dataset object and training a model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Hi Ore, | |
I think I need to do a medium write up on this, because most of the tutorials require manipulating tensorflow's data classes directory. But here's a simple intuitive example I used. | |
First you load your csv to pandas. | |
Next, you add another column to your dataframe which are the paths to your images. | |
then you do your train_test_split() but here you'll pass in the (paths, labels) along with other parameters such as random_state, test_size etc etc | |
After doing that, you should add and run the following functions in the next cell: | |
def decode_image(path, label=None): | |
bits = tf.io.read_file(path) | |
image = tf.image.decode_jpeg(bits, channels=3) # detect and read images | |
image = tf.cast(image, tf.float32) / 225 # standardizes pixel value between 0 and 1 | |
image = tf.image.resize(image, image_size) | |
if label is None: | |
return image | |
else: | |
return image, label | |
The decode function allows tf to read the image in its own format | |
Then optionally, you can add the augmentation function below: | |
def data_augment(image, label=None): | |
image = tf.image.random_flip_up_down(image) | |
image = tf.image.random_flip_left_right(image) | |
image = tf.image.random_brightness(image, max_delta=0.5247078) | |
image = tf.image.random_saturation(image, 0.3824261, 1.4029386) | |
image = tf.image.random_hue(image, 0.1267652) | |
image = tf.image.random_contrast(image, 0.3493415, 1.3461331) | |
image = tf.clip_by_value(image, 0.0, 1.0) | |
if label is None: | |
return image | |
else: | |
return image, label | |
After doing that, you then create the tf.data.Dataset object like this: | |
train_dataset = ( | |
tf.data.Dataset | |
.from_tensor_slices((x_train, y_train)) | |
.shuffle(10000) | |
.map(decode_image, num_parallel_calls=AUTO) | |
.map(data_augment, num_parallel_calls=AUTO) | |
.repeat() | |
.batch(batch_size,drop_remainder=True) | |
.prefetch(AUTO) | |
) | |
val_dataset = ( | |
tf.data.Dataset | |
.from_tensor_slices((x_test, y_test)) | |
.map(decode_image, num_parallel_calls=AUTO) | |
.batch(batch_size) | |
.cache() | |
) | |
# Then you build your model here | |
# Compile it here | |
model.compile(loss=tf.keras.losses.BinaryCrossentropy(), | |
metrics = [tf.keras.metrics.BinaryAccuracy(),], | |
optimizer=tf.keras.optimizers.Adam(lr=3e-5) | |
) | |
Then you fit the model. | |
That's where you pass in the train_dataset and the val_dataset object you created earlier. like this: | |
STEPS_PER_EPOCH = len(x_train)// batch_size | |
valid_step = len(x_test)// batch_size | |
history = model.fit( | |
train_dataset, | |
epochs=50, | |
steps_per_epoch=STEPS_PER_EPOCH, | |
validation_data=val_dataset, | |
validation_steps=valid_step | |
) | |
Let me know if you have any issues making it work. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment