-
-
Save mmeendez8/44712f11376486c2d8feb8c4c63b5493 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def parse_function(filename): | |
| image_string = tf.read_file(filename) | |
| image = tf.image.decode_png(image_string, channels=4) | |
| image = remove_noise(image) | |
| image = tf.image.convert_image_dtype(image, dtype=tf.float32) | |
| image.set_shape([48, 48, 3]) | |
| return image | |
| def load_and_process_data(filenames, batch_size, shuffle=True): | |
| ''' | |
| Reveices a list of filenames and returns preprocessed images as a tensorflow dataset | |
| :param filenames: list of file paths | |
| :param batch_size: mini-batch size | |
| :param shuffle: Boolean | |
| :return: | |
| ''' | |
| with tf.device('/cpu:0'): | |
| dataset = tf.data.Dataset.from_tensor_slices(filenames) | |
| dataset = dataset.map(parse_function, num_parallel_calls=4) | |
| if shuffle: | |
| dataset = dataset.shuffle(5000) # Number of imgs to keep in a buffer to randomly sample | |
| dataset = dataset.batch(batch_size) | |
| dataset = dataset.prefetch(2) | |
| return dataset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment