Skip to content

Instantly share code, notes, and snippets.

@ouor
Last active March 6, 2023 12:02
Show Gist options
  • Save ouor/f48b2815f61b03e9e4212b433edcd995 to your computer and use it in GitHub Desktop.
Save ouor/f48b2815f61b03e9e4212b433edcd995 to your computer and use it in GitHub Desktop.
import argparse
import os
import glob
import librosa
import numpy as np
import soundfile as sf
import torch
from tqdm import tqdm
from lib import dataset
from lib import nets
from lib import spec_utils
from lib import utils
class Separator(object):
def __init__(self, model, device, batchsize, cropsize, postprocess=False):
self.model = model
self.offset = model.offset
self.device = device
self.batchsize = batchsize
self.cropsize = cropsize
self.postprocess = postprocess
def _separate(self, X_mag_pad, roi_size):
X_dataset = []
patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size
for i in range(patches):
start = i * roi_size
X_mag_crop = X_mag_pad[:, :, start:start + self.cropsize]
X_dataset.append(X_mag_crop)
X_dataset = np.asarray(X_dataset)
self.model.eval()
with torch.no_grad():
mask = []
# To reduce the overhead, dataloader is not used.
for i in range(0, patches, self.batchsize):
X_batch = X_dataset[i: i + self.batchsize]
X_batch = torch.from_numpy(X_batch).to(self.device)
pred = self.model.predict_mask(X_batch)
pred = pred.detach().cpu().numpy()
pred = np.concatenate(pred, axis=2)
mask.append(pred)
mask = np.concatenate(mask, axis=2)
return mask
def _preprocess(self, X_spec):
X_mag = np.abs(X_spec)
X_phase = np.angle(X_spec)
return X_mag, X_phase
def _postprocess(self, mask, X_mag, X_phase):
if self.postprocess:
mask = spec_utils.merge_artifacts(mask)
y_spec = mask * X_mag * np.exp(1.j * X_phase)
v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
return y_spec, v_spec
def separate(self, X_spec):
X_mag, X_phase = self._preprocess(X_spec)
n_frame = X_mag.shape[2]
pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
X_mag_pad /= X_mag_pad.max()
mask = self._separate(X_mag_pad, roi_size)
mask = mask[:, :, :n_frame]
y_spec, v_spec = self._postprocess(mask, X_mag, X_phase)
return y_spec, v_spec
def separate_tta(self, X_spec):
X_mag, X_phase = self._preprocess(X_spec)
n_frame = X_mag.shape[2]
pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset)
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
X_mag_pad /= X_mag_pad.max()
mask = self._separate(X_mag_pad, roi_size)
pad_l += roi_size // 2
pad_r += roi_size // 2
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant')
X_mag_pad /= X_mag_pad.max()
mask_tta = self._separate(X_mag_pad, roi_size)
mask_tta = mask_tta[:, :, roi_size // 2:]
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
y_spec, v_spec = self._postprocess(mask, X_mag, X_phase)
return y_spec, v_spec
def main():
p = argparse.ArgumentParser()
p.add_argument('--cpu', '-c', action='store_true')
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth')
p.add_argument('--input-dir', '-i', required=True)
p.add_argument('--sr', '-r', type=int, default=44100)
p.add_argument('--n_fft', '-f', type=int, default=2048)
p.add_argument('--hop_length', '-H', type=int, default=1024)
p.add_argument('--batchsize', '-B', type=int, default=4)
p.add_argument('--cropsize', '-cs', type=int, default=256)
p.add_argument('--output_image', '-I', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--tta', '-t', action='store_true')
p.add_argument('--output_dir', '-o', type=str, required=True)
args = p.parse_args()
tqdm.write('loading model...', end=' ')
device = torch.device('cpu')
model = nets.CascadedNet(args.n_fft, 32, 128)
model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
if torch.cuda.is_available() and not args.cpu:
device = torch.device('cuda')
model.to(device)
tqdm.write('done')
wav_or_mp3 = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if f.endswith('.wav') or f.endswith('.mp3')]
tqdm.write(f'found {len(wav_or_mp3)} wave files from {args.input_dir}')
errors = []
for input in tqdm(wav_or_mp3, position=0, leave=True):
try:
args.input = input
tqdm.write('loading wave source...', end=' ')
X, sr = librosa.load(
args.input, args.sr, False, dtype=np.float32, res_type='kaiser_fast')
basename = os.path.splitext(os.path.basename(args.input))[0]
tqdm.write('done')
if X.ndim == 1:
# mono to stereo
X = np.asarray([X, X])
tqdm.write('stft of wave source...', end=' ')
X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft)
tqdm.write('done')
sp = Separator(model, device, args.batchsize, args.cropsize, args.postprocess)
if args.tta:
y_spec, v_spec = sp.separate_tta(X_spec)
else:
y_spec, v_spec = sp.separate(X_spec)
tqdm.write('validating output directory...', end=' ')
vocal_dir, inst_dir = [os.path.join(args.output_dir, d) for d in ['vocals', 'instruments']]
for d in [vocal_dir, inst_dir]:
if not os.path.exists(d):
os.makedirs(d)
tqdm.write('done')
tqdm.write('inverse stft of instruments...', end=' ')
wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=args.hop_length)
tqdm.write('done')
save_path = os.path.join(inst_dir, basename + '.wav')
sf.write(save_path, wave.T, sr)
tqdm.write('inverse stft of vocals...', end=' ')
wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=args.hop_length)
tqdm.write('done')
save_path = os.path.join(vocal_dir, basename + '.wav')
sf.write(save_path, wave.T, sr)
if args.output_image:
image = spec_utils.spectrogram_to_image(y_spec)
save_path = os.path.join(inst_dir, basename + '.jpg')
utils.imwrite(save_path, image)
image = spec_utils.spectrogram_to_image(v_spec)
save_path = os.path.join(vocal_dir, basename + '.jpg')
utils.imwrite(save_path, image)
except Exception as e:
errors.append([args.input, e])
tqdm.write('\nerror occured while processing {}'.format(args.input))
tqdm.write('\nerror message: \n{}'.format(e))
tqdm.write('\nCaught {} errors'.format(len(errors)))
for f, e in errors:
tqdm.write('{}'.format(f))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment