-
Notifications
You must be signed in to change notification settings - Fork 25
/
globalsdr_musdb18.py
65 lines (48 loc) · 2.21 KB
/
globalsdr_musdb18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import musdb
import wandb
import numpy as np
from base_model import BasicPredictor
dataset_dev = musdb.DB(root='D:\\repos\\musdb18_dev_wav', subsets='test', is_wav=True)
dataset_hq = musdb.DB(root='D:\\repos\\musdb18hq', subsets='test', is_wav=True)
final_predictor = BasicPredictor(use_mixer=True, demucs=True)
def sdr(references, estimates):
# compute SDR for one song
delta = 1e-7 # avoid numerical errors
num = np.sum(np.square(references), axis=(1, 2))
den = np.sum(np.square(references - estimates), axis=(1, 2))
num += delta
den += delta
return 10 * np.log10(num / den)
def eval_dataset(_dataset, _predictor):
config = {
'use_mixer': _predictor.use_mixer,
'use_demucs': _predictor.use_demucs,
'dataset': _dataset.root
}
wandb.init(project="KUIELab-MDX-Net-GlobalSDR", entity="ielab", config=config)
sources = ['bass', 'drums', 'other', 'vocals']
sdr_results=[]
song_sdrs=[]
for idx in range(len(_dataset)):
track = _dataset[idx]
estimation = _predictor.demix(track.audio.T)
# Real SDR
if len(estimation) == len(sources):
track_length = _dataset[idx].samples
if track_length > estimation.shape[-1]:
raise NotImplementedError
else:
estimated_targets_dict = {source: estimated.T for source, estimated in zip(sources, estimation)}
refs = np.stack([track.sources[source].audio for source in sources])
ests = np.stack([estimated_targets_dict[source] for source in sources])
sdrs = sdr(refs, ests)
for source, source_sdr in zip(sources, sdrs):
wandb.log({'test_result/{}'.format(source): source_sdr}, step=idx)
wandb.log({'test_result/{song}': np.mean(sdrs)}, step=idx)
sdr_results.append(sdrs)
song_sdrs.append(np.mean(sdrs))
avg_sdrs = np.mean(np.stack(sdr_results), axis=0)
for source, source_sdr in zip(sources, avg_sdrs):
wandb.log({'test_avg_result/{}'.format(source): source_sdr})
wandb.log({'test_avg_result/song'.format(source): np.mean(np.stack(song_sdrs))})
eval_dataset(dataset_dev, final_predictor)