Skip to content

Instantly share code, notes, and snippets.

@ls-da3m0ns
Created January 28, 2021 19:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ls-da3m0ns/2f39aed465ca0e57d2c5995d8903fa36 to your computer and use it in GitHub Desktop.
Save ls-da3m0ns/2f39aed465ca0e57d2c5995d8903fa36 to your computer and use it in GitHub Desktop.
data loading script for TFrecords
def to_float32_2(image, label):
max_val = tf.reduce_max(label, axis=-1,keepdims=True)
cond = tf.equal(label, max_val)
label = tf.where(cond, tf.ones_like(label), tf.zeros_like(label))
return tf.cast(image, tf.float32), tf.cast(label, tf.int32)
def to_float32(image, label):
return tf.cast(image, tf.float32), label
def decode_image(image_data):
image = tf.image.decode_jpeg(image_data, channels=3)
image = tf.cast(image, tf.float32) / 255.0 # convert image to floats in [0, 1] range
image = tf.reshape(image, [1024,1024, 3]) # explicit size needed for TPU
return image
def read_labeled_tfrecord(example):
# Create a dictionary describing the features.
LABELED_TFREC_FORMAT = {
"StudyInstanceUID" : tf.io.FixedLenFeature([], tf.string),
"image" : tf.io.FixedLenFeature([], tf.string),
"ETT - Abnormal" : tf.io.FixedLenFeature([], tf.int64),
"ETT - Borderline" : tf.io.FixedLenFeature([], tf.int64),
"ETT - Normal" : tf.io.FixedLenFeature([], tf.int64),
"NGT - Abnormal" : tf.io.FixedLenFeature([], tf.int64),
"NGT - Borderline" : tf.io.FixedLenFeature([], tf.int64),
"NGT - Incompletely Imaged" : tf.io.FixedLenFeature([], tf.int64),
"NGT - Normal" : tf.io.FixedLenFeature([], tf.int64),
"CVC - Abnormal" : tf.io.FixedLenFeature([], tf.int64),
"CVC - Borderline" : tf.io.FixedLenFeature([], tf.int64),
"CVC - Normal" : tf.io.FixedLenFeature([], tf.int64),
"Swan Ganz Catheter Present" : tf.io.FixedLenFeature([], tf.int64),
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
image = decode_image(example['image'])
image= tf.image.resize(image, [IMAGE_SIZE[0],IMAGE_SIZE[0]])
uid= example["StudyInstanceUID"]
cvca = example["CVC - Abnormal"]
cvcb = example["CVC - Borderline"]
cvcn = example["CVC - Normal"]
etta = example["ETT - Abnormal"]
ettb = example["ETT - Borderline"]
ettn = example["ETT - Normal"]
ngta = example["NGT - Abnormal"]
ngtb = example["NGT - Borderline"]
ngti = example["NGT - Incompletely Imaged"]
ngtn = example["NGT - Normal"]
sgcp = example["Swan Ganz Catheter Present"]
label = [etta, ettb, ettn, ngta, ngtb, ngti, ngtn,cvca, cvcb, cvcn , sgcp]
label=[tf.cast(i,tf.float32) for i in label]
return image,label # returns a dataset of (image, label) pairs
def read_unlabeled_tfrecord(example):
UNLABELED_TFREC_FORMAT = {
"StudyInstanceUID" : tf.io.FixedLenFeature([], tf.string),
"image" : tf.io.FixedLenFeature([], tf.string)
}
example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
image = decode_image(example['image'])
image= tf.image.resize(image, [IMAGE_SIZE[0],IMAGE_SIZE[0]])
image_name = example['StudyInstanceUID']
return image, image_name # returns a dataset of image(s)
def read_labeled_tf_record(filenames, labeled=True, ordered=False):
ignore_order = tf.data.Options()
if not ordered:
ignore_order.experimental_deterministic = False # disable order, increase speed
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
return dataset
def data_augment(image, label):
image = tf.image.random_flip_left_right(image , seed=SEED)
image = tf.image.random_flip_up_down(image, seed=SEED)
image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
image = tf.image.random_saturation(image, 0, 2, seed=SEED)
image = tf.image.adjust_saturation(image, 3)
#image = tf.image.central_crop(image, central_fraction=0.5)
return image, label
def get_training_dataset(dataset):
dataset = dataset.shuffle(2048)
dataset = dataset.repeat() # the training dataset must repeat for several epochs
dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset
def get_validation_dataset(ordered=False):
dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.cache()
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset
def get_test_dataset(ordered=False):
dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
return dataset
def load_dataset(filenames, labeled = True, ordered = False):
ignore_order = tf.data.Options()
if not ordered:
ignore_order.experimental_deterministic = False
dataset = (tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO).with_options(ignore_order).
map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls = AUTO))
return dataset
training_filenames = []
training_filenames.append(GCS_DS_PATH + '/train_tfrecords/*.tfrec')
TRAINING_FILENAMES = tf.io.gfile.glob(training_filenames)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment