Skip to content

Commit

Permalink
[Audio] Metric with Squim objective and MOS (#9751)
Browse files Browse the repository at this point in the history
* Metric with Squim Objective and MOS

Signed-off-by: Ante Jukić <[email protected]>

* Removed utility functions

Signed-off-by: Ante Jukić <[email protected]>

---------

Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju authored Jul 23, 2024
1 parent b7ca4a3 commit 1711334
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 3 deletions.
22 changes: 20 additions & 2 deletions examples/audio/audio_to_audio_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

from nemo.collections.audio.data import audio_to_audio_dataset
from nemo.collections.audio.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset
from nemo.collections.audio.metrics.audio import AudioMetricWrapper
from nemo.collections.audio.metrics import AudioMetricWrapper, SquimMOSMetric, SquimObjectiveMetric
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.parts.preprocessing import manifest
from nemo.core.config import hydra_runner
Expand Down Expand Up @@ -128,7 +128,17 @@ def get_evaluation_dataloader(config):

def get_metrics(cfg: AudioEvaluationConfig):
"""Prepare a dictionary with metrics."""
available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq']
available_metrics = [
'sdr',
'sisdr',
'stoi',
'estoi',
'pesq',
'squim_mos',
'squim_stoi',
'squim_pesq',
'squim_si_sdr',
]

metrics = dict()
for name in sorted(set(cfg.metrics)):
Expand All @@ -143,6 +153,14 @@ def get_metrics(cfg: AudioEvaluationConfig):
metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True))
elif name == 'pesq':
metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb'))
elif name == 'squim_mos':
metric = AudioMetricWrapper(metric=SquimMOSMetric(fs=cfg.sample_rate))
elif name == 'squim_stoi':
metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='stoi', fs=cfg.sample_rate))
elif name == 'squim_pesq':
metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='pesq', fs=cfg.sample_rate))
elif name == 'squim_si_sdr':
metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='si_sdr', fs=cfg.sample_rate))
else:
raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}')

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:
with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
item = json.loads(line)
item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest)
item[audio_key] = get_full_path(item[audio_key], cfg.dataset_manifest)
if item.get("duration") is None and cfg.presort_manifest:
raise ValueError(
f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field."
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/audio/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.audio.metrics.audio import AudioMetricWrapper
from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric
3 changes: 3 additions & 0 deletions nemo/collections/audio/metrics/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric

from nemo.utils import logging

Expand All @@ -34,6 +35,8 @@
SignalNoiseRatio,
PerceptualEvaluationSpeechQuality,
ShortTimeObjectiveIntelligibility,
SquimMOSMetric,
SquimObjectiveMetric,
]


Expand Down
197 changes: 197 additions & 0 deletions nemo/collections/audio/metrics/squim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
from torchmetrics import Metric
from nemo.utils import logging

try:
import torchaudio

HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False


class SquimMOSMetric(Metric):
"""A metric calculating the average Torchaudio Squim MOS.
Args:
fs: sampling rate of the input signals
"""

sample_rate: int = 16000 # sample rate of the model
mos_sum: torch.Tensor
num_examples: torch.Tensor
higher_is_better: bool = True

def __init__(self, fs: int, **kwargs: Any):
super().__init__(**kwargs)

if not HAVE_TORCHAUDIO:
raise ModuleNotFoundError(f"{self.__class__.__name__} metric needs `torchaudio`.")

