generated from sigsep/open-unmix-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 6
/
contrib.py
188 lines (143 loc) · 5.94 KB
/
contrib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter
from scipy.ndimage.filters import gaussian_filter1d
def _logit(W, threshold, slope):
return 1. / (1.0 + torch.exp(-slope * (W - threshold)))
def residual_model(v, x, alpha=1, autoscale=False):
r"""Compute a model for the residual based on spectral subtraction.
The method consists in two steps:
* The provided spectrograms are summed up to obtain the *input* model for
the mixture. This *input* model is scaled frequency-wise to best
fit with the actual observed mixture spectrogram.
* The residual model is obtained through spectral subtraction of the
input model from the mixture spectrogram, with flooring to 0.
Parameters
----------
v: torch.Tensor [shape=(batch, nb_frames, nb_bins, {1, nb_channels}, nb_sources)]
Estimated spectrograms for the sources
x: torch.Tensor [shape=(batch, nb_frames, nb_bins, nb_channels)]
complex mixture
alpha: float [scalar]
exponent for the spectrograms `v`. For instance, if `alpha==1`,
then `v` must be homogoneous to magnitudes, and if `alpha==2`, `v`
must homogeneous to squared magnitudes.
autoscale: boolean
in the case you know that the spectrograms will not have the right
magnitude, it is important that the models are scaled so that the
residual is correctly estimated.
Returns
-------
v: torch.Tensor [shape=(batch, nb_frames, nb_bins, nb_channels, nb_sources+1)]
Spectrograms of the sources, with an appended one for the residual.
Note
----
It is not mandatory to input multichannel spectrograms. However, the
output spectrograms *will* be multichannel.
Warning
-------
You must be careful to set `alpha` as the exponent that corresponds to `v`.
In other words, *you must have*: ``np.abs(x)**alpha`` homogeneous to `v`.
"""
# to avoid dividing by zero
eps = torch.finfo(v.dtype).eps
# spectrogram for the mixture
vx = F.threshold(x.abs() ** alpha, eps, eps)
# compute the total model as provided
v_total = v.sum(-1)
if autoscale:
# quick trick to scale the provided spectrograms to fit the mixture
gain = torch.sum(vx * v_total, 1)
weights = torch.sum(v_total * v_total, 1).add_(eps)
gain /= weights
v *= gain[..., None]
# re-sum the sources to build the new current model
v_total = v.sum(-1)
# residual is difference between the observation and the model
vr = (vx - v_total).relu()
return torch.cat((v, vr[..., None]), axis=4)
def smooth(v, width=1, temporal=False):
"""
smoothes a ndarray with a Gaussian blur.
Parameters
----------
v: torch.Tensor [shape=(nb_frames, ...)]
input array
sigma: int [scalar]
lengthscale of the gaussian blur
temporal: boolean
if True, will smooth only along time through 1d blur. Will use a
multidimensional Gaussian blur otherwise.
Returns
-------
result: torch.Tensor [shape=(nb_frames, ...)]
filtered array
"""
if temporal:
return gaussian_filter1d(v, sigma=width, axis=0)
else:
return gaussian_filter(v, sigma=width, truncate=width)
def reduce_interferences(v, thresh=0.6, slope=15):
r"""
Reduction of interferences between spectrograms.
The objective of the method is to redistribute the energy of the input in
order to "sparsify" spectrograms along the "source" dimension. This is
motivated by the fact that sources are somewhat sparse and it is hence
unlikely that they are all energetic at the same time-frequency bins.
The method is inspired from [1]_ with ad-hoc modifications.
References
----------
.. [1] Thomas Prätzlich, Rachel Bittner, Antoine Liutkus, Meinard Müller.
"Kernel additive modeling for interference reduction in multi-
channel music recordings" Proc. of ICASSP 2015.
Parameters
----------
v: torch.Tensor [shape=(..., nb_sources)]
non-negative data on which to apply interference reduction
thresh: float [scalar]
threshold for the compression, should be between 0 and 1. The closer
to 1, the more reduction of the interferences, at the price of more
distortion.
slope: float [scalar]
the slope at which binarization is done. The higher, the more
brutal
Returns
-------
v: torch.Tensor [same shape as input]
`v` with reduced interferences
"""
eps = 1e-7
vsmooth = smooth(v.detach().cpu().numpy(), 10)
vsmooth = torch.from_numpy(vsmooth).to(v.device).to(v.dtype)
total_energy = eps + vsmooth.sum(-1, keepdim=True)
v = _logit(vsmooth / total_energy, thresh, slope) * v
return v
def compress_filter(W, thresh=0.6, slope=15):
'''Applies a logit compression to a filter. This enables to "binarize" a
separation filter. This allows to reduce interferences at the price
of distortion.
In the case of multichannel filters, decomposes them as the cascade of a
pure beamformer (selection of one direction in space), followed by a
single-channel mask. Then, compression is applied on the mask only.
Parameters
----------
W: ndarray, shape=(..., nb_channels, nb_channels)
filter on which to apply logit compression.
thresh: float
threshold for the compression, should be between 0 and 1. The closer
to 1, the less interferences, but the more distortion.
slope: float
the slope at which binarization is done. The higher, the more brutal
Returns
-------
W: torch.Tensor [same shape as input]
Compressed filter
'''
eps = torch.finfo(W.dtype).eps
nb_channels = W.shape[-1]
if nb_channels > 1:
gains = torch.einsum('...ii', W)
W *= (_logit(gains, thresh, slope) / (eps + gains))[..., None, None]
else:
W = _logit(W, thresh, slope)
return W