|
# Copyright (c) Facebook, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
#MIT License |
|
# |
|
#Copyright (c) Facebook, Inc. and its affiliates. |
|
# |
|
#Permission is hereby granted, free of charge, to any person obtaining a copy |
|
#of this software and associated documentation files (the "Software"), to deal |
|
#in the Software without restriction, including without limitation the rights |
|
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
#copies of the Software, and to permit persons to whom the Software is |
|
#furnished to do so, subject to the following conditions: |
|
# |
|
#The above copyright notice and this permission notice shall be included in all |
|
#copies or substantial portions of the Software. |
|
# |
|
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
#SOFTWARE. |
|
""" |
|
This script creates realistic mixes with stems from different songs. |
|
In particular, it will align BPM, sync up the first beat and perform pitch |
|
shift to maximize pitches overlap. |
|
In order to limit artifacts, only parts that can be mixed with less than 15% |
|
tempo shift, and 3 semitones of pitch shift are mixed together. |
|
""" |
|
from collections import namedtuple |
|
from concurrent.futures import ProcessPoolExecutor |
|
import hashlib |
|
from pathlib import Path |
|
import random |
|
import shutil |
|
import tqdm |
|
import pickle |
|
|
|
from librosa.beat import beat_track |
|
from librosa.feature import chroma_cqt |
|
import numpy as np |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from dora.utils import try_load |
|
from demucs.audio import save_audio |
|
from demucs.repitch import repitch |
|
from demucs.pretrained import SOURCES |
|
from demucs.wav import build_metadata, Wavset, _get_musdb_valid |
|
|
|
|
|
MUSDB_PATH = '/checkpoint/defossez/datasets/musdbhq' |
|
EXTRA_WAV_PATH = "/checkpoint/defossez/datasets/allstems_44" |
|
# WARNING: OUTPATH will be completely erased. |
|
OUTPATH = Path.home() / 'tmp/demucs_mdx/automix_musdb/' |
|
CACHE = Path.home() / 'tmp/automix_cache' # cache BPM and pitch information. |
|
CHANNELS = 2 |
|
SR = 44100 |
|
MAX_PITCH = 3 # maximum allowable pitch shift in semi tones |
|
MAX_TEMPO = 0.15 # maximum allowable tempo shift |
|
|
|
|
|
Spec = namedtuple("Spec", "tempo onsets kr track index") |
|
|
|
|
|
def rms(wav, window=10000): |
|
"""efficient rms computed for each time step over a given window.""" |
|
half = window // 2 |
|
window = 2 * half + 1 |
|
wav = F.pad(wav, (half, half)) |
|
tot = wav.pow(2).cumsum(dim=-1) |
|
return ((tot[..., window - 1:] - tot[..., :-window + 1]) / window).sqrt() |
|
|
|
|
|
def analyse_track(dset, index): |
|
"""analyse track, extract bpm and distribution of notes from the bass line.""" |
|
track = dset[index] |
|
mix = track.sum(0).mean(0) |
|
ref = mix.std() |
|
|
|
starts = (abs(mix) >= 1e-2 * ref).float().argmax().item() |
|
track = track[..., starts:] |
|
|
|
cache = CACHE / dset.sig |
|
cache.mkdir(exist_ok=True, parents=True) |
|
|
|
cache_file = cache / f"{index}.pkl" |
|
cached = None |
|
if cache_file.exists(): |
|
cached = try_load(cache_file) |
|
if cached is not None: |
|
tempo, events, hist_kr = cached |
|
|
|
if cached is None: |
|
drums = track[0].mean(0) |
|
if drums.std() > 1e-2 * ref: |
|
tempo, events = beat_track(drums.numpy(), units='time', sr=SR) |
|
else: |
|
print("failed drums", drums.std(), ref) |
|
return None, track |
|
|
|
bass = track[1].mean(0) |
|
r = rms(bass) |
|
peak = r.max() |
|
mask = r >= 0.05 * peak |
|
bass = bass[mask] |
|
if bass.std() > 1e-2 * ref: |
|
kr = torch.from_numpy(chroma_cqt(bass.numpy(), sr=SR)) |
|
hist_kr = (kr.max(dim=0, keepdim=True)[0] == kr).float().mean(1) |
|
else: |
|
print("failed bass", bass.std(), ref) |
|
return None, track |
|
|
|
pickle.dump([tempo, events, hist_kr], open(cache_file, 'wb')) |
|
spec = Spec(tempo, events, hist_kr, track, index) |
|
return spec, None |
|
|
|
|
|
def best_pitch_shift(kr_a, kr_b): |
|
"""find the best pitch shift between two chroma distributions.""" |
|
deltas = [] |
|
for p in range(12): |
|
deltas.append((kr_a - kr_b).abs().mean()) |
|
kr_b = kr_b.roll(1, 0) |
|
|
|
ps = np.argmin(deltas) |
|
if ps > 6: |
|
ps = ps - 12 |
|
return ps |
|
|
|
|
|
def align_stems(stems): |
|
"""Align the first beats of the stems. |
|
This is a naive implementation. A grid with a time definition 10ms is defined and |
|
each beat onset is represented as a gaussian over this grid. |
|
Then, we try each possible time shift to make two grids align the best. |
|
We repeat for all sources. |
|
""" |
|
sources = len(stems) |
|
width = 5e-3 # grid of 10ms |
|
limit = 5 |
|
std = 2 |
|
x = torch.arange(-limit, limit + 1, 1).float() |
|
gauss = torch.exp(-x**2 / (2 * std**2)) |
|
|
|
grids = [] |
|
for wav, onsets in stems: |
|
le = wav.shape[-1] |
|
dur = le / SR |
|
grid = torch.zeros(int(le / width / SR)) |
|
for onset in onsets: |
|
pos = int(onset / width) |
|
if onset >= dur - 1: |
|
continue |
|
if onset < 1: |
|
continue |
|
grid[pos - limit:pos + limit + 1] += gauss |
|
grids.append(grid) |
|
|
|
shifts = [0] |
|
for s in range(1, sources): |
|
max_shift = int(4 / width) |
|
dots = [] |
|
for shift in range(-max_shift, max_shift): |
|
other = grids[s] |
|
ref = grids[0] |
|
if shift >= 0: |
|
other = other[shift:] |
|
else: |
|
ref = ref[shift:] |
|
le = min(len(other), len(ref)) |
|
dots.append((ref[:le].dot(other[:le]), int(shift * width * SR))) |
|
|
|
_, shift = max(dots) |
|
shifts.append(-shift) |
|
|
|
outs = [] |
|
new_zero = min(shifts) |
|
for (wav, _), shift in zip(stems, shifts): |
|
offset = shift - new_zero |
|
wav = F.pad(wav, (offset, 0)) |
|
outs.append(wav) |
|
|
|
le = min(x.shape[-1] for x in outs) |
|
|
|
outs = [w[..., :le] for w in outs] |
|
return torch.stack(outs) |
|
|
|
|
|
def find_candidate(spec_ref, catalog, pitch_match=True): |
|
"""Given reference track, this finds a track in the catalog that |
|
is a potential match (pitch and tempo delta must be within the allowable limits). |
|
""" |
|
candidates = list(catalog) |
|
random.shuffle(candidates) |
|
|
|
for spec in candidates: |
|
ok = False |
|
for scale in [1/4, 1/2, 1, 2, 4]: |
|
tempo = spec.tempo * scale |
|
delta_tempo = spec_ref.tempo / tempo - 1 |
|
if abs(delta_tempo) < MAX_TEMPO: |
|
ok = True |
|
break |
|
if not ok: |
|
print(delta_tempo, spec_ref.tempo, spec.tempo, "FAILED TEMPO") |
|
# too much of a tempo difference |
|
continue |
|
spec = spec._replace(tempo=tempo) |
|
|
|
ps = 0 |
|
if pitch_match: |
|
ps = best_pitch_shift(spec_ref.kr, spec.kr) |
|
if abs(ps) > MAX_PITCH: |
|
print("Failed pitch", ps) |
|
# too much pitch difference |
|
continue |
|
return spec, delta_tempo, ps |
|
|
|
|
|
def get_part(spec, source, dt, dp): |
|
"""Apply given delta of tempo and delta of pitch to a stem.""" |
|
wav = spec.track[source] |
|
if dt or dp: |
|
wav = repitch(wav, dp, dt * 100, samplerate=SR, voice=source == 3) |
|
spec = spec._replace(onsets=spec.onsets / (1 + dt)) |
|
return wav, spec |
|
|
|
|
|
def build_track(ref_index, catalog): |
|
"""Given the reference track index and a catalog of track, builds |
|
a completely new track. One of the source at random from the ref track will |
|
be kept and other sources will be drawn from the catalog. |
|
""" |
|
order = list(range(len(SOURCES))) |
|
random.shuffle(order) |
|
|
|
stems = [None] * len(order) |
|
indexes = [None] * len(order) |
|
origs = [None] * len(order) |
|
dps = [None] * len(order) |
|
dts = [None] * len(order) |
|
|
|
first = order[0] |
|
spec_ref = catalog[ref_index] |
|
stems[first] = (spec_ref.track[first], spec_ref.onsets) |
|
indexes[first] = ref_index |
|
origs[first] = spec_ref.track[first] |
|
dps[first] = 0 |
|
dts[first] = 0 |
|
|
|
pitch_match = order != 0 |
|
|
|
for src in order[1:]: |
|
spec, dt, dp = find_candidate(spec_ref, catalog, pitch_match=pitch_match) |
|
if not pitch_match: |
|
spec_ref = spec_ref._replace(kr=spec.kr) |
|
pitch_match = True |
|
dps[src] = dp |
|
dts[src] = dt |
|
wav, spec = get_part(spec, src, dt, dp) |
|
stems[src] = (wav, spec.onsets) |
|
indexes[src] = spec.index |
|
origs.append(spec.track[src]) |
|
print("FINAL CHOICES", ref_index, indexes, dps, dts) |
|
stems = align_stems(stems) |
|
return stems, origs |
|
|
|
|
|
def get_musdb_dataset(part='train'): |
|
root = Path(MUSDB_PATH) / part |
|
ext = '.wav' |
|
metadata = build_metadata(root, SOURCES, ext=ext, normalize=False) |
|
valid_tracks = _get_musdb_valid() |
|
metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} |
|
train_set = Wavset( |
|
root, metadata_train, SOURCES, samplerate=SR, channels=CHANNELS, |
|
normalize=False, ext=ext) |
|
sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] |
|
train_set.sig = sig |
|
return train_set |
|
|
|
|
|
def get_wav_dataset(): |
|
root = Path(EXTRA_WAV_PATH) |
|
ext = '.wav' |
|
metadata = _build_metadata(root, SOURCES, ext=ext, normalize=False) |
|
train_set = Wavset( |
|
root, metadata, SOURCES, samplerate=SR, channels=CHANNELS, |
|
normalize=False, ext=ext) |
|
sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] |
|
train_set.sig = sig |
|
return train_set |
|
|
|
|
|
def main(): |
|
random.seed(4321) |
|
if OUTPATH.exists(): |
|
shutil.rmtree(OUTPATH) |
|
OUTPATH.mkdir(exist_ok=True, parents=True) |
|
(OUTPATH / 'train').mkdir(exist_ok=True, parents=True) |
|
(OUTPATH / 'valid').mkdir(exist_ok=True, parents=True) |
|
out = OUTPATH / 'train' |
|
|
|
dset = get_musdb_dataset() |
|
# dset2 = get_wav_dataset() |
|
# dset3 = get_musdb_dataset('test') |
|
dset2 = None |
|
dset3 = None |
|
pendings = [] |
|
copies = 6 |
|
copies_rej = 2 |
|
|
|
with ProcessPoolExecutor(20) as pool: |
|
for index in range(len(dset)): |
|
pendings.append(pool.submit(analyse_track, dset, index)) |
|
|
|
if dset2: |
|
for index in range(len(dset2)): |
|
pendings.append(pool.submit(analyse_track, dset2, index)) |
|
if dset3: |
|
for index in range(len(dset3)): |
|
pendings.append(pool.submit(analyse_track, dset3, index)) |
|
|
|
catalog = [] |
|
rej = 0 |
|
for pending in tqdm.tqdm(pendings, ncols=120): |
|
spec, track = pending.result() |
|
if spec is not None: |
|
catalog.append(spec) |
|
else: |
|
mix = track.sum(0) |
|
for copy in range(copies_rej): |
|
folder = out / f'rej_{rej}_{copy}' |
|
folder.mkdir() |
|
save_audio(mix, folder / "mixture.wav", SR) |
|
for stem, source in zip(track, SOURCES): |
|
save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') |
|
rej += 1 |
|
|
|
for copy in range(copies): |
|
for index in range(len(catalog)): |
|
track, origs = build_track(index, catalog) |
|
mix = track.sum(0) |
|
mx = mix.abs().max() |
|
scale = max(1, 1.01 * mx) |
|
mix = mix / scale |
|
track = track / scale |
|
folder = out / f'{copy}_{index}' |
|
folder.mkdir() |
|
save_audio(mix, folder / "mixture.wav", SR) |
|
for stem, source, orig in zip(track, SOURCES, origs): |
|
save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') |
|
# save_audio(stem.std() * orig / (1e-6 + orig.std()), folder / f"{source}_orig.wav", |
|
# SR, clip='clamp') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |