diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 71e006472db9..a9b3d4361f5b 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -436,7 +436,7 @@ def export_ckpt( def generate( path: Union[Path, str], prompts: list[str], - trainer: Optional[nl.Trainer] = None, + trainer: nl.Trainer, params_dtype: torch.dtype = torch.bfloat16, max_batch_size: int = 4, random_seed: Optional[int] = None, diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 95da536fde06..0171f1c2dd5c 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -16,6 +16,7 @@ import nemo.lightning as nl from nemo.lightning import io +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig @@ -44,6 +45,7 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl. load_optim_state=False, ) trainer.strategy.restore_config = restore_config + trainer.strategy._setup_optimizers = False trainer.ckpt_path = None trainer.strategy.connect(model) if trainer.strategy.launcher is not None: @@ -61,16 +63,22 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl. def setup_model_and_tokenizer( path: Path, - trainer: Optional[nl.Trainer] = None, + trainer: nl.Trainer, params_dtype: torch.dtype = torch.bfloat16, inference_batch_times_seqlen_threshold: int = 1000, ) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]: - model: io.TrainerContext = io.load_context(path=path, subpath="model") - trainer = trainer or io.load_context(path=path, subpath="trainer") + model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model") _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) # This is to get the MCore model required in GPTInferenceWrapper. - mcore_model = model.module.module.module + mcore_model = model + while mcore_model: + if type(mcore_model) is MCoreGPTModel: + break + mcore_model = getattr(mcore_model, "module", None) + if mcore_model is None or type(mcore_model) is not MCoreGPTModel: + raise ValueError("Exact McoreGPTModel instance not found in the model structure.") + inference_wrapped_model = GPTInferenceWrapper( mcore_model, InferenceWrapperConfig(