Skip to content

Instantly share code, notes, and snippets.

@faroit
Last active May 9, 2018 23:55
Show Gist options
  • Save faroit/a692f1a6d8ecae8f1312131cf12c1247 to your computer and use it in GitHub Desktop.
Save faroit/a692f1a6d8ecae8f1312131cf12c1247 to your computer and use it in GitHub Desktop.
import numpy as np
import pescador
import torch.utils.data
np.random.seed(42)
nb_tracks = 100
track_length = 30
excerpt_length = 10
excerpt_hop = 2
batch_size = 5
# define audio tracks as (nb_samples, nb_features)
tracks = [np.random.random((track_length, 1)) for i in range(nb_tracks)]
class TrackData(torch.utils.data.Dataset):
def __init__(self, tracks):
self.tracks = tracks
self.streams = [pescador.Streamer(excerpt_gen, fn) for fn in tracks]
self.mux = pescador.tuples(
pescador.StochasticMux(
self.streams, nb_tracks, rate=None, mode='exhaustive'
),
'X', 'X'
)
def __len__(self):
return len(self.tracks) * (track_length // excerpt_hop) - excerpt_length
def __iter__(self):
return self.mux.iterate()
def __getitem__(self, idx):
return next(self.mux)
def excerpt_gen(data):
for i in range(0, data.shape[0] - excerpt_length, excerpt_hop):
yield dict(X=data[i:i+excerpt_length, :])
dataset = TrackData(tracks)
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size
)
for batch, (X, y) in enumerate(train_loader):
print(X.mean())
import numpy as np
import pescador
np.random.seed(42)
nb_tracks = 100
track_length = 30
excerpt_length = 10
excerpt_hop = 2
batch_size = 5
# define audio tracks as (nb_samples, nb_features)
tracks = [np.random.random((track_length, 1)) for i in range(nb_tracks)]
# yield excerpts from audio tracks
def excerpt_gen(data):
for i in range(0, data.shape[0] - excerpt_length, excerpt_hop):
yield dict(X=data[i:i+excerpt_length, :])
# set up track streamers
streams = [pescador.Streamer(excerpt_gen, track) for track in tracks]
# randomly sample from streamers
mux = pescador.StochasticMux(streams, nb_tracks, rate=None, mode='exhaustive')
buffered_sample_gen = pescador.buffer_stream(mux, batch_size)
# iterate over data
for batch in buffered_sample_gen:
print(batch['X'].mean())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment