diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py similarity index 97% rename from scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py rename to scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py index b1cd385f4198..2af8283c7b82 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py @@ -25,12 +25,12 @@ # Config Help To discover all arguments of the script, please run : -python eval_beamsearch_ngram.py --help -python eval_beamsearch_ngram.py --cfg job +python eval_beamsearch_ngram_ctc.py --help +python eval_beamsearch_ngram_ctc.py --cfg job # USAGE -python eval_beamsearch_ngram.py nemo_model_file= \ +python eval_beamsearch_ngram_ctc.py nemo_model_file= \ input_manifest= \ kenlm_model_file= \ beam_width=[] \ @@ -140,7 +140,10 @@ def beam_search_eval( level = logging.getEffectiveLevel() logging.setLevel(logging.CRITICAL) # Reset config - model.change_decoding_strategy(None) + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(decoding_cfg=None, decoder_type="ctc") + else: + model.change_decoding_strategy(None) # Override the beam search config with current search candidate configuration cfg.decoding.beam_size = beam_width @@ -257,7 +260,6 @@ def beam_search_eval( @hydra_runner(config_path=None, config_name='EvalBeamSearchNGramConfig', schema=EvalBeamSearchNGramConfig) def main(cfg: EvalBeamSearchNGramConfig): - logging.warning("This file will be renamed to eval_beamsearch_ngram_ctc.py in the future NeMo (1.21) release.") if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) # type: EvalBeamSearchNGramConfig @@ -333,6 +335,7 @@ def default_autocast(): all_probs = all_logits if cfg.probs_cache_file: + os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True) logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...") with open(cfg.probs_cache_file, 'wb') as f_dump: pickle.dump(all_probs, f_dump)