Skip to content

Instantly share code, notes, and snippets.

@bsridatta
Created May 28, 2019 16:25
Show Gist options
  • Save bsridatta/ce0fdcf4ce6aca425416c64e9ca03e37 to your computer and use it in GitHub Desktop.
Save bsridatta/ce0fdcf4ce6aca425416c64e9ca03e37 to your computer and use it in GitHub Desktop.
Creating TFRecords Dataset
import tensorflow as tf
SHUFFLE_BUFFER = 1000
BATCH_SIZE = 32
NUM_CLASSES = 12
# Create a description of the features.
feature_description = {
'feature0': tf.FixedLenFeature([32768], tf.float32),
'feature1': tf.FixedLenFeature([1], tf.int64)
}
def _parse_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
parsed_example = tf.parse_single_example(example_proto, feature_description)
parsed_example["feature0"] = tf.transpose(tf.reshape(parsed_example['feature0'], (256,128)))
return parsed_example
def create_dataset(filepath):
dataset = tf.data.TFRecordDataset(filepath)
dataset = dataset.map(_parse_function) #, num_parallel_calls=8)
# This dataset will go on forever
dataset = dataset.repeat()
# Set the number of datapoints you want to load and shuffle
dataset = dataset.shuffle(SHUFFLE_BUFFER)
dataset = dataset.batch(BATCH_SIZE)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Create your tf representation of the iterator
feature = iterator.get_next()
#print(feature)
lmfcc = feature["feature0"]
label = feature["feature1"]
# Bring your picture back in shape
lmfcc = tf.reshape(lmfcc, [1,128, 256])
# Create a one hot array for your labels
label = tf.one_hot(label, NUM_CLASSES)
label = tf.reshape(label, [1,1, 12])
print(lmfcc.shape)
print(label.shape)
return lmfcc, label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment