diff --git a/axlearn/audio/frontend.py b/axlearn/audio/frontend.py index 344be56c..242191d6 100644 --- a/axlearn/audio/frontend.py +++ b/axlearn/audio/frontend.py @@ -68,8 +68,6 @@ class Config(BaseLayer.Config): # Number of output channels. output_dim: Required[int] = REQUIRED - # Number of filters/bands in the output spectrogram. - num_filters: Required[int] = REQUIRED # Number of input samples per second, e.g., 24000 for 24KHz inputs. sample_rate: Required[int] = REQUIRED # Size of each frame in ms. @@ -132,6 +130,8 @@ class LogMelFrontend(BaseFrontend): class Config(BaseFrontend.Config): """Configures LogMelFrontend.""" + # Number of filters/bands in the output spectrogram. + num_filters: Required[int] = REQUIRED # Number of output channels. Should always be 1. output_dim: int = 1 # Optional output transformation. See `normalize_by_mean_std` for an example. diff --git a/axlearn/audio/frontend_utils.py b/axlearn/audio/frontend_utils.py index 55ff48b7..c2988287 100644 --- a/axlearn/audio/frontend_utils.py +++ b/axlearn/audio/frontend_utils.py @@ -250,9 +250,7 @@ def pre_emphasis(x: Tensor, *, coeff: Tensor) -> Tensor: return x[..., 1:] - coeff * x[..., :-1] -def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> Tensor: - """Applies windowing to the input frames of shape `[..., num_windows, window_size]`.""" - window_size = x.shape[-1] +def window_coffs(window_size: int, *, window_type: WindowType, periodic: bool = True) -> Tensor: is_even = (1 - window_size % 2) * periodic if window_type == WindowType.HANN: @@ -261,7 +259,12 @@ def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> T coeffs = jnp.hamming(window_size + is_even)[:window_size] else: raise NotImplementedError(f"Unrecognized window_type {window_type}.") + return coeffs + +def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> Tensor: + """Applies windowing to the input frames of shape `[..., num_windows, window_size]`.""" + coeffs = window_coffs(x.shape[-1], window_type=window_type, periodic=periodic) return (x * coeffs).astype(x.dtype) diff --git a/axlearn/audio/frontend_utils_test.py b/axlearn/audio/frontend_utils_test.py index dbca80e1..74537b57 100644 --- a/axlearn/audio/frontend_utils_test.py +++ b/axlearn/audio/frontend_utils_test.py @@ -31,6 +31,7 @@ next_power_of_2, pre_emphasis, sharded_fft, + window_coffs, windowing, ) from axlearn.audio.test_utils import fake_audio @@ -160,6 +161,20 @@ def test_window(self, input_shape, window_type: WindowType, periodic: bool): windowing(inputs, window_type=window_type, periodic=periodic), ) + @parameterized.product( + window_size=[400, 401], + window_type=list(WindowType), + periodic=[True, False], + ) + def test_window_coffs(self, window_size, window_type: WindowType, periodic: bool): + ref_coffs = _ref_window_coffs( + window_size=window_size, window_type=window_type, periodic=periodic + ) + test_coeffs = window_coffs( + window_size=window_size, window_type=window_type, periodic=periodic + ) + self.assertAllClose(ref_coffs, test_coeffs) + class SpectrogramTest(parameterized.TestCase, tf.test.TestCase): """Tests spectrograms.""" @@ -296,19 +311,30 @@ def _ref_pre_emphasis(*, inputs: ArrayLike, coeff: float): return inputs[:, :, 1:] - coeff * inputs[:, :, :-1] -def _ref_window(*, inputs: ArrayLike, window_type: WindowType, **kwargs): +def _ref_window_coffs( + *, window_size: int, window_type: WindowType, periodic: bool = True, dtype=tf.float32 +): + if window_type == WindowType.HANN: + tf_window = tf.signal.hann_window(window_size, periodic=periodic, dtype=dtype) + elif window_type == WindowType.HAMMING: + tf_window = tf.signal.hamming_window(window_size, periodic=periodic, dtype=dtype) + else: + raise NotImplementedError(f"Unrecognized window type: {window_type}") + return tf_window + + +def _ref_window( + *, inputs: ArrayLike, window_type: WindowType, periodic: bool = True, dtype=tf.float32 +): """Lingvo window. Reference: https://github.com/tensorflow/lingvo/blob/4a9097a212622d99d7f8e2379804dbffdc44a97f/lingvo/tasks/asr/frontend.py#L244 """ frame_size = inputs.shape[-1] - if window_type == WindowType.HANN: - tf_window = tf.signal.hann_window(frame_size, **kwargs) - elif window_type == WindowType.HAMMING: - tf_window = tf.signal.hamming_window(frame_size, **kwargs) - else: - raise NotImplementedError(f"Unrecognized window type: {window_type}") + tf_window = _ref_window_coffs( + window_size=frame_size, window_type=window_type, periodic=periodic, dtype=dtype + ) return inputs * tf_window