Skip to content

Commit

Permalink
fix hybrid eval (#7759) (#7760)
Browse files Browse the repository at this point in the history
* fix



* rename



* docs



* warning



* if



---------

Signed-off-by: Nikolay Karpov <[email protected]>
Co-authored-by: Nikolay Karpov <[email protected]>
Co-authored-by: Nikolay Karpov <[email protected]>
  • Loading branch information
3 people authored Oct 19, 2023
1 parent 519a52f commit 493413e
Showing 1 changed file with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
# Config Help
To discover all arguments of the script, please run :
python eval_beamsearch_ngram.py --help
python eval_beamsearch_ngram.py --cfg job
python eval_beamsearch_ngram_ctc.py --help
python eval_beamsearch_ngram_ctc.py --cfg job
# USAGE
python eval_beamsearch_ngram.py nemo_model_file=<path to the .nemo file of the model> \
python eval_beamsearch_ngram_ctc.py nemo_model_file=<path to the .nemo file of the model> \
input_manifest=<path to the evaluation JSON manifest file> \
kenlm_model_file=<path to the binary KenLM model> \
beam_width=[<list of the beam widths, separated with commas>] \
Expand Down Expand Up @@ -140,7 +140,10 @@ def beam_search_eval(
level = logging.getEffectiveLevel()
logging.setLevel(logging.CRITICAL)
# Reset config
model.change_decoding_strategy(None)
if isinstance(model, EncDecHybridRNNTCTCModel):
model.change_decoding_strategy(decoding_cfg=None, decoder_type="ctc")
else:
model.change_decoding_strategy(None)

# Override the beam search config with current search candidate configuration
cfg.decoding.beam_size = beam_width
Expand Down Expand Up @@ -257,7 +260,6 @@ def beam_search_eval(

@hydra_runner(config_path=None, config_name='EvalBeamSearchNGramConfig', schema=EvalBeamSearchNGramConfig)
def main(cfg: EvalBeamSearchNGramConfig):
logging.warning("This file will be renamed to eval_beamsearch_ngram_ctc.py in the future NeMo (1.21) release.")
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg) # type: EvalBeamSearchNGramConfig

Expand Down Expand Up @@ -333,6 +335,7 @@ def default_autocast():

all_probs = all_logits
if cfg.probs_cache_file:
os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True)
logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...")
with open(cfg.probs_cache_file, 'wb') as f_dump:
pickle.dump(all_probs, f_dump)
Expand Down

0 comments on commit 493413e

Please sign in to comment.