if fs != self.sample_rate:
# Resampler: kaiser_best
self._squim_mos_metric_resampler = torchaudio.transforms.Resample(
orig_freq=fs,
new_freq=self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
logging.warning('Input signals will be resampled from fs=%d to %d Hz', fs, self.sample_rate)
self.fs = fs

# MOS model
self._squim_mos_metric_model = torchaudio.pipelines.SQUIM_SUBJECTIVE.get_model()

self.add_state('mos_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum')
logging.debug('Setup metric %s with input fs=%s', self.__class__.__name__, self.fs)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update the metric by calculating the MOS score for the current batch.
Args:
preds: tensor with predictions, shape (B, T)
target: tensor with target signals, shape (B, T). Target can be a non-matching reference.
"""
if self.fs != self.sample_rate:
preds = self._squim_mos_metric_resampler(preds)
target = self._squim_mos_metric_resampler(target)

if preds.ndim == 1:
# Unsqueeze batch dimension
preds = preds.unsqueeze(0)
target = target.unsqueeze(0)
elif preds.ndim > 2:
raise ValueError(f'Expected 1D or 2D signals, got {preds.ndim}D signals')

mos_batch = self._squim_mos_metric_model(preds, target)

self.mos_sum += mos_batch.sum()
self.num_examples += mos_batch.numel()

def compute(self) -> torch.Tensor:
"""Compute the underlying metric."""
return self.mos_sum / self.num_examples

def state_dict(self, *args, **kwargs):
"""Do not save the MOS model and resampler in the state dict."""
state_dict = super().state_dict(*args, **kwargs)
# Do not include resampler or mos_model in the state dict
remove_keys = [
key
for key in state_dict.keys()
if '_squim_mos_metric_resampler' in key or '_squim_mos_metric_model' in key
]
for key in remove_keys:
del state_dict[key]
return state_dict


class SquimObjectiveMetric(Metric):
"""A metric calculating the average Torchaudio Squim objective metric.
Args:
fs: sampling rate of the input signals
metric: the objective metric to calculate. One of 'stoi', 'pesq', 'si_sdr'
"""

sample_rate: int = 16000 # sample rate of the model
metric_sum: torch.Tensor
num_examples: torch.Tensor
higher_is_better: bool = True

def __init__(self, fs: int, metric: str, **kwargs: Any):
super().__init__(**kwargs)

if not HAVE_TORCHAUDIO:
raise ModuleNotFoundError(f"{self.__class__.__name__} needs `torchaudio`.")

if fs != self.sample_rate:
# Resampler: kaiser_best
self._squim_objective_metric_resampler = torchaudio.transforms.Resample(
orig_freq=fs,
new_freq=self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
logging.warning('Input signals will be resampled from fs=%d to %d Hz', fs, self.sample_rate)
self.fs = fs

if metric not in ['stoi', 'pesq', 'si_sdr']:
raise ValueError(f'Unsupported metric {metric}. Supported metrics are "stoi", "pesq", "si_sdr".')

self.metric = metric

# Objective model
self._squim_objective_metric_model = torchaudio.pipelines.SQUIM_OBJECTIVE.get_model()

self.add_state('metric_sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum')
logging.debug('Setup %s with metric=%s, input fs=%s', self.__class__.__name__, self.metric, self.fs)

def update(self, preds: torch.Tensor, target: Any = None) -> None:
"""Update the metric by calculating the selected metric score for the current batch.
Args:
preds: tensor with predictions, shape (B, T)
target: None, not used. Keeping for interfacfe compatibility with other metrics.
"""
if self.fs != self.sample_rate:
preds = self._squim_objective_metric_resampler(preds)

if preds.ndim == 1:
# Unsqueeze batch dimension
preds = preds.unsqueeze(0)
elif preds.ndim > 2:
raise ValueError(f'Expected 1D or 2D signals, got {preds.ndim}D signals')

stoi_batch, pesq_batch, si_sdr_batch = self._squim_objective_metric_model(preds)

if self.metric == 'stoi':
metric_batch = stoi_batch
elif self.metric == 'pesq':
metric_batch = pesq_batch
elif self.metric == 'si_sdr':
metric_batch = si_sdr_batch
else:
raise ValueError(f'Unknown metric {self.metric}')

self.metric_sum += metric_batch.sum()
self.num_examples += metric_batch.numel()

def compute(self) -> torch.Tensor:
"""Compute the underlying metric."""
return self.metric_sum / self.num_examples

def state_dict(self, *args, **kwargs):
"""Do not save the MOS model and resampler in the state dict."""
state_dict = super().state_dict(*args, **kwargs)
# Do not include resampler or mos_model in the state dict
remove_keys = [
key
for key in state_dict.keys()
if '_squim_objective_metric_resampler' in key or '_squim_objective_metric_model' in key
]
for key in remove_keys:
del state_dict[key]
return state_dict
Loading

0 comments on commit 1711334

Please sign in to comment.