Last active
March 6, 2023 12:02
-
-
Save ouor/f48b2815f61b03e9e4212b433edcd995 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import glob | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from tqdm import tqdm | |
from lib import dataset | |
from lib import nets | |
from lib import spec_utils | |
from lib import utils | |
class Separator(object): | |
def __init__(self, model, device, batchsize, cropsize, postprocess=False): | |
self.model = model | |
self.offset = model.offset | |
self.device = device | |
self.batchsize = batchsize | |
self.cropsize = cropsize | |
self.postprocess = postprocess | |
def _separate(self, X_mag_pad, roi_size): | |
X_dataset = [] | |
patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size | |
for i in range(patches): | |
start = i * roi_size | |
X_mag_crop = X_mag_pad[:, :, start:start + self.cropsize] | |
X_dataset.append(X_mag_crop) | |
X_dataset = np.asarray(X_dataset) | |
self.model.eval() | |
with torch.no_grad(): | |
mask = [] | |
# To reduce the overhead, dataloader is not used. | |
for i in range(0, patches, self.batchsize): | |
X_batch = X_dataset[i: i + self.batchsize] | |
X_batch = torch.from_numpy(X_batch).to(self.device) | |
pred = self.model.predict_mask(X_batch) | |
pred = pred.detach().cpu().numpy() | |
pred = np.concatenate(pred, axis=2) | |
mask.append(pred) | |
mask = np.concatenate(mask, axis=2) | |
return mask | |
def _preprocess(self, X_spec): | |
X_mag = np.abs(X_spec) | |
X_phase = np.angle(X_spec) | |
return X_mag, X_phase | |
def _postprocess(self, mask, X_mag, X_phase): | |
if self.postprocess: | |
mask = spec_utils.merge_artifacts(mask) | |
y_spec = mask * X_mag * np.exp(1.j * X_phase) | |
v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) | |
return y_spec, v_spec | |
def separate(self, X_spec): | |
X_mag, X_phase = self._preprocess(X_spec) | |
n_frame = X_mag.shape[2] | |
pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset) | |
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') | |
X_mag_pad /= X_mag_pad.max() | |
mask = self._separate(X_mag_pad, roi_size) | |
mask = mask[:, :, :n_frame] | |
y_spec, v_spec = self._postprocess(mask, X_mag, X_phase) | |
return y_spec, v_spec | |
def separate_tta(self, X_spec): | |
X_mag, X_phase = self._preprocess(X_spec) | |
n_frame = X_mag.shape[2] | |
pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset) | |
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') | |
X_mag_pad /= X_mag_pad.max() | |
mask = self._separate(X_mag_pad, roi_size) | |
pad_l += roi_size // 2 | |
pad_r += roi_size // 2 | |
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') | |
X_mag_pad /= X_mag_pad.max() | |
mask_tta = self._separate(X_mag_pad, roi_size) | |
mask_tta = mask_tta[:, :, roi_size // 2:] | |
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 | |
y_spec, v_spec = self._postprocess(mask, X_mag, X_phase) | |
return y_spec, v_spec | |
def main(): | |
p = argparse.ArgumentParser() | |
p.add_argument('--cpu', '-c', action='store_true') | |
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') | |
p.add_argument('--input-dir', '-i', required=True) | |
p.add_argument('--sr', '-r', type=int, default=44100) | |
p.add_argument('--n_fft', '-f', type=int, default=2048) | |
p.add_argument('--hop_length', '-H', type=int, default=1024) | |
p.add_argument('--batchsize', '-B', type=int, default=4) | |
p.add_argument('--cropsize', '-cs', type=int, default=256) | |
p.add_argument('--output_image', '-I', action='store_true') | |
p.add_argument('--postprocess', '-p', action='store_true') | |
p.add_argument('--tta', '-t', action='store_true') | |
p.add_argument('--output_dir', '-o', type=str, required=True) | |
args = p.parse_args() | |
tqdm.write('loading model...', end=' ') | |
device = torch.device('cpu') | |
model = nets.CascadedNet(args.n_fft, 32, 128) | |
model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) | |
if torch.cuda.is_available() and not args.cpu: | |
device = torch.device('cuda') | |
model.to(device) | |
tqdm.write('done') | |
wav_or_mp3 = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if f.endswith('.wav') or f.endswith('.mp3')] | |
tqdm.write(f'found {len(wav_or_mp3)} wave files from {args.input_dir}') | |
errors = [] | |
for input in tqdm(wav_or_mp3, position=0, leave=True): | |
try: | |
args.input = input | |
tqdm.write('loading wave source...', end=' ') | |
X, sr = librosa.load( | |
args.input, args.sr, False, dtype=np.float32, res_type='kaiser_fast') | |
basename = os.path.splitext(os.path.basename(args.input))[0] | |
tqdm.write('done') | |
if X.ndim == 1: | |
# mono to stereo | |
X = np.asarray([X, X]) | |
tqdm.write('stft of wave source...', end=' ') | |
X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) | |
tqdm.write('done') | |
sp = Separator(model, device, args.batchsize, args.cropsize, args.postprocess) | |
if args.tta: | |
y_spec, v_spec = sp.separate_tta(X_spec) | |
else: | |
y_spec, v_spec = sp.separate(X_spec) | |
tqdm.write('validating output directory...', end=' ') | |
vocal_dir, inst_dir = [os.path.join(args.output_dir, d) for d in ['vocals', 'instruments']] | |
for d in [vocal_dir, inst_dir]: | |
if not os.path.exists(d): | |
os.makedirs(d) | |
tqdm.write('done') | |
tqdm.write('inverse stft of instruments...', end=' ') | |
wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=args.hop_length) | |
tqdm.write('done') | |
save_path = os.path.join(inst_dir, basename + '.wav') | |
sf.write(save_path, wave.T, sr) | |
tqdm.write('inverse stft of vocals...', end=' ') | |
wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=args.hop_length) | |
tqdm.write('done') | |
save_path = os.path.join(vocal_dir, basename + '.wav') | |
sf.write(save_path, wave.T, sr) | |
if args.output_image: | |
image = spec_utils.spectrogram_to_image(y_spec) | |
save_path = os.path.join(inst_dir, basename + '.jpg') | |
utils.imwrite(save_path, image) | |
image = spec_utils.spectrogram_to_image(v_spec) | |
save_path = os.path.join(vocal_dir, basename + '.jpg') | |
utils.imwrite(save_path, image) | |
except Exception as e: | |
errors.append([args.input, e]) | |
tqdm.write('\nerror occured while processing {}'.format(args.input)) | |
tqdm.write('\nerror message: \n{}'.format(e)) | |
tqdm.write('\nCaught {} errors'.format(len(errors))) | |
for f, e in errors: | |
tqdm.write('{}'.format(f)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment