Skip to content

Commit

Permalink
Mengruw/fixed tgt lang v3 (#7866)
Browse files Browse the repository at this point in the history
* Based on David's review. fixed lang id inference

* Reset tgt_language to null

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Dockerfile builder <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and web-flow committed Nov 9, 2023
1 parent 77d1386 commit fcb9129
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,10 @@ source_lang: null
target_lang: null
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_model_parallel_split_rank: 0
pipeline_model_parallel_split_rank: 0

# setting tgt_language during inference. It can be set as a list:
# tgt_language:
# - en
# - de
tgt_language: null
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import os

from omegaconf.omegaconf import OmegaConf, open_dict
from omegaconf import OmegaConf, open_dict
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel
Expand All @@ -47,7 +47,6 @@

@hydra_runner(config_path="conf", config_name="nmt_megatron_infer")
def main(cfg) -> None:

# trainer required for restoring model parallel models
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
assert (
Expand Down Expand Up @@ -75,16 +74,23 @@ def main(cfg) -> None:
if cfg.model_file is not None:
if not os.path.exists(cfg.model_file):
raise ValueError(f"Model file {cfg.model_file} does not exist")

# getting the model's config and updating tgt_language if needed
pretrained_cfg = MegatronNMTModel.restore_from(cfg.model_file, trainer=trainer, return_config=True)

# modifying the config
OmegaConf.set_struct(pretrained_cfg, True)
with open_dict(pretrained_cfg):
pretrained_cfg.precision = trainer.precision
if hasattr(cfg, 'tgt_language') and cfg.tgt_language is not None:
pretrained_cfg.tgt_language = cfg.tgt_language

model = MegatronNMTModel.restore_from(
restore_path=cfg.model_file,
trainer=trainer,
save_restore_connector=NLPSaveRestoreConnector(),
override_config_path=pretrained_cfg,
)

elif cfg.checkpoint_dir is not None:
checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name))
model = MegatronNMTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer)
Expand Down

0 comments on commit fcb9129

Please sign in to comment.