Skip to content

Instantly share code, notes, and snippets.

@Ma5onic
Last active September 20, 2022 04:12
Show Gist options
  • Save Ma5onic/756cebcca27f7a619926a1e62fd24799 to your computer and use it in GitHub Desktop.
Save Ma5onic/756cebcca27f7a619926a1e62fd24799 to your computer and use it in GitHub Desktop.
This script creates realistic mixes with stems from different songs.
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
#MIT License
#
#Copyright (c) Facebook, Inc. and its affiliates.
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
"""
This script creates realistic mixes with stems from different songs.
In particular, it will align BPM, sync up the first beat and perform pitch
shift to maximize pitches overlap.
In order to limit artifacts, only parts that can be mixed with less than 15%
tempo shift, and 3 semitones of pitch shift are mixed together.
"""
from collections import namedtuple
from concurrent.futures import ProcessPoolExecutor
import hashlib
from pathlib import Path
import random
import shutil
import tqdm
import pickle
from librosa.beat import beat_track
from librosa.feature import chroma_cqt
import numpy as np
import torch
from torch.nn import functional as F
from dora.utils import try_load
from demucs.audio import save_audio
from demucs.repitch import repitch
from demucs.pretrained import SOURCES
from demucs.wav import build_metadata, Wavset, _get_musdb_valid
MUSDB_PATH = '/checkpoint/defossez/datasets/musdbhq'
EXTRA_WAV_PATH = "/checkpoint/defossez/datasets/allstems_44"
# WARNING: OUTPATH will be completely erased.
OUTPATH = Path.home() / 'tmp/demucs_mdx/automix_musdb/'
CACHE = Path.home() / 'tmp/automix_cache' # cache BPM and pitch information.
CHANNELS = 2
SR = 44100
MAX_PITCH = 3 # maximum allowable pitch shift in semi tones
MAX_TEMPO = 0.15 # maximum allowable tempo shift
Spec = namedtuple("Spec", "tempo onsets kr track index")
def rms(wav, window=10000):
"""efficient rms computed for each time step over a given window."""
half = window // 2
window = 2 * half + 1
wav = F.pad(wav, (half, half))
tot = wav.pow(2).cumsum(dim=-1)
return ((tot[..., window - 1:] - tot[..., :-window + 1]) / window).sqrt()
def analyse_track(dset, index):
"""analyse track, extract bpm and distribution of notes from the bass line."""
track = dset[index]
mix = track.sum(0).mean(0)
ref = mix.std()
starts = (abs(mix) >= 1e-2 * ref).float().argmax().item()
track = track[..., starts:]
cache = CACHE / dset.sig
cache.mkdir(exist_ok=True, parents=True)
cache_file = cache / f"{index}.pkl"
cached = None
if cache_file.exists():
cached = try_load(cache_file)
if cached is not None:
tempo, events, hist_kr = cached
if cached is None:
drums = track[0].mean(0)
if drums.std() > 1e-2 * ref:
tempo, events = beat_track(drums.numpy(), units='time', sr=SR)
else:
print("failed drums", drums.std(), ref)
return None, track
bass = track[1].mean(0)
r = rms(bass)
peak = r.max()
mask = r >= 0.05 * peak
bass = bass[mask]
if bass.std() > 1e-2 * ref:
kr = torch.from_numpy(chroma_cqt(bass.numpy(), sr=SR))
hist_kr = (kr.max(dim=0, keepdim=True)[0] == kr).float().mean(1)
else:
print("failed bass", bass.std(), ref)
return None, track
pickle.dump([tempo, events, hist_kr], open(cache_file, 'wb'))
spec = Spec(tempo, events, hist_kr, track, index)
return spec, None
def best_pitch_shift(kr_a, kr_b):
"""find the best pitch shift between two chroma distributions."""
deltas = []
for p in range(12):
deltas.append((kr_a - kr_b).abs().mean())
kr_b = kr_b.roll(1, 0)
ps = np.argmin(deltas)
if ps > 6:
ps = ps - 12
return ps
def align_stems(stems):
"""Align the first beats of the stems.
This is a naive implementation. A grid with a time definition 10ms is defined and
each beat onset is represented as a gaussian over this grid.
Then, we try each possible time shift to make two grids align the best.
We repeat for all sources.
"""
sources = len(stems)
width = 5e-3 # grid of 10ms
limit = 5
std = 2
x = torch.arange(-limit, limit + 1, 1).float()
gauss = torch.exp(-x**2 / (2 * std**2))
grids = []
for wav, onsets in stems:
le = wav.shape[-1]
dur = le / SR
grid = torch.zeros(int(le / width / SR))
for onset in onsets:
pos = int(onset / width)
if onset >= dur - 1:
continue
if onset < 1:
continue
grid[pos - limit:pos + limit + 1] += gauss
grids.append(grid)
shifts = [0]
for s in range(1, sources):
max_shift = int(4 / width)
dots = []
for shift in range(-max_shift, max_shift):
other = grids[s]
ref = grids[0]
if shift >= 0:
other = other[shift:]
else:
ref = ref[shift:]
le = min(len(other), len(ref))
dots.append((ref[:le].dot(other[:le]), int(shift * width * SR)))
_, shift = max(dots)
shifts.append(-shift)
outs = []
new_zero = min(shifts)
for (wav, _), shift in zip(stems, shifts):
offset = shift - new_zero
wav = F.pad(wav, (offset, 0))
outs.append(wav)
le = min(x.shape[-1] for x in outs)
outs = [w[..., :le] for w in outs]
return torch.stack(outs)
def find_candidate(spec_ref, catalog, pitch_match=True):
"""Given reference track, this finds a track in the catalog that
is a potential match (pitch and tempo delta must be within the allowable limits).
"""
candidates = list(catalog)
random.shuffle(candidates)
for spec in candidates:
ok = False
for scale in [1/4, 1/2, 1, 2, 4]:
tempo = spec.tempo * scale
delta_tempo = spec_ref.tempo / tempo - 1
if abs(delta_tempo) < MAX_TEMPO:
ok = True
break
if not ok:
print(delta_tempo, spec_ref.tempo, spec.tempo, "FAILED TEMPO")
# too much of a tempo difference
continue
spec = spec._replace(tempo=tempo)
ps = 0
if pitch_match:
ps = best_pitch_shift(spec_ref.kr, spec.kr)
if abs(ps) > MAX_PITCH:
print("Failed pitch", ps)
# too much pitch difference
continue
return spec, delta_tempo, ps
def get_part(spec, source, dt, dp):
"""Apply given delta of tempo and delta of pitch to a stem."""
wav = spec.track[source]
if dt or dp:
wav = repitch(wav, dp, dt * 100, samplerate=SR, voice=source == 3)
spec = spec._replace(onsets=spec.onsets / (1 + dt))
return wav, spec
def build_track(ref_index, catalog):
"""Given the reference track index and a catalog of track, builds
a completely new track. One of the source at random from the ref track will
be kept and other sources will be drawn from the catalog.
"""
order = list(range(len(SOURCES)))
random.shuffle(order)
stems = [None] * len(order)
indexes = [None] * len(order)
origs = [None] * len(order)
dps = [None] * len(order)
dts = [None] * len(order)
first = order[0]
spec_ref = catalog[ref_index]
stems[first] = (spec_ref.track[first], spec_ref.onsets)
indexes[first] = ref_index
origs[first] = spec_ref.track[first]
dps[first] = 0
dts[first] = 0
pitch_match = order != 0
for src in order[1:]:
spec, dt, dp = find_candidate(spec_ref, catalog, pitch_match=pitch_match)
if not pitch_match:
spec_ref = spec_ref._replace(kr=spec.kr)
pitch_match = True
dps[src] = dp
dts[src] = dt
wav, spec = get_part(spec, src, dt, dp)
stems[src] = (wav, spec.onsets)
indexes[src] = spec.index
origs.append(spec.track[src])
print("FINAL CHOICES", ref_index, indexes, dps, dts)
stems = align_stems(stems)
return stems, origs
def get_musdb_dataset(part='train'):
root = Path(MUSDB_PATH) / part
ext = '.wav'
metadata = build_metadata(root, SOURCES, ext=ext, normalize=False)
valid_tracks = _get_musdb_valid()
metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
train_set = Wavset(
root, metadata_train, SOURCES, samplerate=SR, channels=CHANNELS,
normalize=False, ext=ext)
sig = hashlib.sha1(str(root).encode()).hexdigest()[:8]
train_set.sig = sig
return train_set
def get_wav_dataset():
root = Path(EXTRA_WAV_PATH)
ext = '.wav'
metadata = _build_metadata(root, SOURCES, ext=ext, normalize=False)
train_set = Wavset(
root, metadata, SOURCES, samplerate=SR, channels=CHANNELS,
normalize=False, ext=ext)
sig = hashlib.sha1(str(root).encode()).hexdigest()[:8]
train_set.sig = sig
return train_set
def main():
random.seed(4321)
if OUTPATH.exists():
shutil.rmtree(OUTPATH)
OUTPATH.mkdir(exist_ok=True, parents=True)
(OUTPATH / 'train').mkdir(exist_ok=True, parents=True)
(OUTPATH / 'valid').mkdir(exist_ok=True, parents=True)
out = OUTPATH / 'train'
dset = get_musdb_dataset()
# dset2 = get_wav_dataset()
# dset3 = get_musdb_dataset('test')
dset2 = None
dset3 = None
pendings = []
copies = 6
copies_rej = 2
with ProcessPoolExecutor(20) as pool:
for index in range(len(dset)):
pendings.append(pool.submit(analyse_track, dset, index))
if dset2:
for index in range(len(dset2)):
pendings.append(pool.submit(analyse_track, dset2, index))
if dset3:
for index in range(len(dset3)):
pendings.append(pool.submit(analyse_track, dset3, index))
catalog = []
rej = 0
for pending in tqdm.tqdm(pendings, ncols=120):
spec, track = pending.result()
if spec is not None:
catalog.append(spec)
else:
mix = track.sum(0)
for copy in range(copies_rej):
folder = out / f'rej_{rej}_{copy}'
folder.mkdir()
save_audio(mix, folder / "mixture.wav", SR)
for stem, source in zip(track, SOURCES):
save_audio(stem, folder / f"{source}.wav", SR, clip='clamp')
rej += 1
for copy in range(copies):
for index in range(len(catalog)):
track, origs = build_track(index, catalog)
mix = track.sum(0)
mx = mix.abs().max()
scale = max(1, 1.01 * mx)
mix = mix / scale
track = track / scale
folder = out / f'{copy}_{index}'
folder.mkdir()
save_audio(mix, folder / "mixture.wav", SR)
for stem, source, orig in zip(track, SOURCES, origs):
save_audio(stem, folder / f"{source}.wav", SR, clip='clamp')
# save_audio(stem.std() * orig / (1e-6 + orig.std()), folder / f"{source}_orig.wav",
# SR, clip='clamp')
if __name__ == '__main__':
main()
import os
import subprocess as sp
import tempfile
import warnings
from argparse import ArgumentParser
from ast import literal_eval
import numpy as np
import soundfile as sf
import torch
from pathlib import Path
from tqdm import tqdm
warnings.simplefilter(action='ignore', category=Warning)
source_names = ['vocals', 'drums', 'bass', 'other']
sample_rate = 44100
def main (args):
data_dir = Path(args.data_dir)
train = args.train
test = args.test
P = [-2, -1, 0, 1, 2] # pitch shift amounts (in semitones)
T = [-20, -10, 0, 10, 20] # time stretch amounts (10 means 10% slower)
for p in P:
for t in T:
if not (p==0 and t==0):
if train:
save_shifted_dataset(p, t, data_dir, 'train')
if test:
save_shifted_dataset(p, t, data_dir, 'test')
def shift(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
def i16_pcm(wav):
if wav.dtype == np.int16:
return wav
return (wav * 2 ** 15).clamp_(-2 ** 15, 2 ** 15 - 1).short()
def f32_pcm(wav):
if wav.dtype == np.float:
return wav
return wav.float() / 2 ** 15
"""
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
pitch is in semi tones.
Requires `soundstretch` to be installed, see
https://www.surina.net/soundtouch/soundstretch.html
"""
inputfile = tempfile.NamedTemporaryFile(suffix=".wav")
outfile = tempfile.NamedTemporaryFile(suffix=".wav")
sf.write(inputfile.name, data=i16_pcm(wav).t().numpy(), samplerate=samplerate, format='WAV')
command = [
"soundstretch",
inputfile.name,
outfile.name,
f"-pitch={pitch}",
f"-tempo={tempo:.6f}",
]
if quick:
command += ["-quick"]
if voice:
command += ["-speech"]
try:
sp.run(command, capture_output=True, check=True)
except sp.CalledProcessError as error:
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
wav, sr = sf.read(outfile.name, dtype='float32')
# wav = np.float32(wav)
# wav = f32_pcm(torch.from_numpy(wav).t())
assert sr == samplerate
return wav
def save_shifted_dataset(delta_pitch, delta_tempo, data_dir, split):
aug_split = split + f'_p={delta_pitch}_t={delta_tempo}'
in_dir = data_dir.joinpath(split)
out_dir = data_dir.joinpath(aug_split)
if not out_dir.exists():
os.mkdir(out_dir)
track_names = os.listdir(in_dir)
for track_name in tqdm(track_names):
in_path = in_dir.joinpath(track_name)
out_path = out_dir.joinpath(track_name)
if not out_path.exists():
os.mkdir(out_path)
for s_name in source_names:
source = load_wav(in_path.joinpath(s_name+'.wav'))
shifted = shift(
torch.tensor(source),
delta_pitch,
delta_tempo,
voice=s_name == 'vocals')
sf.write(out_path.joinpath(s_name+'.wav'), shifted, samplerate=sample_rate, format='WAV')
def load_wav(path, sr=None):
return sf.read(path, samplerate=sr, dtype='float32')[0].T
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--data_dir', type=str)
parser.add_argument('--train', default=True, type=literal_eval)
parser.add_argument('--test', default=False, type=literal_eval)
main(parser.parse_args())

For the data augmentation script, you will also need to install the soundstretch library. Documentation

sudo apt install soundstretch

data_augmentation.py

This script was taken from the mdx-net-submission repo and was used in the Sony music demixing challenge (2021)

data_augmentation.py requires 1 argument (--data_dir), but has 3 in total:

    --data_dir [/the/absolute/path/where/datasets/are/stored]
    --train [True|False] Default: True
    --test  [True|False] Default: False

data_dir file structure:

The data directory is expected to have the same format as the MusdbHQ
/full/path/to/data_dir/
                  |
                  |_ test/
                  |   |
                  |   |_ songname/
                  |             |_ bass.wav
                  |             |_ drums.wav
                  |             |_ vocals.wav
                  |             |_ other.wav
                  |
                  |
                  |_ train/
                      |
                      |_ songname/
                               |_ bass.wav
                               |_ drums.wav
                               |_ vocals.wav
                               |_ other.wav

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment