Created
January 28, 2021 19:13
-
-
Save ls-da3m0ns/2f39aed465ca0e57d2c5995d8903fa36 to your computer and use it in GitHub Desktop.
data loading script for TFrecords
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def to_float32_2(image, label): | |
max_val = tf.reduce_max(label, axis=-1,keepdims=True) | |
cond = tf.equal(label, max_val) | |
label = tf.where(cond, tf.ones_like(label), tf.zeros_like(label)) | |
return tf.cast(image, tf.float32), tf.cast(label, tf.int32) | |
def to_float32(image, label): | |
return tf.cast(image, tf.float32), label | |
def decode_image(image_data): | |
image = tf.image.decode_jpeg(image_data, channels=3) | |
image = tf.cast(image, tf.float32) / 255.0 # convert image to floats in [0, 1] range | |
image = tf.reshape(image, [1024,1024, 3]) # explicit size needed for TPU | |
return image | |
def read_labeled_tfrecord(example): | |
# Create a dictionary describing the features. | |
LABELED_TFREC_FORMAT = { | |
"StudyInstanceUID" : tf.io.FixedLenFeature([], tf.string), | |
"image" : tf.io.FixedLenFeature([], tf.string), | |
"ETT - Abnormal" : tf.io.FixedLenFeature([], tf.int64), | |
"ETT - Borderline" : tf.io.FixedLenFeature([], tf.int64), | |
"ETT - Normal" : tf.io.FixedLenFeature([], tf.int64), | |
"NGT - Abnormal" : tf.io.FixedLenFeature([], tf.int64), | |
"NGT - Borderline" : tf.io.FixedLenFeature([], tf.int64), | |
"NGT - Incompletely Imaged" : tf.io.FixedLenFeature([], tf.int64), | |
"NGT - Normal" : tf.io.FixedLenFeature([], tf.int64), | |
"CVC - Abnormal" : tf.io.FixedLenFeature([], tf.int64), | |
"CVC - Borderline" : tf.io.FixedLenFeature([], tf.int64), | |
"CVC - Normal" : tf.io.FixedLenFeature([], tf.int64), | |
"Swan Ganz Catheter Present" : tf.io.FixedLenFeature([], tf.int64), | |
} | |
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT) | |
image = decode_image(example['image']) | |
image= tf.image.resize(image, [IMAGE_SIZE[0],IMAGE_SIZE[0]]) | |
uid= example["StudyInstanceUID"] | |
cvca = example["CVC - Abnormal"] | |
cvcb = example["CVC - Borderline"] | |
cvcn = example["CVC - Normal"] | |
etta = example["ETT - Abnormal"] | |
ettb = example["ETT - Borderline"] | |
ettn = example["ETT - Normal"] | |
ngta = example["NGT - Abnormal"] | |
ngtb = example["NGT - Borderline"] | |
ngti = example["NGT - Incompletely Imaged"] | |
ngtn = example["NGT - Normal"] | |
sgcp = example["Swan Ganz Catheter Present"] | |
label = [etta, ettb, ettn, ngta, ngtb, ngti, ngtn,cvca, cvcb, cvcn , sgcp] | |
label=[tf.cast(i,tf.float32) for i in label] | |
return image,label # returns a dataset of (image, label) pairs | |
def read_unlabeled_tfrecord(example): | |
UNLABELED_TFREC_FORMAT = { | |
"StudyInstanceUID" : tf.io.FixedLenFeature([], tf.string), | |
"image" : tf.io.FixedLenFeature([], tf.string) | |
} | |
example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT) | |
image = decode_image(example['image']) | |
image= tf.image.resize(image, [IMAGE_SIZE[0],IMAGE_SIZE[0]]) | |
image_name = example['StudyInstanceUID'] | |
return image, image_name # returns a dataset of image(s) | |
def read_labeled_tf_record(filenames, labeled=True, ordered=False): | |
ignore_order = tf.data.Options() | |
if not ordered: | |
ignore_order.experimental_deterministic = False # disable order, increase speed | |
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files | |
dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order | |
dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO) | |
return dataset | |
def data_augment(image, label): | |
image = tf.image.random_flip_left_right(image , seed=SEED) | |
image = tf.image.random_flip_up_down(image, seed=SEED) | |
image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness | |
image = tf.image.random_saturation(image, 0, 2, seed=SEED) | |
image = tf.image.adjust_saturation(image, 3) | |
#image = tf.image.central_crop(image, central_fraction=0.5) | |
return image, label | |
def get_training_dataset(dataset): | |
dataset = dataset.shuffle(2048) | |
dataset = dataset.repeat() # the training dataset must repeat for several epochs | |
dataset = dataset.map(data_augment, num_parallel_calls=AUTO) | |
dataset = dataset.batch(BATCH_SIZE) | |
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size) | |
return dataset | |
def get_validation_dataset(ordered=False): | |
dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered) | |
dataset = dataset.batch(BATCH_SIZE) | |
dataset = dataset.cache() | |
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size) | |
return dataset | |
def get_test_dataset(ordered=False): | |
dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered) | |
dataset = dataset.batch(BATCH_SIZE) | |
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size) | |
return dataset | |
def load_dataset(filenames, labeled = True, ordered = False): | |
ignore_order = tf.data.Options() | |
if not ordered: | |
ignore_order.experimental_deterministic = False | |
dataset = (tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO).with_options(ignore_order). | |
map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls = AUTO)) | |
return dataset | |
training_filenames = [] | |
training_filenames.append(GCS_DS_PATH + '/train_tfrecords/*.tfrec') | |
TRAINING_FILENAMES = tf.io.gfile.glob(training_filenames) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment