Skip to content

Commit

Permalink
Little clean-up in frontend.
Browse files Browse the repository at this point in the history
Before implementing the STFT frontend, the frontend code is refactored to make
STFT implementation easier:
* Move num_filters from BaseFrontend to LogMelFrontend, as it is specific to
  filter bank configuration.
* Factor out the part that returns HANN coeffs from the function that applies
  the HANN window to the input. STFT inverse needs it.
  • Loading branch information
ds-hwang committed Dec 20, 2024
1 parent 6a7d2f0 commit fb9fdc9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
4 changes: 2 additions & 2 deletions axlearn/audio/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class Config(BaseLayer.Config):

# Number of output channels.
output_dim: Required[int] = REQUIRED
# Number of filters/bands in the output spectrogram.
num_filters: Required[int] = REQUIRED
# Number of input samples per second, e.g., 24000 for 24KHz inputs.
sample_rate: Required[int] = REQUIRED
# Size of each frame in ms.
Expand Down Expand Up @@ -132,6 +130,8 @@ class LogMelFrontend(BaseFrontend):
class Config(BaseFrontend.Config):
"""Configures LogMelFrontend."""

# Number of filters/bands in the output spectrogram.
num_filters: Required[int] = REQUIRED
# Number of output channels. Should always be 1.
output_dim: int = 1
# Optional output transformation. See `normalize_by_mean_std` for an example.
Expand Down
9 changes: 6 additions & 3 deletions axlearn/audio/frontend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def pre_emphasis(x: Tensor, *, coeff: Tensor) -> Tensor:
return x[..., 1:] - coeff * x[..., :-1]


def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> Tensor:
"""Applies windowing to the input frames of shape `[..., num_windows, window_size]`."""
window_size = x.shape[-1]
def window_coffs(window_size: int, *, window_type: WindowType, periodic: bool = True) -> Tensor:
is_even = (1 - window_size % 2) * periodic

if window_type == WindowType.HANN:
Expand All @@ -261,7 +259,12 @@ def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> T
coeffs = jnp.hamming(window_size + is_even)[:window_size]
else:
raise NotImplementedError(f"Unrecognized window_type {window_type}.")
return coeffs


def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> Tensor:
"""Applies windowing to the input frames of shape `[..., num_windows, window_size]`."""
coeffs = window_coffs(x.shape[-1], window_type=window_type, periodic=periodic)
return (x * coeffs).astype(x.dtype)


Expand Down
40 changes: 33 additions & 7 deletions axlearn/audio/frontend_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
next_power_of_2,
pre_emphasis,
sharded_fft,
window_coffs,
windowing,
)
from axlearn.audio.test_utils import fake_audio
Expand Down Expand Up @@ -160,6 +161,20 @@ def test_window(self, input_shape, window_type: WindowType, periodic: bool):
windowing(inputs, window_type=window_type, periodic=periodic),
)

@parameterized.product(
window_size=[400, 401],
window_type=list(WindowType),
periodic=[True, False],
)
def test_window_coffs(self, window_size, window_type: WindowType, periodic: bool):
ref_coffs = _ref_window_coffs(
window_size=window_size, window_type=window_type, periodic=periodic
)
test_coeffs = window_coffs(
window_size=window_size, window_type=window_type, periodic=periodic
)
self.assertAllClose(ref_coffs, test_coeffs)


class SpectrogramTest(parameterized.TestCase, tf.test.TestCase):
"""Tests spectrograms."""
Expand Down Expand Up @@ -296,19 +311,30 @@ def _ref_pre_emphasis(*, inputs: ArrayLike, coeff: float):
return inputs[:, :, 1:] - coeff * inputs[:, :, :-1]


def _ref_window(*, inputs: ArrayLike, window_type: WindowType, **kwargs):
def _ref_window_coffs(
*, window_size: int, window_type: WindowType, periodic: bool = True, dtype=tf.float32
):
if window_type == WindowType.HANN:
tf_window = tf.signal.hann_window(window_size, periodic=periodic, dtype=dtype)
elif window_type == WindowType.HAMMING:
tf_window = tf.signal.hamming_window(window_size, periodic=periodic, dtype=dtype)
else:
raise NotImplementedError(f"Unrecognized window type: {window_type}")
return tf_window


def _ref_window(
*, inputs: ArrayLike, window_type: WindowType, periodic: bool = True, dtype=tf.float32
):
"""Lingvo window.
Reference:
https://github.com/tensorflow/lingvo/blob/4a9097a212622d99d7f8e2379804dbffdc44a97f/lingvo/tasks/asr/frontend.py#L244
"""
frame_size = inputs.shape[-1]
if window_type == WindowType.HANN:
tf_window = tf.signal.hann_window(frame_size, **kwargs)
elif window_type == WindowType.HAMMING:
tf_window = tf.signal.hamming_window(frame_size, **kwargs)
else:
raise NotImplementedError(f"Unrecognized window type: {window_type}")
tf_window = _ref_window_coffs(
window_size=frame_size, window_type=window_type, periodic=periodic, dtype=dtype
)
return inputs * tf_window


Expand Down

0 comments on commit fb9fdc9

Please sign in to comment.