Skip to content

Instantly share code, notes, and snippets.

@dschwertfeger
Created December 29, 2019 10:24
Show Gist options
  • Save dschwertfeger/c95fecff33767d07ea431e00a2554287 to your computer and use it in GitHub Desktop.
Save dschwertfeger/c95fecff33767d07ea431e00a2554287 to your computer and use it in GitHub Desktop.
import pandas as pd
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE
def get_dataset(df):
file_path_ds = tf.data.Dataset.from_tensor_slices(df.file_path)
label_ds = tf.data.Dataset.from_tensor_slices(df.label)
return tf.data.Dataset.zip((file_path_ds, label_ds))
def load_audio(file_path, label):
# Load one second of audio at 44.1kHz sample-rate
audio = tf.io.read_file(file_path)
audio, sample_rate = tf.audio.decode_wav(audio,
desired_channels=1,
desired_samples=44100)
return audio, label
def prepare_for_training(ds, shuffle_buffer_size=1024, batch_size=64):
# Randomly shuffle (file_path, label) dataset
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
# Load and decode audio from file paths
ds = ds.map(load_audio, num_parallel_calls=AUTOTUNE)
# Repeat dataset forever
ds = ds.repeat()
# Prepare batches
ds = ds.batch(batch_size)
# Prefetch
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
def main():
# Load meta.csv containing file-paths and labels as pd.DataFrame
df = pd.read_csv('meta.csv')
ds = get_dataset(df)
train_ds = prepare_for_training(ds)
batch_size = 64
train_steps = len(df) / batch_size
model = tf.keras.models.load_model('model.h5')
model.fit(train_ds, epochs=10, steps_per_epoch=train_steps)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment