From 764652a1e1bf5c133cfb6f8ba2d6f95bef0a04c4 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 3 Oct 2023 13:00:50 -0700 Subject: [PATCH] Update name to ConfidenceMethodConfig Signed-off-by: smajumdar --- .../asr/parts/submodules/ctc_greedy_decoding.py | 4 ++-- .../asr/parts/submodules/rnnt_greedy_decoding.py | 8 ++++---- nemo/collections/asr/parts/utils/asr_confidence_utils.py | 2 +- scripts/confidence_ensembles/build_ensemble.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 3a55a00185af..516781b1223e 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -253,8 +253,8 @@ class GreedyCTCInferConfig: preserve_alignments: bool = False compute_timestamps: bool = False preserve_frame_confidence: bool = False - confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = field( - default_factory=lambda: ConfidenceMeasureConfig() + confidence_measure_cfg: Optional[ConfidenceMethodConfig] = field( + default_factory=lambda: ConfidenceMethodConfig() ) confidence_method_cfg: str = "DEPRECATED" diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index b6787918750b..090c9d6e4702 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2185,8 +2185,8 @@ class GreedyRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False - confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = field( - default_factory=lambda: ConfidenceMeasureConfig() + confidence_measure_cfg: Optional[ConfidenceMethodConfig] = field( + default_factory=lambda: ConfidenceMethodConfig() ) confidence_method_cfg: str = "DEPRECATED" @@ -2204,8 +2204,8 @@ class GreedyBatchedRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False - confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = field( - default_factory=lambda: ConfidenceMeasureConfig() + confidence_measure_cfg: Optional[ConfidenceMethodConfig] = field( + default_factory=lambda: ConfidenceMethodConfig() ) confidence_method_cfg: str = "DEPRECATED" diff --git a/nemo/collections/asr/parts/utils/asr_confidence_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_utils.py index 0c387709ea2b..45e501078b64 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_utils.py @@ -175,7 +175,7 @@ class ConfidenceConfig: preserve_word_confidence: bool = False exclude_blank: bool = True aggregation: str = "min" - measure_cfg: ConfidenceMeasureConfig = field(default_factory=lambda: ConfidenceMeasureConfig()) + measure_cfg: ConfidenceMethodConfig = field(default_factory=lambda: ConfidenceMethodConfig()) method_cfg: str = "DEPRECATED" def __post_init__(self): diff --git a/scripts/confidence_ensembles/build_ensemble.py b/scripts/confidence_ensembles/build_ensemble.py index 8190b6ce2c20..e40997c4aca2 100644 --- a/scripts/confidence_ensembles/build_ensemble.py +++ b/scripts/confidence_ensembles/build_ensemble.py @@ -215,7 +215,7 @@ class BuildEnsembleConfig: preserve_frame_confidence=True, exclude_blank=True, aggregation="mean", - measure_cfg=ConfidenceMeasureConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), + measure_cfg=ConfidenceMethodConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), ) ) temperature: float = 1.0