Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix hybrid eval #7759

Merged
merged 5 commits into from
Oct 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
# Config Help
To discover all arguments of the script, please run :
python eval_beamsearch_ngram.py --help
python eval_beamsearch_ngram.py --cfg job
python eval_beamsearch_ngram_ctc.py --help
python eval_beamsearch_ngram_ctc.py --cfg job
# USAGE
python eval_beamsearch_ngram.py nemo_model_file=<path to the .nemo file of the model> \
python eval_beamsearch_ngram_ctc.py nemo_model_file=<path to the .nemo file of the model> \
input_manifest=<path to the evaluation JSON manifest file> \
kenlm_model_file=<path to the binary KenLM model> \
beam_width=[<list of the beam widths, separated with commas>] \
Expand Down Expand Up @@ -140,7 +140,10 @@ def beam_search_eval(
level = logging.getEffectiveLevel()
logging.setLevel(logging.CRITICAL)
# Reset config
model.change_decoding_strategy(None)
if isinstance(model, EncDecHybridRNNTCTCModel):
model.change_decoding_strategy(decoding_cfg=None, decoder_type="ctc")
else:
model.change_decoding_strategy(None)

# Override the beam search config with current search candidate configuration
cfg.decoding.beam_size = beam_width
Expand Down Expand Up @@ -257,7 +260,6 @@ def beam_search_eval(

@hydra_runner(config_path=None, config_name='EvalBeamSearchNGramConfig', schema=EvalBeamSearchNGramConfig)
def main(cfg: EvalBeamSearchNGramConfig):
logging.warning("This file will be renamed to eval_beamsearch_ngram_ctc.py in the future NeMo (1.21) release.")
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg) # type: EvalBeamSearchNGramConfig

Expand Down Expand Up @@ -333,6 +335,7 @@ def default_autocast():

all_probs = all_logits
if cfg.probs_cache_file:
os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True)
logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...")
with open(cfg.probs_cache_file, 'wb') as f_dump:
pickle.dump(all_probs, f_dump)
Expand Down
Loading